Skip to content

Commit

Permalink
Creates non-breaking changes where necessary in preparation for switc…
Browse files Browse the repository at this point in the history
…hing all of Keras to new serialization format.

PiperOrigin-RevId: 507864605
  • Loading branch information
nkovela1 authored and tensorflower-gardener committed Feb 13, 2023
1 parent 8c33592 commit c504ca6
Show file tree
Hide file tree
Showing 25 changed files with 311 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,11 @@ def testStripClusteringSequentialModel(self):
stripped_model = cluster.strip_clustering(clustered_model)

self.assertEqual(self._count_clustered_layers(stripped_model), 0)
self.assertEqual(model.get_config(), stripped_model.get_config())
model_config = model.get_config()
for layer in model_config['layers']:
# New serialization format includes `build_config` in wrapper
layer.pop('build_config', None)
self.assertEqual(model_config, stripped_model.get_config())

def testClusterStrippingFunctionalModel(self):
"""Verifies that stripping the clustering wrappers from a functional model produces the expected config."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ py_strict_test(
visibility = ["//visibility:public"],
deps = [
":quantizers",
":utils",
# absl/testing:parameterized dep1,
# numpy dep1,
# tensorflow dep1,
Expand All @@ -87,9 +88,10 @@ py_strict_library(
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
":quantizers",
":utils",
# six dep1,
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/quantization/keras:quantizers",
],
)

Expand Down Expand Up @@ -125,6 +127,7 @@ py_strict_library(
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
":utils",
# tensorflow dep1,
],
)
Expand Down Expand Up @@ -152,6 +155,7 @@ py_strict_library(
srcs_version = "PY3",
visibility = ["//visibility:public"],
deps = [
":utils",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:utils",
],
Expand All @@ -167,6 +171,7 @@ py_strict_test(
deps = [
":quantize_aware_activation",
":quantizers",
":utils",
# absl/testing:parameterized dep1,
# numpy dep1,
# tensorflow dep1,
Expand All @@ -182,6 +187,7 @@ py_strict_library(
visibility = ["//visibility:public"],
deps = [
":quantizers",
":utils",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:utils",
],
Expand Down Expand Up @@ -211,6 +217,7 @@ py_strict_library(
visibility = ["//visibility:public"],
deps = [
":quantize_aware_activation",
":utils",
# tensorflow dep1,
# python/util tensorflow dep2,
"//tensorflow_model_optimization/python/core/keras:metrics",
Expand Down Expand Up @@ -249,6 +256,7 @@ py_strict_library(
":quantize_layer",
":quantize_wrapper",
":quantizers",
":utils",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:metrics",
"//tensorflow_model_optimization/python/core/quantization/keras/default_8bit:default_8bit_quantize_registry",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
import tensorflow as tf

from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry

keras = tf.keras
K = tf.keras.backend
l = tf.keras.layers

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = quantize_utils.deserialize_keras_object
serialize_keras_object = quantize_utils.serialize_keras_object


class _TestHelper(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation
from tensorflow_model_optimization.python.core.quantization.keras import quantize_layer
from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_configs
from tensorflow_model_optimization.python.core.quantization.keras.default_8bit import default_8bit_quantize_registry
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms
Expand Down Expand Up @@ -67,13 +68,17 @@ def _get_params(conv_layer, bn_layer, relu_layer=None):
list(conv_layer['config'].items()) + list(bn_layer['config'].items()))

if relu_layer is not None:
params['post_activation'] = keras.layers.deserialize(relu_layer)
params['post_activation'] = quantize_utils.deserialize_layer(
relu_layer, use_legacy_format=True
)

return params


def _get_layer_node(fused_layer, weights):
layer_config = keras.layers.serialize(fused_layer)
layer_config = quantize_utils.serialize_layer(
fused_layer, use_legacy_format=True
)
layer_config['name'] = layer_config['config']['name']
# This config tracks which layers get quantized, and whether they have a
# custom QuantizeConfig.
Expand Down Expand Up @@ -118,7 +123,10 @@ def _replace(self, bn_layer_node, conv_layer_node):
return bn_layer_node

conv_layer_node.layer['config']['activation'] = (
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
quantize_utils.serialize_activation(
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
)
)
bn_layer_node.metadata['quantize_config'] = (
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig())

Expand Down Expand Up @@ -180,7 +188,10 @@ def _replace(self, relu_layer_node, bn_layer_node, conv_layer_node):
return relu_layer_node

conv_layer_node.layer['config']['activation'] = (
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
quantize_utils.serialize_activation(
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
)
)
bn_layer_node.metadata['quantize_config'] = (
default_8bit_quantize_configs.NoOpQuantizeConfig())

Expand Down Expand Up @@ -261,7 +272,10 @@ def _replace(self, bn_layer_node, dense_layer_node):
return bn_layer_node

dense_layer_node.layer['config']['activation'] = (
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
quantize_utils.serialize_activation(
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
)
)
bn_layer_node.metadata['quantize_config'] = (
default_8bit_quantize_configs.Default8BitOutputQuantizeConfig())

Expand Down Expand Up @@ -297,7 +311,10 @@ def _replace(self, relu_layer_node, bn_layer_node, dense_layer_node):
return relu_layer_node

dense_layer_node.layer['config']['activation'] = (
keras.activations.serialize(quantize_aware_activation.NoOpActivation()))
quantize_utils.serialize_activation(
quantize_aware_activation.NoOpActivation(), use_legacy_format=True
)
)
bn_layer_node.metadata['quantize_config'] = (
default_8bit_quantize_configs.NoOpQuantizeConfig())

Expand Down Expand Up @@ -408,7 +425,9 @@ def replacement(self, match_layer):
else:
spatial_dim = 2

sepconv2d_layer_config = keras.layers.serialize(sepconv2d_layer)
sepconv2d_layer_config = quantize_utils.serialize_layer(
sepconv2d_layer, use_legacy_format=True
)
sepconv2d_layer_config['name'] = sepconv2d_layer.name

# Needed to ensure these new layers are considered for quantization.
Expand All @@ -420,15 +439,19 @@ def replacement(self, match_layer):
expand_layer = tf.keras.layers.Lambda(
lambda x: tf.expand_dims(x, spatial_dim),
name=self._get_name('sepconv1d_expand'))
expand_layer_config = keras.layers.serialize(expand_layer)
expand_layer_config = quantize_utils.serialize_layer(
expand_layer, use_legacy_format=True
)
expand_layer_config['name'] = expand_layer.name
expand_layer_metadata = {
'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()}

squeeze_layer = tf.keras.layers.Lambda(
lambda x: tf.squeeze(x, [spatial_dim]),
name=self._get_name('sepconv1d_squeeze'))
squeeze_layer_config = keras.layers.serialize(squeeze_layer)
squeeze_layer_config = quantize_utils.serialize_layer(
squeeze_layer, use_legacy_format=True
)
squeeze_layer_config['name'] = squeeze_layer.name
squeeze_layer_metadata = {
'quantize_config': default_8bit_quantize_configs.NoOpQuantizeConfig()}
Expand Down Expand Up @@ -493,7 +516,9 @@ def replacement(self, match_layer):
)
dconv_weights = collections.OrderedDict()
dconv_weights['depthwise_kernel:0'] = sepconv_weights[0]
dconv_layer_config = keras.layers.serialize(dconv_layer)
dconv_layer_config = quantize_utils.serialize_layer(
dconv_layer, use_legacy_format=True
)
dconv_layer_config['name'] = dconv_layer.name
# Needed to ensure these new layers are considered for quantization.
dconv_metadata = {'quantize_config': None}
Expand Down Expand Up @@ -521,7 +546,9 @@ def replacement(self, match_layer):
conv_weights['kernel:0'] = sepconv_weights[1]
if sepconv_layer['config']['use_bias']:
conv_weights['bias:0'] = sepconv_weights[2]
conv_layer_config = keras.layers.serialize(conv_layer)
conv_layer_config = quantize_utils.serialize_layer(
conv_layer, use_legacy_format=True
)
conv_layer_config['name'] = conv_layer.name
# Needed to ensure these new layers are considered for quantization.
conv_metadata = {'quantize_config': None}
Expand Down Expand Up @@ -588,7 +615,9 @@ def replacement(self, match_layer):
quant_layer = quantize_layer.QuantizeLayer(
quantizers.AllValuesQuantizer(
num_bits=8, per_axis=False, symmetric=False, narrow_range=False))
layer_config = keras.layers.serialize(quant_layer)
layer_config = quantize_utils.serialize_layer(
quant_layer, use_legacy_format=True
)
layer_config['name'] = quant_layer.name

quant_layer_node = LayerNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -707,4 +707,6 @@ def testConcatConcatTransformDisablesOutput(self):


if __name__ == '__main__':
if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'):
tf.keras.__internal__.enable_unsafe_deserialization()
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,6 @@ def testModelEndToEnd(self, model_fn):


if __name__ == '__main__':
if hasattr(tf.keras.__internal__, 'enable_unsafe_deserialization'):
tf.keras.__internal__.enable_unsafe_deserialization()
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@
import tensorflow as tf

from tensorflow_model_optimization.python.core.quantization.keras import quantizers
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.quantization.keras.experimental.default_n_bit import default_n_bit_quantize_registry as n_bit_registry

keras = tf.keras
K = tf.keras.backend
l = tf.keras.layers

deserialize_keras_object = tf.keras.utils.deserialize_keras_object
serialize_keras_object = tf.keras.utils.serialize_keras_object
deserialize_keras_object = quantize_utils.deserialize_keras_object
serialize_keras_object = quantize_utils.serialize_keras_object


class _TestHelper(object):
Expand Down

0 comments on commit c504ca6

Please sign in to comment.