-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix LayerProfile class check with SavedModels #719
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Can you also add a unittest to make sure we don't run into this issue again?
@lgeiger The test was a good call, it turns out there is a more serious issue at play here as well. I've pushed an unfinished test that prints the weightprofiles of the layer before and after SavedModel loading, and they're different: # op_profiles, [w.bitwidth for w in profile.weight_profiles], mac_containing_layer, input_precision
[OperationProfile(n=1179648, precision=1, op_type='mac')] [1, 32] True 1
[OperationProfile(n=1179648, precision=32, op_type='mac')] [32, 32] True 1 At this point this is a bit beyond the parts of Larq I'm familiar with, so I'll have to postpone fixing this for now. I suspect the weight precision somehow isn't part of the layer config or something like that... |
To expand on the above, the problem is as follows: from tempfile import TemporaryDirectory
import tensorflow as tf
from tensorflow.python.keras.utils.generic_utils import get_custom_objects
import larq as lq
model = tf.keras.models.Sequential(
[
lq.layers.QuantConv2D(
filters=32,
kernel_size=(3, 3),
kernel_quantizer="ste_sign",
input_quantizer="ste_sign",
input_shape=(64, 64, 1),
padding="same",
)
]
)
# Save and reload
with TemporaryDirectory() as dir:
model.save(dir)
del get_custom_objects()["QuantConv2D"]
loaded_model = tf.keras.models.load_model(dir, compile=False)
# Pre-save
print(type(model.layers[0].weights[0]))
print(model.layers[0].weights[0].precision)
# Loaded model
print(type(loaded_model.layers[0].weights[0]))
print(loaded_model.layers[0].weights[0].precision)
|
Unfortunately there is no way to check whether the
RevivedLayer
with nameQuantConv2D
was originally alarq
layer and not some custom layer with the same name, but the only situation in which that'd break is if you not only subclass alarq
layer with an identical name, but also change whether it has MACs, which seems extremely unlikely.