Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Slicing 2 #169

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open

Module Slicing 2 #169

wants to merge 15 commits into from

Conversation

alexander-g
Copy link
Contributor

@alexander-g alexander-g commented Feb 28, 2021

Alternative approach to #115

  • This is now based on jax.make_jaxpr and jax.named_call, not on summaries. Specifically, I've extended hooks context with a new attribute, if it is active, Module.call will be wrapped with jax.named_call. This allows identifying individual modules in the low-level jaxpr commands, so the interface is basically the same as before. However, one can now use arbitrary JAX functions without having to wrap them in modules:
class Module(elegy.Module):
    def call(self, x):
        x = x/255.
        x = elegy.nn.Linear(300, name="linear0")(x)
        x = jax.nn.relu(x)
        x = elegy.nn.Linear(100, name="linear1")(x)
        x = jax.nn.relu(x)
        x = elegy.nn.Linear(10,  name="linear2")(x)
        return x

x         = np.random.random([32,1024])
model     = elegy.Model(Module())
submodule = model.slice(start='input', end=['linear0', 'linear1'], sample_input=x)
out       = elegy.Model(submodule).predict(x, initialize=True)

assert out[0].shape == (32,300)
assert out[1].shape == (32,100)
  • Some limitations still apply:
    • Only Elegy Modules are possible as start and end targets
    • Only one start target allowed. This is possible to fix, but might get complex
  • Internal logic is simpler, not graph-based as before
  • WIP, still need to cover some edge cases and test more

elegy/slicing.py Outdated
sample_input: np.ndarray,
) -> elegy.Model:

model.maybe_initialize(elegy.types.Mode.pred, x=sample_input)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you update to master maybe you can check for model.initialized since maybe_initialize is gone.

@cgarciae
Copy link
Collaborator

cgarciae commented Mar 1, 2021

Thanks @alexander-g!

I have one main question: are the parameters being transferred to the new module thanks to Module.set_default_parameters?

I am assuming the answer is yes which is fine, however, if this is so, should this method be part of elegy.Module instead? This might make more sense now that we support other module systems.

Due to recent changes your example might require some minor modifications:

  • To apply set_default_parameters to an elegy Module from Model use model.update_module()
  • To initialize the new Model on predict you can use model.predict(x, initialize=True)

@alexander-g alexander-g changed the title [WIP] Module Slicing 2 Module Slicing 2 Mar 8, 2021
@alexander-g alexander-g marked this pull request as ready for review March 8, 2021 14:23
@alexander-g
Copy link
Contributor Author

I have one main question: are the parameters being transferred to the new module thanks to Module.set_default_parameters?

Parameters are transferred but I do not call set_default_parameters explicitly. I construct a new Module with the old modules as submodules. Not sure how exactly this works but it does.

should this method be part of elegy.Module instead?

Yes, I've added Model.slice()

To initialize the new Model on predict you can use model.predict(x, initialize=True)

This is a bit annoying and counterintuitive that I have to initialize a pretrained Model. Can this be avoided?


Ready for review.
How to I trigger the tests?

@cgarciae
Copy link
Collaborator

cgarciae commented Mar 8, 2021

Parameters are transferred but I do not call set_default_parameters explicitly. I construct a new Module with the old modules as submodules. Not sure how exactly this works but it does.

Interesting. In the latest version I think you have to be more explicit about this but its just a single line of code.

This is a bit annoying and counterintuitive that I have to initialize a pretrained Model. Can this be avoided?

Ok so there is a pathological case here I'll show next so I coded a "defensive" solution but I am open to suggestions:

We conditionally initialize on predict, evaluate, fit based on whether Model is initialize or not and using the available information for that method (x, y, sample_weights, etc) so if you have losses and metrics doing this is ok:

model.fit(X_train, y_train)
preds = model.predict(X_train)

but doing this will raise exceptions on initialization since the losses and metrics don't get the label information because predict doesn't have it:

preds = model.predict(X_train)
model.fit(X_train, y_train)

So the proposed solution is to have the user opt-in via initialize=True and raise a friendly error if he didn't with information on how it can fixed the issue instead of just letting the program crash. All this stems from #165.

@cgarciae
Copy link
Collaborator

cgarciae commented Mar 8, 2021

@alexander-g Now that its not a draft I think just commit a random change.

@alexander-g
Copy link
Contributor Author

Interesting. In the latest version I think you have to be more explicit about this but its just a single line of code.

You're right, it's required, I've added it. It had worked before simply because of the same random seed.

@cgarciae
Copy link
Collaborator

@alexander-g is this ready for review? :)

@alexander-g
Copy link
Contributor Author

Yes it is.
Just added some comments in the inner logic.

@cgarciae
Copy link
Collaborator

cgarciae commented Apr 1, 2021

Hey @alexander-g! Sorry for the hiatus, a couple of change important changes recently kept me busy, good new is that my new employer is sponsoring a couple of hours a week for me to work on Elegy continuously :)

Regarding this PR, I definitely think its best that slice becomes a feature of elegy.Module instead of elegy.Model since this won't work with any other Module system.

@codecov-io
Copy link

Codecov Report

Merging #169 (b39da4d) into master (2c78a78) will increase coverage by 0.65%.
The diff coverage is 98.22%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #169      +/-   ##
==========================================
+ Coverage   87.19%   87.85%   +0.65%     
==========================================
  Files         136      138       +2     
  Lines        7433     7770     +337     
==========================================
+ Hits         6481     6826     +345     
+ Misses        952      944       -8     
Impacted Files Coverage Δ
elegy/slicing_test.py 97.90% <97.90%> (ø)
elegy/slicing.py 98.13% <98.13%> (ø)
elegy/__init__.py 84.00% <100.00%> (+0.66%) ⬆️
elegy/hooks.py 86.53% <100.00%> (+0.92%) ⬆️
elegy/hooks_test.py 100.00% <100.00%> (ø)
elegy/module.py 95.56% <100.00%> (+0.30%) ⬆️
elegy/callbacks/progbar_logger.py 67.88% <0.00%> (+5.28%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 2c78a78...b39da4d. Read the comment docs.

@alexander-g
Copy link
Contributor Author

  • moved from Model to Module. The user has to call model.update_modules() manually to ensure parameters are transferred.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants