Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory overflow using scale_by_radam #580

Open
HGangloff opened this issue Aug 29, 2023 · 1 comment
Open

Memory overflow using scale_by_radam #580

HGangloff opened this issue Aug 29, 2023 · 1 comment
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@HGangloff
Copy link

HGangloff commented Aug 29, 2023

Hi,

I have my RAM getting used up to overflow when I use scale_by_radam gradient transform or equivalently optax.radam without JIT compiling the code. The problem appears on CPU and GPU but does not appear when I use JIT compilation. The problem does not seem to exist with optax.adam.

Here is a MWE derived from optax quick start tutorial:

import random
from typing import Tuple
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '' # uncomment to force CPU

import optax
import jax.numpy as jnp
import jax
import numpy as np

BATCH_SIZE = 500
NUM_TRAIN_STEPS = 10000
RAW_TRAINING_DATA = np.random.randint(255, size=(NUM_TRAIN_STEPS, BATCH_SIZE, 1))

TRAINING_DATA = np.unpackbits(RAW_TRAINING_DATA.astype(np.uint8), axis=-1)
LABELS = jax.nn.one_hot(RAW_TRAINING_DATA % 2, 2).astype(jnp.float32).reshape(NUM_TRAIN_STEPS, BATCH_SIZE, 2)

initial_params = {
    'hidden': jax.random.normal(shape=[8, 200], key=jax.random.PRNGKey(0)),
    'hidden2': jax.random.normal(shape=[200, 100], key=jax.random.PRNGKey(0)),
    'output': jax.random.normal(shape=[100, 2], key=jax.random.PRNGKey(1)),
}


def net(x: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
  x = jnp.dot(x, params['hidden'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['hidden2'])
  x = jax.nn.relu(x)
  x = jnp.dot(x, params['output'])
  return x


def loss(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
  y_hat = net(batch, params)

  # optax also provides a number of common loss functions.
  loss_value = optax.sigmoid_binary_cross_entropy(y_hat, labels).sum(axis=-1)

  return loss_value.mean()

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  #@jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.radam(learning_rate=1e-2)
params = fit(initial_params, optimizer)

Of course this example is simple enough and does not saturate the RAM before a long time but this issue is really problematic in another particular research project.

The problem seems to be linked with this computation specific to RAdam: https://github.com/deepmind/optax/blob/fc5de3d3951c4dfd87513c6426e30baf505d89ae/optax/_src/transform.py#L685C7-L685C7. But I do not know how to investigate further.

Thanks for your feedback.

@fabianp fabianp added bug Something isn't working help wanted Extra attention is needed labels Dec 10, 2023
@itstalmeez
Copy link

Hi HGangloff,
Prioritize JIT Compilation:

Compile your code using jax.jit whenever possible to benefit from JAX's optimizations and potentially avoid the RAM issue.
Investigate RAdam Implementation:

Explore the RAdam implementation in Optax:
https://github.com/deepmind/optax/blob/fc5de3d3951c4dfd87513c6426e30baf505d89ae/optax/_src/transform.py#L685C7-L685C7
Focus on areas that might create large temporary arrays or perform memory-intensive operations.
Consider profiling memory usage to pinpoint specific lines or functions causing excessive consumption.
Experiment with Alternative Optimizers:

If RAdam's performance is crucial for your research, consider:
Modifying RAdam's implementation to reduce memory footprint (if feasible).
Exploring alternative optimizers like Yogi, which share similarities with RAdam but might have different memory characteristics.
Report to Optax Maintainers:

Share your findings and code examples with the Optax maintainers to bring attention to the issue and potentially contribute to a fix.
Additional Considerations:

Memory Profiling: Use tools like jax.profiler or external profilers to track memory usage and identify bottlenecks.
Batch Size Adjustment: Experiment with smaller batch sizes to reduce memory requirements per step.
Hardware Constraints: Consider available RAM and potential hardware limitations.
I'm ready to assist further if you have more questions or require additional guidance. I'll be waiting for your positive response!!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants