New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding clone_layer_graph function to clone a graph of layers without cloning actual layers #19600
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19600 +/- ##
==========================================
- Coverage 78.30% 78.29% -0.01%
==========================================
Files 498 499 +1
Lines 45477 45555 +78
Branches 8382 8398 +16
==========================================
+ Hits 35610 35668 +58
- Misses 8107 8117 +10
- Partials 1760 1770 +10
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
I'll fix the pytorch test failures. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broadly speaking, this looks good and we should incorporate it into the API. For the interface, I think we should modify the existing clone_model
function to support this use case.
We could do something like:
def clone_model(model, input_tensors=None, clone_function=None, call_function=None):
Where clone_function
would be used to obtain a layer copy (or just return the same layer, if you want to reuse it):
def my_clone_fn(layer):
return layer
And where call_function
would take the layer and its arguments, and return call outputs:
def my_call_fn(layer, *args, **kwargs):
return layer(*args, **kwargs)
Yes, having a single clone function definitely sounds like a better UX.
|
They're completely independent, no? One creates the layer instance and the other uses it. We should be able to combine them.
They have defaults (default to layer cloning and plain layer call respectively).
Again, that argument seems fully independent from the others. It's just the starting node in the graph. It can be used in combination with anything. |
Hmm, isn't this the exact equivalent of clone_model:
I'm wondering if there isn't a simpler way of having a single cloning API? |
A big factor here is backwards compatibility. The new API must be a strict extension of the old API, that does not affect existing workflows. Hence why it's better to keep two |
Yes, backwards compatibility is important. There is a way to both preserve backwards compatibility and implement the better, simpler API. We can have the new API and implement the old one with it under the hood (and steer people towards the new one in the documentation). Proposal:
The old API can be implemented as suggested in the previous comment (with enter_nested=False). Benefits:
|
Unrelated to the previous comment, I wanted to log why I initially went with
|
The proof of concept PR on KerasNLP for this new API is here: keras-team/keras-nlp#1598 It shows how the new layer graph cloning API can be used in With these two PRs, users can use the new layer graph cloning API to alter an LLM backbone before initializing an A demo Colab can be found here: Model rewiring demo with LLMs.ipynb. For example, you can insert control vectors into an LLM backbone with this clone_fn applied to the backbone: def clone_fn(layer, *args, **kwargs):
if isinstance(layer, keras_nlp.layers.TransformerDecoder):
x = layer(*args, **kwargs)
x = ControlVectorLayer()(x)
return x
else:
return layer(*args, **kwargs) # identity The proof of concept KerasNLP PR also confirms that the new layer graph cloning API can be used to re-wire a backbone with caches for LLM generation. And since the rewired backbone is now a proper Keras Functional model, it can be visualized with |
Merged with alternative API. |
Clone the layer graph between input and output. Actual layers are NOT cloned,
but shared. Only the graph of layers is re-created. The clone_fn function
is called when cloning each node (i.e. layer invocation in the graph) which
allows you to modify the topology of the graph.
Recommended usage:
serialization implemented (i.e. implement
get_config()
).Functional or Sequential models) is possible. If a clone_fn is provided,
the nested subgraph will be cloned as a new Functional (Note: this
will change Sequential subgraphs to Functional)
Args:
input: Instance of
KerasTensor
or pytree ofKerasTensor
s.All inputs must be of type
keras.Input
. If you wish toclone a layer graph that starts with intermediate KerasTensors,
you have to create a new Functional model first by calling
model = keras.Model(intermediate_tensors, output)
which willcreate proper
Input
tensors instead of the intermediate ones.output: Instance of
KerasTensor
or pytree ofKerasTensor
s.clone_fn: Callable that will be called when each layer in the layer graph is
invoked. The expected signature is
clone_fn(layer, *args, **kwargs)
.To leave a layer unchanged,
return layer(*args, **kwargs)
.Examples:
Implementation note: Why not modify the existing functional.clone_graph_nodes ?
Context: The purpose of the existing function is to create a new
functional model from any intermediate Keras tensors.