Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problems when jitting Adafactor with inject_hyperparams. #412

Open
mkunesch opened this issue Sep 7, 2022 · 0 comments
Open

Problems when jitting Adafactor with inject_hyperparams. #412

mkunesch opened this issue Sep 7, 2022 · 0 comments
Assignees

Comments

@mkunesch
Copy link
Member

mkunesch commented Sep 7, 2022

When wrapping optax.adafactor with optax.inject_hyperparams without specifying static_args

optax.inject_hyperparams(optax.adafactor)(learning_rate=0.1)

the init function of the resulting GradientTransformation cannot be jit compiled. The reason is that by default inject_hyperparams treats all arguments as dynamic and one of the argument has to be static to avoid a TracerError. A workaround is to specify the static argument:

optax.inject_hyperparams(optax.adafactor, static_args=("min_dim_size_to_factor",))(learning_rate=0.1)

However, this is not ideal since it requires the user to know which arguments should be static and which ones can be dynamic.

We should:

  • Add a test to check whether any other optimizers are affected.
  • Change the implementations so that all optimizers wrapped in inject_hyperparams can be jit compiled without any arguments being specified as static.
@mkunesch mkunesch self-assigned this Sep 7, 2022
copybara-service bot pushed a commit that referenced this issue Sep 16, 2022
This PR adds a test to alias_test to ensure that all optimizers can be wrapped
in inject_hyperarams. This is to check whether issue #412 affects other
optimizers too and not just adafactor.

Currently, adafactor needs static_args to be passed. This will be solved as
part of issue #412.

PiperOrigin-RevId: 474808793
copybara-service bot pushed a commit that referenced this issue Sep 21, 2022
This PR adds a test to alias_test to ensure that all optimizers can be wrapped
in inject_hyperarams. This is to check whether issue #412 affects other
optimizers too and not just adafactor.

Currently, adafactor needs static_args to be passed. This will be solved as
part of issue #412.

PiperOrigin-RevId: 474808793
copybara-service bot pushed a commit that referenced this issue Sep 21, 2022
This PR adds a test to alias_test to ensure that all optimizers can be wrapped
in inject_hyperarams. This is to check whether issue #412 affects other
optimizers too and not just adafactor.

Currently, adafactor needs static_args to be passed. This will be solved as
part of issue #412.

PiperOrigin-RevId: 475810505
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant