Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA OOM when saving checkpoint (in consolidate_state_dict()) using OSS #973

Open
crowsonkb opened this issue Apr 20, 2022 · 6 comments
Open

Comments

@crowsonkb
Copy link

I am experiencing CUDA out of memory crashes when consolidating my optimizer state dict before saving it. I am training on 32 40GB A100s, four nodes with eight GPUs each, using PyTorch Lightning's 'ddp_sharded' strategy, which is OSS. I get the OOM crash in the middle of running consolidate_state_dict(). I have tried adding del statements, gc.collect() and torch.cuda.empty_cache() inside the loop to no avail. I am using a custom optimizer class, a modified AdamW that also saves an exponential moving average of the weights, and I need optimizer state sharding because the extra memory overhead for the EMA weights is so onerous. Here is the custom optimizer code: https://gist.github.com/crowsonkb/ea0ed1f6e88594046c72735f3cef1d05. I don't understand how I am running out of GPU memory partway through consolidate_state_dict() (I put in print statements and it got through 27 of 32 ranks) since it moves the tensors to CPU after each broadcast. I am using NCCL so it has to broadcast on GPU but it copies to CPU right afterwards.

Thank you,
Katherine Crowson

@crowsonkb
Copy link
Author

My fairscale version is 0.4.6, my PyTorch version is 1.11.0+cu113, and my PyTorch Lightning version is 1.6.1.

@crowsonkb
Copy link
Author

It is the same as this issue AFAICT: huggingface/transformers#14542

@crowsonkb
Copy link
Author

Except that I can't fix it by setting force_broadcast_object=True when I create the optimizer because it just hangs instead.

@min-xu-ai
Copy link
Contributor

Do you have a full trace back for the OOM crash? Pasting it in a gist is fine.

@aced125
Copy link

aced125 commented Jun 1, 2022

Getting this as well

@aced125
Copy link

aced125 commented Jun 1, 2022

I fixed this by enabling force_broadcast_object=True in the fairscale.optim.OSS initialization. On V100s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants