-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Passing arguments to train multiple models in parallel #932
Comments
Hello @kclauw,
Thanks for reaching out |
Hi, Thanks
I looked at the code of adamw:
The problem is due to adamw (and SGD etc) changing the learning rate via transform.scale_by_learning_rate(learning_rate) see (scale(m * learning_rate). What would be the best way to deal with having to pass arguments that will change during vmap? if this is even possible? I figure this will also become a problem when passing weight decay arguments. |
When dealing with parameters that change during |
Hello @kclauw, Sorry for the delayed answer.
|
Hi,
I want to perform a gridsearch over different arguments to train multiple models in parallel using optax and flax. My initial idea is to pass an array of learning rates to an initialization function using vmap but it results in a side effect transformation error.
What is the best way to pass a list of arguments and can this be solved? The issue seems to be related to the adamw optimizer which I believe modifies the learning rate parameter?
I have attached a reduced example of my code:
The text was updated successfully, but these errors were encountered: