Skip to content

Commit

Permalink
Rewire TensorFlow to rely on tf_keras target.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 559470121
  • Loading branch information
fchollet authored and tensorflower-gardener committed Sep 14, 2023
1 parent d16ccaa commit b78cbb9
Showing 1 changed file with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,16 @@
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer

try:
from keras.engine import base_layer # pylint: disable=g-import-not-at-top
# OSS case.
import keras # pylint: disable=g-import-not-at-top
if hasattr(keras, 'src'):
# Path as seen in pip packages as of TF/Keras 2.13.
from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member
else:
from keras.engine import base_layer # pylint: disable=g-import-not-at-top,g-importing-member
except ImportError:
# Path as seen in pip packages as of TF/Keras 2.13.
from keras.src.engine import base_layer # pylint: disable=g-import-not-at-top

# TODO(b/139939526): move to public API.
# Internal case.
base_layer = tf._keras_internal.engine.base_layer # pylint: disable=protected-access

layers = tf.keras.layers
layers_compat_v1 = tf.compat.v1.keras.layers
Expand Down

0 comments on commit b78cbb9

Please sign in to comment.