Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optax.MultiSteps out of memory #472

Open
ein-ich opened this issue Jan 7, 2023 · 15 comments
Open

optax.MultiSteps out of memory #472

ein-ich opened this issue Jan 7, 2023 · 15 comments
Assignees

Comments

@ein-ich
Copy link

ein-ich commented Jan 7, 2023

I always get an out of memory error using optax.MultiSteps, even when every_k_schedule=1.
Using optax.apply_every(k=1) in a chain works fine.

optimizer = optax.chain(
    optax.clip_by_global_norm(0.5),
    optax.adam(lr),
    #optax.apply_every(k=1)
)
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)

Later I'm using
opt_state = optimizer.init(params)
and

updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

I have no idea what I could be doing wrong. I'm not changing anything else, like batch size.

@mkunesch
Copy link
Member

mkunesch commented Jan 8, 2023

Hi! Interesting - thanks for reporting this!

Are you also at more than ~2/3 memory usage when you use apply_every? From a first look, I could see that the implementation of apply_every returns 0*updates for skipped steps while MultiSteps constructs a new array of 0s (even if every_k_schedule=1) so the former has a better memory footprint. This would explain a higher memory usage by up to 50% - but not more.

I'm not sure why the two functions use completely different code paths - we should be able to merge them (and deprecate one of them).

@ein-ich
Copy link
Author

ein-ich commented Jan 8, 2023

I have most of my available memory preallocated by JAX. I tried reducing the batch size from 120 (which works with apply_every) to 30, but it still crashed with MultiSteps.

@ayaka14732
Copy link

I am training Llama 2 7B on TPU. Without optax.MultiSteps my batch_size can be 4. However, after applying optax.MultiSteps, I got OOM even if batch_size is 1.

@hr0nix
Copy link

hr0nix commented Aug 21, 2023

I can confirm that MultiStep implementation has much larger memory overhead than just one extra buffer for gradient (something like 4x extra buffers). This is very problematic when using this class with large models.

@Sea-Snell
Copy link

I also noticed this issue

@philippe-eecs
Copy link

I am having this issue as well for use in diffusion models

@agrimgupta92
Copy link

Facing the same issue.

copybara-service bot pushed a commit that referenced this issue Aug 29, 2023
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561129449
copybara-service bot pushed a commit that referenced this issue Aug 29, 2023
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561129449
@hbq1 hbq1 self-assigned this Aug 29, 2023
copybara-service bot pushed a commit that referenced this issue Aug 29, 2023
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561129449
copybara-service bot pushed a commit that referenced this issue Aug 29, 2023
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561129449
copybara-service bot pushed a commit that referenced this issue Aug 29, 2023
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561129449
copybara-service bot pushed a commit that referenced this issue Aug 30, 2023
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561129449
copybara-service bot pushed a commit that referenced this issue Aug 30, 2023
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561129449
copybara-service bot pushed a commit that referenced this issue Aug 30, 2023
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561129449
copybara-service bot pushed a commit that referenced this issue Aug 30, 2023
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561390202
@hbq1
Copy link
Collaborator

hbq1 commented Aug 30, 2023

Hi everyone, thanks for flagging it up. I just merged a new version of optax.MultiSteps which should be more memory friendly, could you check this please?

@philippe-eecs
Copy link

you're a king

@celiolarcher
Copy link
Contributor

Hi @hbq1! Thank you for the fix!

One question, I am still seeing a larger consumption with MultiStep when compared with the function apply_every. This was supposed to happen?

@celiolarcher
Copy link
Contributor

As a follow-up, I was conducting some debugging by myself and it seems that the problem is on this part of the code (line 414):

new_updates, new_state = jax.lax.cond(
          state.mini_step < k_steps - 1,
          _mid_step, _final_step, *(state, params, acc_grads))

If I got it right, JAX is allocating memory for both function outputs (_mid_step and _final_step), so this basically doubles the space to store optimizer states and grads.

Still trying to figure out a way to solve it, though.

@celiolarcher
Copy link
Contributor

Just added a PR merging apply_every logic into MultiStep function. From my initial tests, it reduces the memory footprint (able to train Llama2 7b in a v3-8 now) without affecting convergence.

@mtthss
Copy link
Collaborator

mtthss commented Oct 31, 2023

This is really great!

@hbq1
Copy link
Collaborator

hbq1 commented Nov 23, 2023

Awesome work @celiolarcher!

jax.lax.cond seems to be suboptimal in some use cases, e.g. here, in theory, it should understand that either _mid_step or _final_step needs to be executed, so it shouldn't allocate memory for both outputs. It might be something that JAX/XLA devs would like to have a look at. Let me know if you'd like me to file a bug to https://github.com/google/jax/issues, or feel free to do it yourself ofc!

@celiolarcher
Copy link
Contributor

I'm glad to be able to help!
About the issue @hbq1 , I can open it there, no problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants