Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nt.linearize induces a CUDA_ERROR_OUT_OF_MEMORY error #144

Open
RylanSchaeffer opened this issue Mar 5, 2022 · 7 comments
Open

nt.linearize induces a CUDA_ERROR_OUT_OF_MEMORY error #144

RylanSchaeffer opened this issue Mar 5, 2022 · 7 comments
Labels
question Further information is requested

Comments

@RylanSchaeffer
Copy link

RylanSchaeffer commented Mar 5, 2022

We're trying to fine-tune a linearized Vision Transformer by adapting code from https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb.

We're running into a really puzzling problem: when we load a model, we can train it, and when we linearize it, we can still train the pre-linearized model. However, when we try using the linearized model, we get:

RuntimeError: Internal: Failed to load in-memory CUBIN: CUDA_ERROR_OUT_OF_MEMORY: out of memory

This error emerges regardless of whether we are using 1 GPU or multiple. It also emerges whether we are using a large batch (512) or small (1).

We manually tested that a forward pass raises no error, and that a backward pass raises no error. We suspect that the error might arise from the following code (although we could be wrong!):

Their code:

def make_update_fn(*, apply_fn, accum_steps, lr_fn):
  """Returns update step for data parallel training."""

  def update_fn(opt, step, batch, rng):

    _, new_rng = jax.random.split(rng)
    # Bind the rng key to the device id (which is unique across hosts)
    # Note: This is only used for multi-host training (i.e. multiple computers
    # each with multiple accelerators).
    dropout_rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))

    def cross_entropy_loss(*, logits, labels):
      logp = jax.nn.log_softmax(logits)
      return -jnp.mean(jnp.sum(logp * labels, axis=1))

    def loss_fn(params, images, labels):
      logits = apply_fn(
          dict(params=params),
          rngs=dict(dropout=dropout_rng),
          inputs=images,
          train=True)
      return cross_entropy_loss(logits=logits, labels=labels)

    l, g = utils.accumulate_gradient(
        jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
        accum_steps)
    g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g)
    l = jax.lax.pmean(l, axis_name='batch')

    opt = opt.apply_gradient(g, learning_rate=lr_fn(step))
    return opt, l, new_rng

  return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,))

That function is then called via:

# Check out train.make_update_fn in the editor on the right side for details.
lr_fn = utils.create_learning_rate_schedule(total_steps, base_lr, decay_type, warmup_steps)
update_fn_repl = train.make_update_fn(
    apply_fn=vit_apply, accum_steps=accum_steps, lr_fn=lr_fn)
# We use a momentum optimizer that uses half precision for state to save
# memory. It als implements the gradient clipping.
opt = momentum_clip.Optimizer(grad_norm_clip=grad_norm_clip).create(params)
opt_repl = flax.jax_utils.replicate(opt)

The training loop where the memory error arises:

losses = []
lrs = []
# Completes in ~20 min on the TPU runtime.
for step, batch in zip(
    tqdm.trange(1, total_steps + 1),
    ds_train.as_numpy_iterator(),
):

  opt_repl, loss_repl, update_rng_repl = update_fn_repl(
      opt_repl, flax.jax_utils.replicate(step), batch, update_rng_repl)  # ERROR IS HERE
  losses.append(loss_repl[0])
  lrs.append(lr_fn(step))

The above code is all copied from the ViT repo. This is how we linearize the ViT model:

def vit_apply(params, input):
  return model.apply(dict(params=params), input, train=True)
f_lin = nt.linearize(vit_apply, params)
@RylanSchaeffer
Copy link
Author

@romanngg , you were really helpful previously - any thoughts here? Thanks in advance :)

@romanngg
Copy link
Contributor

romanngg commented Mar 6, 2022

nt.linearize is essentially a Jacobian-vector product (jax.jvp), and it's peak memory consumption of the linearized forward pass should be about 2x the peak memory consumption of the forward pass. Then, I believe the costs of the backward passes (jax.vjp) of the linearized and non-linearized models should also differ by 2X (see https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#how-it-s-made-two-foundational-autodiff-functions or https://openreview.net/pdf?id=ym68T6OoO6L). If you have a way to diagnose the peak memory consumption when you train your model, could you check that it's less than half of your GPU memory?

@romanngg romanngg added the question Further information is requested label Mar 6, 2022
@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Mar 6, 2022

@romanngg thanks for getting back to me so soon! I'll check the max memory consumption but I don't think that's the reason because we could successfully "manually" perform a forward and backward pass of f_lin on a single GPU with batch size = 512. By "manually," I mean executing the following alone:

    l, g = utils.accumulate_gradient(
        jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
        **accum_steps)**

I suspect that there might be an odd interaction between f_lin as constructed by Neural Tangents and the code used in the vision transformer notebook (pasted above).

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Mar 7, 2022

Another bizarre observation: if we try

    l, g = utils.accumulate_gradient(
        jax.value_and_grad(loss_fn), opt.target, batch['image'], batch['label'],
        **accum_steps)**

with the original (non-linearized) model outside the update_fn() (defined above), we get a OOM error for about 155 MiB, even though the GPU has tons of additional available memory. This problem does not occur when using the linearized model.

Edit: Ignore that last observation. That problem vanished when we reduced the per-GPU batch size..

@RylanSchaeffer
Copy link
Author

Here's a self-contained colab that reproduces the issue. https://colab.research.google.com/drive/184moQLq3tjo-wEpc8gD7fXCFguAVDBOm#scrollTo=k4CjYqp5qLvj

We suspect that pmap might be causing a problem because if we don't use it, the linearized model can train via jax.value_and_grad(loss_fn), but once we try pmap(jax.value_and_grad(loss_fn)), we hit OOM.

@RylanSchaeffer
Copy link
Author

Another insight: GPU on Colab breaks, but TPU on Colab is fine

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Mar 7, 2022

Looks like someone else has a similar problem while using neural tangents, also potentially arising from pmap

google/jax#8585 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants