Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding clone_layer_graph function to clone a graph of layers without cloning actual layers #19600

Closed
wants to merge 14 commits into from

Conversation

martin-gorner
Copy link
Contributor

Clone the layer graph between input and output. Actual layers are NOT cloned,
but shared. Only the graph of layers is re-created. The clone_fn function
is called when cloning each node (i.e. layer invocation in the graph) which
allows you to modify the topology of the graph.

Recommended usage:

def clone_fn(layer, *args, **kwargs)
    # here, you can insert layer, wrap layers, extract activations, etc.
    return layer(*args, **kwargs) # default to identity
model = ... # a keras model
output = clone_layer_graph(model.input, model.output, clone_fn)
new_model = keras.Model(model.input, output)
  • When cloning a layer graph, shared layers remain shared.
  • Since no actual cloning of layers occurs, layers do not need to have
    serialization implemented (i.e. implement get_config()).
  • Cloning a layer graph with nested subgraphs (i.e. layers that are themselves
    Functional or Sequential models) is possible. If a clone_fn is provided,
    the nested subgraph will be cloned as a new Functional (Note: this
    will change Sequential subgraphs to Functional)

Args:
input: Instance of KerasTensor or pytree of KerasTensors.
All inputs must be of type keras.Input. If you wish to
clone a layer graph that starts with intermediate KerasTensors,
you have to create a new Functional model first by calling
model = keras.Model(intermediate_tensors, output) which will
create proper Input tensors instead of the intermediate ones.
output: Instance of KerasTensor or pytree of KerasTensors.
clone_fn: Callable that will be called when each layer in the layer graph is
invoked. The expected signature is clone_fn(layer, *args, **kwargs).
To leave a layer unchanged, return layer(*args, **kwargs).

Examples:

# clone the layer graph identically (actual layers will be shared, not cloned)
def clone_fn(layer, *args, **kwargs):
    output = layer(*args, **kwargs)  # identity call
    return output
model = ... # a keras model
output = clone_layer_graph(model.input, model.output, clone_fn)
new_model = keras.Model(model.input, output)
# wrap every Dense layer in custom layer WrapDense
def clone_fn(layer, *args, **kwargs):
    if isinstance(layer, layers.Dense):
        wrapper = WrapDense(layer)
        return wrapper(*args, **kwargs)
    else:
        return layer(*args, **kwargs)  # default to identity
model = ... # a keras model
output = clone_layer_graph(model.input, model.output, clone_fn)
new_model = keras.Model(model.input, output)
# Insert an extra Dense(128) layer after every Dense layer
def clone_fn(layer, *args, **kwargs):
    if isinstance(layer, layers.Dense):
        output = layer(*args, **kwargs)
        output = layers.Dense(128)(output)
        return output
    else:
        return layer(*args, **kwargs)  # default to identity
model = ... # a keras model
output = clone_layer_graph(model.input, model.output, clone_fn)
new_model = keras.Model(model.input, output)
# Collect inner activations from the model and create a new model that returns them
activations = []
def clone_fn(layer, *args, **kwargs):
    output = layer(*args, **kwargs)  # identity call
    activations.append(output)
    return output
model = ... # a keras model
output = clone_layer_graph(model.input, model.output, clone_fn)
new_output = [output] + activations
new_model = keras.Model(model.input, new_output)

Implementation note: Why not modify the existing functional.clone_graph_nodes ?
Context: The purpose of the existing function is to create a new
functional model from any intermediate Keras tensors.

  1. The existing function does not re-execute the operations (layers) in the graph. This new function does, as it is necessary to run clone_fn. It is therefore best to keep them distinct functions
  2. The existing function copies input tensors to replace them with proper InputLayers. This is essential when creating a subgraph starting at intermediate tensors. This new function always assumes all inputs are InputLayers.
  3. The existing function does not have a clone_fn. Adding one would complexify the code.

@martin-gorner martin-gorner changed the title adding clone_layer_graph function to clone adding clone_layer_graph function to clone a graph of layers without cloning actual layers Apr 23, 2024
@codecov-commenter
Copy link

codecov-commenter commented Apr 23, 2024

Codecov Report

Attention: Patch coverage is 87.17949% with 10 lines in your changes are missing coverage. Please review.

Project coverage is 78.29%. Comparing base (880f0cd) to head (0036744).

Files Patch % Lines
keras/src/models/cloning_layer_graph.py 92.42% 2 Missing and 3 partials ⚠️
keras/src/models/sequential.py 60.00% 2 Missing and 2 partials ⚠️
keras/api/_tf_keras/keras/models/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19600      +/-   ##
==========================================
- Coverage   78.30%   78.29%   -0.01%     
==========================================
  Files         498      499       +1     
  Lines       45477    45555      +78     
  Branches     8382     8398      +16     
==========================================
+ Hits        35610    35668      +58     
- Misses       8107     8117      +10     
- Partials     1760     1770      +10     
Flag Coverage Δ
keras 78.15% <87.17%> (+<0.01%) ⬆️
keras-jax 62.05% <87.17%> (+0.04%) ⬆️
keras-numpy 56.44% <87.17%> (+0.10%) ⬆️
keras-tensorflow 63.40% <87.17%> (+0.04%) ⬆️
keras-torch 62.07% <87.17%> (+0.03%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@martin-gorner
Copy link
Contributor Author

I'll fix the pytorch test failures.
The "code format" check fails in "Check for API changes". Yes, this PR changes the public API. Not sure what to do to validate the change and make the "code format" test pass.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broadly speaking, this looks good and we should incorporate it into the API. For the interface, I think we should modify the existing clone_model function to support this use case.

We could do something like:

def clone_model(model, input_tensors=None, clone_function=None, call_function=None):

Where clone_function would be used to obtain a layer copy (or just return the same layer, if you want to reuse it):

def my_clone_fn(layer):
    return layer

And where call_function would take the layer and its arguments, and return call outputs:

def my_call_fn(layer, *args, **kwargs):
    return layer(*args, **kwargs)

@martin-gorner
Copy link
Contributor Author

Yes, having a single clone function definitely sounds like a better UX.

  • Is it OK to have clone_function and call_function be mutually exclusive? Working with both at the same time could be complicated.
  • what if both clone_function and call_function are None? The right thing to do would probably be to apply layer graph cloning (no pb with shared layers, no memory explosion) but clone_model(model) already has a meaning in the API. Should we go for clone_model(model, mode="clone_layers / clone_layer_graph") ? Another advantage of the "mode" is that with an explicit mode, there is no need for an extra "call_function" parameter. "clone_function" can have a different signature depending on "mode.
  • what to do with the input_tensors parameter when cloning the layer graph? Error out? Call build(input_tensors) at the end? Can build be useful in this case?

@fchollet
Copy link
Member

Is it OK to have clone_function and call_function be mutually exclusive? Working with both at the same time could be complicated.

They're completely independent, no? One creates the layer instance and the other uses it. We should be able to combine them.

what if both clone_function and call_function are None?

They have defaults (default to layer cloning and plain layer call respectively).

what to do with the input_tensors parameter when cloning the layer graph

Again, that argument seems fully independent from the others. It's just the starting node in the graph. It can be used in combination with anything.

@martin-gorner
Copy link
Contributor Author

We should be able to combine them. [clone_fn and call_fn]

Hmm, isn't this the exact equivalent of clone_model:

def clone_fn(layer, *args, **kwargs):
    cloned_layer = layer.__class__.from_config(layer.get_config())
    return cloned_layer(*args, **kwargs)

cloned_model=clone_layer_graph(model, clone_fn=clone_fn)

I'm wondering if there isn't a simpler way of having a single cloning API?

@fchollet
Copy link
Member

A big factor here is backwards compatibility. The new API must be a strict extension of the old API, that does not affect existing workflows. Hence why it's better to keep two clone_ and call_ functions.

@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Apr 24, 2024
@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 24, 2024

Yes, backwards compatibility is important. There is a way to both preserve backwards compatibility and implement the better, simpler API. We can have the new API and implement the old one with it under the hood (and steer people towards the new one in the documentation). Proposal:

def clone(model, clone_fn=None, enter_nested=True)

The old API can be implemented as suggested in the previous comment (with enter_nested=False).

Benefits:

  • The new API is more generic. It can do graph cloning and layer cloning with the same paradigm. A single clone_fn is simpler to understand and reason about for users.
  • The default behavior of clone(model) is totally non-problematic, whereas the default behavior of clone_model(model) has numerous gotchas: 1) requires all layers to have serialization implemented, which is optional, 2) does not handle shared layers, 3) not usable for LLMs because of memory. So fewer possibilities to shoot yourself in the foot.
  • If we implement clone_model(model, clone_fn, call_fn), then the right thing to do in the documentation would be to recommend to never use clone_fn and do all node manipulations (incl. cloning) in call_fn. Unfortunately, that's impossible since clone_fn=None triggers a default clone_fn which does layer cloning. Having these types of conundrums is sign of a problematic API.

@martin-gorner
Copy link
Contributor Author

martin-gorner commented Apr 24, 2024

Unrelated to the previous comment, I wanted to log why I initially went with clone_layer_graph(inputs, outputs) rather than clone(model). None of these reasons are sufficient by themselves to determine which API is the best.

  • Cloning the layer graph only requires input and output Tensors, not a model, hence the more direct API
  • Pretending to "clone a model" is problematic for functional subclassing models: only the layer graph is cloned (which is the desired behavior), no model attributes are cloned, and the resulting model is Fuctional, rather than the original type of the functional subclassing model. Trying to clone model attributes would be a complex can of worms with presently no real use case.
  • clone(model) is problematic for subclassing models while clone_layer_graph(inputs, outputs) if obviously not relevant if you are not in the presence of a symbolic layer graph.

@martin-gorner
Copy link
Contributor Author

The proof of concept PR on KerasNLP for this new API is here: keras-team/keras-nlp#1598

It shows how the new layer graph cloning API can be used in XXXCausalLM to wire caches into the backbone in order to implement call_with_cache in a Keras Functional way. This fixes a keras_nlp issue where the layer graph of the backbone used to initialize any XXXCausalLM would be disregarded and only some layers used.

With these two PRs, users can use the new layer graph cloning API to alter an LLM backbone before initializing an XXXCausalLM with it, for example insert layers to implement the "control vectors" technique (arxiv.org/abs/2310.01405).

A demo Colab can be found here: Model rewiring demo with LLMs.ipynb.

For example, you can insert control vectors into an LLM backbone with this clone_fn applied to the backbone:

def clone_fn(layer, *args, **kwargs):
  if isinstance(layer, keras_nlp.layers.TransformerDecoder):
    x = layer(*args, **kwargs)
    x = ControlVectorLayer()(x)
    return x
  else:
    return layer(*args, **kwargs) # identity

beforeafter

The proof of concept KerasNLP PR also confirms that the new layer graph cloning API can be used to re-wire a backbone with caches for LLM generation. And since the rewired backbone is now a proper Keras Functional model, it can be visualized with plot_model:
cache

@fchollet
Copy link
Member

Merged with alternative API.

@fchollet fchollet closed this Apr 29, 2024
PR Queue automation moved this from Assigned Reviewer to Closed/Rejected Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
PR Queue
Closed/Rejected
Development

Successfully merging this pull request may close these issues.

None yet

4 participants