You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
Stripping the pruning layers seems to somehow disconnect the input layer from the graph.
System information
TensorFlow version (installed from source or binary): 2.11 (macos)
TensorFlow Model Optimization version (installed from source or binary): 0.7.4
Python version: 3.10
Describe the expected behavior
Pruning a model during training, stripping the pruning layers, then creating a new model based on a subset of layers (e.g. to remove additional targets used during training) should work, if I didn't miss anything. Describe the current behavior
It fails, although doing it in the order of pruning it, creating the model and then stripping works.
Code to reproduce the issue
importtempfileimporttensorflowastfimportnumpyasnpfromtensorflowimportkerasimporttensorflow_model_optimizationastfmotfromsrc.common.pathimportMODELS_DIRif__name__=='__main__':
# Load MNIST datasetmnist=keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) =mnist.load_data()
# Normalize the input image so that each pixel value is between 0 and 1.train_images=train_images/255.0test_images=test_images/255.0# Define the model architecture.model=keras.Sequential(
[
keras.layers.InputLayer(input_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10),
]
)
model=tf.keras.Model(inputs=model.inputs, outputs=model.outputs)
prune_low_magnitude=tfmot.sparsity.keras.prune_low_magnitude# Compute end step to finish pruning after 2 epochs.batch_size=128epochs=1validation_split=0.1# 10% of training set will be used for validation set.num_images=train_images.shape[0] * (1-validation_split)
end_step=np.ceil(num_images/batch_size).astype(np.int32) *epochsprint("end step", end_step)
# Define model for pruning.pruning_params= {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.05,
final_sparsity=0.95,
begin_step=1,
end_step=end_step,
frequency=422,
)
}
model_for_pruning=prune_low_magnitude(model, **pruning_params)
# `prune_low_magnitude` requires a recompile.model_for_pruning.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'],
)
model_for_pruning.summary()
logdir=tempfile.mkdtemp()
callbacks= [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]
model_for_pruning.fit(
train_images,
train_labels,
batch_size=batch_size,
epochs=epochs,
validation_split=validation_split,
callbacks=callbacks,
)
Given the above setup code, running the following snippet fails:
pruned_model=tfmot.sparsity.keras.strip_pruning(model_for_pruning)
inputs= [pruned_model.get_layer("input_1").input]
outputs=pruned_model.get_layer("dense").output_new_model=tf.keras.Model(inputs=inputs, outputs=outputs) # ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'") at layer "conv2d". The following previous layers were accessed without issue: []
Describe the bug
Stripping the pruning layers seems to somehow disconnect the input layer from the graph.
System information
TensorFlow version (installed from source or binary): 2.11 (macos)
TensorFlow Model Optimization version (installed from source or binary): 0.7.4
Python version: 3.10
Describe the expected behavior
Pruning a model during training, stripping the pruning layers, then creating a new model based on a subset of layers (e.g. to remove additional targets used during training) should work, if I didn't miss anything.
Describe the current behavior
It fails, although doing it in the order of pruning it, creating the model and then stripping works.
Code to reproduce the issue
Given the above setup code, running the following snippet fails:
while the following snippets works
The text was updated successfully, but these errors were encountered: