Replies: 1 comment
-
Hi, I believe I have accomplished a similar thing using def train_step(state, batch):
...
def initialization(model, learning_rate, input_size, seed, weight_decay):
# keep this function as written
...
# create functions -- assume we have a list of learning rates, seeds, and weight decays
initializers = [
jtu.Partial(initialization, model=model, learning_rate=lr, input_size=input_size, seed=seed, weight_decay=wd)
for lr, seed, wd in zip(learning_rates, seeds, weight_decays)
]
# vmap over initialisers
states = jax.vmap(jtu.Partial(jax.lax.switch, branches=initializers))(jnp.arange(len(learning_rates)))
# vectorise the stepping function over the first dimension of the states
fn_step = jax.vmap(train_step, in_axes=(0, None)) This is quite a crude example, but I think you can play around with it to achieve what you are asking for Note: I did not test this code, but I think the idea is clear. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I made a script to perform a gridsearch over hyperparameters using vmap. I want to pass parameters such as the seed, weight decay, and learning rate to train a model using optax and flax. This works fine for variables like seed that are not changed inside vmap. However, the learning rate is modified inside the optax optimizer resulting in a side effect.
How can you pass variables to vmap that are changed during execution? Is this even possible?
My code is as follows:
Beta Was this translation helpful? Give feedback.
All reactions