Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stripping disconnects input layer from graph #1063

Open
christian-steinmeyer opened this issue Apr 25, 2023 · 0 comments
Open

Stripping disconnects input layer from graph #1063

christian-steinmeyer opened this issue Apr 25, 2023 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@christian-steinmeyer
Copy link

Describe the bug
Stripping the pruning layers seems to somehow disconnect the input layer from the graph.

System information

TensorFlow version (installed from source or binary): 2.11 (macos)

TensorFlow Model Optimization version (installed from source or binary): 0.7.4

Python version: 3.10

Describe the expected behavior
Pruning a model during training, stripping the pruning layers, then creating a new model based on a subset of layers (e.g. to remove additional targets used during training) should work, if I didn't miss anything.
Describe the current behavior
It fails, although doing it in the order of pruning it, creating the model and then stripping works.

Code to reproduce the issue

import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_model_optimization as tfmot

from src.common.path import MODELS_DIR

if __name__ == '__main__':
    # Load MNIST dataset
    mnist = keras.datasets.mnist
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()

    # Normalize the input image so that each pixel value is between 0 and 1.
    train_images = train_images / 255.0
    test_images = test_images / 255.0

    # Define the model architecture.
    model = keras.Sequential(
        [
            keras.layers.InputLayer(input_shape=(28, 28, 1)),
            keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
            keras.layers.MaxPooling2D(pool_size=(2, 2)),
            keras.layers.Flatten(),
            keras.layers.Dense(10),
        ]
    )
    model = tf.keras.Model(inputs=model.inputs, outputs=model.outputs)

    prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

    # Compute end step to finish pruning after 2 epochs.
    batch_size = 128
    epochs = 1
    validation_split = 0.1  # 10% of training set will be used for validation set.

    num_images = train_images.shape[0] * (1 - validation_split)
    end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
    print("end step", end_step)
    # Define model for pruning.
    pruning_params = {
        'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
            initial_sparsity=0.05,
            final_sparsity=0.95,
            begin_step=1,
            end_step=end_step,
            frequency=422,
        )
    }

    model_for_pruning = prune_low_magnitude(model, **pruning_params)

    # `prune_low_magnitude` requires a recompile.
    model_for_pruning.compile(
        optimizer='adam',
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy'],
    )

    model_for_pruning.summary()

    logdir = tempfile.mkdtemp()

    callbacks = [
        tfmot.sparsity.keras.UpdatePruningStep(),
        tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
    ]

    model_for_pruning.fit(
        train_images,
        train_labels,
        batch_size=batch_size,
        epochs=epochs,
        validation_split=validation_split,
        callbacks=callbacks,
    )

Given the above setup code, running the following snippet fails:

    pruned_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    inputs = [pruned_model.get_layer("input_1").input]
    outputs = pruned_model.get_layer("dense").output
    _new_model = tf.keras.Model(inputs=inputs, outputs=outputs)  # ValueError: Graph disconnected: cannot obtain value for tensor KerasTensor(type_spec=TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'") at layer "conv2d". The following previous layers were accessed without issue: []

while the following snippets works

    inputs = [model_for_pruning.get_layer("input_1").input]
    outputs = model_for_pruning.get_layer("prune_low_magnitude_dense").output
    _new_model = tf.keras.Model(inputs=inputs, outputs=outputs)
    _new_model = tfmot.sparsity.keras.strip_pruning(_new_model)
    pruned_model = tfmot.sparsity.keras.strip_pruning(model_for_pruning)
    inputs = [pruned_model.get_layer("conv2d").input]  # skipping the input layer
    outputs = pruned_model.get_layer("dense").output
    _new_model = tf.keras.Model(inputs=inputs, outputs=outputs)
@christian-steinmeyer christian-steinmeyer added the bug Something isn't working label Apr 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants