Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing arguments to train multiple models in parallel #932

Open
kclauw opened this issue Apr 14, 2024 · 4 comments
Open

Passing arguments to train multiple models in parallel #932

kclauw opened this issue Apr 14, 2024 · 4 comments

Comments

@kclauw
Copy link

kclauw commented Apr 14, 2024

Hi,

I want to perform a gridsearch over different arguments to train multiple models in parallel using optax and flax. My initial idea is to pass an array of learning rates to an initialization function using vmap but it results in a side effect transformation error.

What is the best way to pass a list of arguments and can this be solved? The issue seems to be related to the adamw optimizer which I believe modifies the learning rate parameter?

I have attached a reduced example of my code:


def calculate_loss_acc(state, params, batch):
    data_input, labels = batch
    logits = state.apply_fn(params, data_input)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    acc = jnp.mean(jnp.argmax(logits, -1) == labels)
    return loss, acc

@jax.jit  # Jit the function for efficiency
def train_step(state, batch):
    # Gradient function
    
    grad_fn = jax.value_and_grad(calculate_loss_acc,  # Function to calculate the loss
                                 argnums=1,  # Parameters are second argument of the function
                                 has_aux=True  # Function has additional outputs, here accuracy
                                )
    # Determine gradients for current model, parameters and batch
    (loss, acc), grads = grad_fn(state, state.params, batch)
    
    # Perform parameter update with gradients and optimizer
    state = state.apply_gradients(grads=grads)
    # Return state and any other value we might want
    return state, loss, acc

def initialization(model, learning_rate, input_size, seed):
    rng = jax.random.PRNGKey(seed)
    
    rng, init_rng = jax.random.split(rng)
    dummy_input = jax.random.normal(init_rng, (8, input_size))  # Batch size 8, input size 2
    params = model.init(init_rng, dummy_input)
    model.apply(params, dummy_input)
    optimizer = optax.adamw(learning_rate=learning_rate)
    model_state = train_state.TrainState.create(apply_fn=model.apply,
                                                params=params,
                                                tx=optimizer)
    return model_state


@hydra.main(version_base=None, config_name="main", config_path="config")
def main(cfg) -> None:
    seed = 0
    num_epochs = 1
    input_size = 194
    output_size = 97
    learning_rates = jnp.array([0.01, 0.1])

    train_dataloader, test_dataloader = get_dataloaders(cfg)
    model = FCNN_2(num_hidden=1000, 
                   num_outputs=output_size, 
                   activation = cfg.model.parameters.activation)

    parallel_init_fn = jax.vmap(initialization, in_axes=(None, 0, None, None))
    parallel_train_step_fn = jax.vmap(train_step, in_axes=(0, None))
    
    params = parallel_init_fn(model, learning_rates, input_size, seed)
    
    for epoch in range(num_epochs):
        #Run training on epoch
        for batch in train_dataloader:
            params, loss, acc = parallel_train_step_fn(params, batch)
            print(loss)
@vroulet
Copy link
Collaborator

vroulet commented Apr 14, 2024

Hello @kclauw,

  1. What error do you get exactly?
  2. Why are you saying that the issue is with adamw? Adamw does not modify the learning rate internally. Have you tried with sgd and did that produce the same error?

Thanks for reaching out

@kclauw
Copy link
Author

kclauw commented Apr 14, 2024

Hi,

Thanks

  1. I am still learning Jax coming from Pytorch but my understanding of the error is that something is changing the value of the learning rate parameter in the initialization function:

params, loss, acc = parallel_train_step_fn(params, batch)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type int32[] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(int32[], weak_type=True)>with<BatchTrace(level=1/0)> with
  val = Array([0, 0], dtype=int32, weak_type=True)
  batch_dim = 0, BatchTrace(level=1/0)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
  1. The code works when using a fixed value. However, when using the learning rate passed by vmap it gives the error. Changing to SGD did not resolve this issue. Based on this, I figured the optax optimizer might be changing the learning rate.
  2. I passed a list of seeds as argument to initialization which is not used by the optimizer. This works fine so the issue seems to only happen when passing the learning rate parameter in combination with any optix optimizer.

I looked at the code of adamw:


def adamw(
    learning_rate: base.ScalarOrSchedule,
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    eps_root: float = 0.0,
    mu_dtype: Optional[Any] = None,
    weight_decay: float = 1e-4,
    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
    *,
    nesterov: bool = False,
) -> base.GradientTransformation:
  return combine.chain(
      transform.scale_by_adam(
          b1=b1,
          b2=b2,
          eps=eps,
          eps_root=eps_root,
          mu_dtype=mu_dtype,
          nesterov=nesterov,
      ),
      transform.add_decayed_weights(weight_decay, mask),
      transform.scale_by_learning_rate(learning_rate),
  )

def scale_by_learning_rate(
    learning_rate: base.ScalarOrSchedule,
    *,
    flip_sign: bool = True,
) -> base.GradientTransformation:
  m = -1 if flip_sign else 1
  if callable(learning_rate):
    return scale_by_schedule(lambda count: m * learning_rate(count))
  return scale(m * learning_rate)

The problem is due to adamw (and SGD etc) changing the learning rate via transform.scale_by_learning_rate(learning_rate) see (scale(m * learning_rate).

What would be the best way to deal with having to pass arguments that will change during vmap? if this is even possible? I figure this will also become a problem when passing weight decay arguments.

@Ekundayo39283
Copy link

When dealing with parameters that change during vmap, like learning rates or weight decay values, you can use partial function application or closures. This allows you to fix certain arguments while leaving others flexible. For instance, you can create a function that takes only the parameters that remain constant during vmap, then partially apply it with the varying parameters within the vmap loop. This ensures that only the necessary parameters are passed through vmap, avoiding unexpected tracer errors.

@vroulet
Copy link
Collaborator

vroulet commented Apr 21, 2024

Hello @kclauw,

Sorry for the delayed answer.

  1. It could help if you would make the example minimal to reproduce the same error (some dependencies are not defined in what you sent). Also you may try to trace the error as suggested just to be sure. It's not clear to me yet if this is really the learning rate that is the culprit here.
  2. If this is truly the learning rate, one quick workaround would be to use optax.inject_hyperparams. So you would instantiate the optimizer as opt= optax.inject_hyperparams(optax.adamw)(learning_rate=1.) outside the vmap and in the vmap you would call the init of the optimizer state = opt.init(params). In the resulting state, you would be able to change the learning rate chosen state = optax.tree_util.tree_set(state, learning_rate=your_learning_rate). The optimizer would then run with the learning rate you chose in the vmap. Happy to try out to be sure but I'd need a minimal example for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants