Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creates non-breaking changes where necessary in preparation for switching all of Keras to new serialization format. #1039

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,11 @@ def testStripClusteringSequentialModel(self):
stripped_model = cluster.strip_clustering(clustered_model)

self.assertEqual(self._count_clustered_layers(stripped_model), 0)
self.assertEqual(model.get_config(), stripped_model.get_config())
model_config = model.get_config()
for layer in model_config['layers']:
# New serialization format includes `build_config` in wrapper
layer.pop('build_config', None)
self.assertEqual(model_config, stripped_model.get_config())

def testClusterStrippingFunctionalModel(self):
"""Verifies that stripping the clustering wrappers from a functional model produces the expected config."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ py_strict_test(
visibility = ["//visibility:public"],
deps = [
":quantizers",
":utils",
# absl/testing:parameterized dep1,
# numpy dep1,
# tensorflow dep1,
Expand All @@ -87,9 +88,10 @@ py_strict_library(
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
":quantizers",
":utils",
# six dep1,
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
],
)

Expand Down Expand Up @@ -125,6 +127,7 @@ py_strict_library(
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
":utils",
# tensorflow dep1,
],
)
Expand Down Expand Up @@ -152,6 +155,7 @@ py_strict_library(
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
":utils",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:utils",
],
Expand All @@ -167,6 +171,7 @@ py_strict_test(
deps = [
":quantize_aware_activation",
":quantizers",
":utils",
# absl/testing:parameterized dep1,
# numpy dep1,
# tensorflow dep1,
Expand All @@ -182,6 +187,7 @@ py_strict_library(
visibility = ["//visibility:public"],
deps = [
":quantizers",
":utils",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:utils",
],
Expand Down Expand Up @@ -211,6 +217,7 @@ py_strict_library(
visibility = ["//visibility:public"],
deps = [
":quantize_aware_activation",
":utils",
# tensorflow dep1,
# python/util tensorflow dep2,
"//tensorflow_model_optimization/python/core/keras:metrics",
Expand Down Expand Up @@ -249,6 +256,7 @@ py_strict_library(
":quantize_layer",
":quantize_wrapper",
":quantizers",
":utils",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:metrics",
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
import tensorflow as tf

from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry

keras = tf.keras
K = tf.keras.backend
l = tf.keras.layers

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = quantize_utils.deserialize_keras_object
serialize_keras_object = quantize_utils.serialize_keras_object


class _TestHelper(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
Expand Down Expand Up @@ -67,13 +68,17 @@ def _get_params(conv_layer, bn_layer, relu_layer=None):
list(conv_layer['config'].items()) + list(bn_layer['config'].items()))

if relu_layer is not None:
params['post_activation'] = keras.layers.deserialize(relu_layer)
params['post_activation'] = quantize_utils.deserialize_layer(
relu_layer, use_legacy_format=True
)

return params


def _get_layer_node(fused_layer, weights):
layer_config = keras.layers.serialize(fused_layer)
layer_config = quantize_utils.serialize_layer(
fused_layer, use_legacy_format=True
)
layer_config['name'] = layer_config['config']['name']
# This config tracks which layers get quantized, and whether they have a
# custom QuantizeConfig.
Expand Down Expand Up @@ -118,7 +123,10 @@ def _replace(self, bn_layer_node, conv_layer_node):
return bn_layer_node

conv_layer_node.layer['config']['activation'] = (
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
quantize_utils.serialize_activation(
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
)
)
bn_layer_node.metadata['quantize_config'] = (
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig())

Expand Down Expand Up @@ -180,7 +188,10 @@ def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node):
return relu_layer_node

conv_layer_node.layer['config']['activation'] = (
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
quantize_utils.serialize_activation(
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
)
)
bn_layer_node.metadata['quantize_config'] = (
default_8bit_quantize_configs.NoOpQuantizeConfig())

Expand Down Expand Up @@ -261,7 +272,10 @@ def _replace(self, bn_layer_node, dense_layer_node):
return bn_layer_node

dense_layer_node.layer['config']['activation'] = (
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
quantize_utils.serialize_activation(
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
)
)
bn_layer_node.metadata['quantize_config'] = (
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig())

Expand Down Expand Up @@ -297,7 +311,10 @@ def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node):
return relu_layer_node

dense_layer_node.layer['config']['activation'] = (
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
quantize_utils.serialize_activation(
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
)
)
bn_layer_node.metadata['quantize_config'] = (
default_8bit_quantize_configs.NoOpQuantizeConfig())

Expand Down Expand Up @@ -408,7 +425,9 @@ def replacement(self, match_layer):
else:
spatial_dim = 2

sepconv2d_layer_config = keras.layers.serialize(sepconv2d_layer)
sepconv2d_layer_config = quantize_utils.serialize_layer(
sepconv2d_layer, use_legacy_format=True
)
sepconv2d_layer_config['name'] = sepconv2d_layer.name

# Needed to ensure these new layers are considered for quantization.
Expand All @@ -420,15 +439,19 @@ def replacement(self, match_layer):
expand_layer = tf.keras.layers.Lambda(
lambda x: tf.expand_dims(x, spatial_dim),
name=self._get_name('sepconv1d_expand'))
expand_layer_config = keras.layers.serialize(expand_layer)
expand_layer_config = quantize_utils.serialize_layer(
expand_layer, use_legacy_format=True
)
expand_layer_config['name'] = expand_layer.name
expand_layer_metadata = {
'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()}

squeeze_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, [spatial_dim]),
name=self._get_name('sepconv1d_squeeze'))
squeeze_layer_config = keras.layers.serialize(squeeze_layer)
squeeze_layer_config = quantize_utils.serialize_layer(
squeeze_layer, use_legacy_format=True
)
squeeze_layer_config['name'] = squeeze_layer.name
squeeze_layer_metadata = {
'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()}
Expand Down Expand Up @@ -493,7 +516,9 @@ def replacement(self, match_layer):
)
dconv_weights = collections.OrderedDict()
dconv_weights['depthwise_kernel:0'] = sepconv_weights[0]
dconv_layer_config = keras.layers.serialize(dconv_layer)
dconv_layer_config = quantize_utils.serialize_layer(
dconv_layer, use_legacy_format=True
)
dconv_layer_config['name'] = dconv_layer.name
# Needed to ensure these new layers are considered for quantization.
dconv_metadata = {'quantize_config': None}
Expand Down Expand Up @@ -521,7 +546,9 @@ def replacement(self, match_layer):
conv_weights['kernel:0'] = sepconv_weights[1]
if sepconv_layer['config']['use_bias']:
conv_weights['bias:0'] = sepconv_weights[2]
conv_layer_config = keras.layers.serialize(conv_layer)
conv_layer_config = quantize_utils.serialize_layer(
conv_layer, use_legacy_format=True
)
conv_layer_config['name'] = conv_layer.name
# Needed to ensure these new layers are considered for quantization.
conv_metadata = {'quantize_config': None}
Expand Down Expand Up @@ -588,7 +615,9 @@ def replacement(self, match_layer):
quant_layer = quantize_layer.QuantizeLayer(
quantizers.AllValuesQuantizer(
num_bits=8, per_axis=False, symmetric=False, narrow_range=False))
layer_config = keras.layers.serialize(quant_layer)
layer_config = quantize_utils.serialize_layer(
quant_layer, use_legacy_format=True
)
layer_config['name'] = quant_layer.name

quant_layer_node = LayerNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -707,4 +707,6 @@ def testConcatConcatTransformDisablesOutput(self):


if __name__ == '__main__':
if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'):
tf.keras.__internal__.enable_unsafe_deserialization()
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,6 @@ def testModelEndToEnd(self, model_fn):


if __name__ == '__main__':
if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'):
tf.keras.__internal__.enable_unsafe_deserialization()
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
import tensorflow as tf

from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry as n_bit_registry

keras = tf.keras
K = tf.keras.backend
l = tf.keras.layers

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = quantize_utils.deserialize_keras_object
serialize_keras_object = quantize_utils.serialize_keras_object


class _TestHelper(object):
Expand Down