Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add support for Zero3 FSDP #3753

Open
stephenroller opened this issue Jun 28, 2021 · 3 comments
Open

Add support for Zero3 FSDP #3753

stephenroller opened this issue Jun 28, 2021 · 3 comments

Comments

@stephenroller
Copy link
Contributor

In #3740, we added support for FullyShardedDataParallel, but limited implementation to that of Zero2, not Zero3. Zero3 results in substantial decreases of memory usage compared with Zero2 while bringing speed back in line with vanilla DDP.

We have already added support for this (via manual calls to wrap) within the Transformer modules, but we still cannot support Zero3. The main issue is that Zero3 assumes that every worker calls forward the exact same number of times, and performs a parameter-transfer during this forward (moving the sharded parameters to each worker just in time). ParlAI cannot provide this guarantee though because:

  • During validation, each worker sees a variable number of examples. This is okay in itself, but it is problematic (hang) if it results in any worker having extra batches.
  • During generation, workers will have a variable number of forwards due to the variable sequence length. While everything stays happy for a while, if one worker ends the run with needing more generations than the others, we will get hangs.

It seems far too difficult (and ugly) to try to force this equality in worlds.py or in our dataset sharding. So our best future bet is to implement something like .join() in vanilla DDP. It would work roughly as follows:

  • Every worker in forward tries to synchronize a True boolean saying "Am I doing a true forward?"
  • Upon __exit__ of the context, workers enter an infinite loop where they sync a False boolean. As long as any worker is providing a True value, they participate in a dummy batch forward.
  • When all workers agree on the False boolean, we can end the infinite loop.

This feature makes the most sense to implement upstream in Fairscale, and then integrate into ParlAI.

@blefaudeux
Copy link

During validation, each worker sees a variable number of examples. This is okay in itself, but it is problematic (hang) if it results in any worker having extra batches.

Pytorch distributed has a wrapper for that, I've tried to look it up to no avail (maybe not public yet). Not sure how applicable that would be, just a heads up

@seed93
Copy link

seed93 commented Aug 25, 2021

So any updates here? We are really looking forward to using ZERO3 for boosting.

@blefaudeux
Copy link

During validation, each worker sees a variable number of examples. This is okay in itself, but it is problematic (hang) if it results in any worker having extra batches.

Pytorch distributed has a wrapper for that, I've tried to look it up to no avail (maybe not public yet). Not sure how applicable that would be, just a heads up

https://pytorch.org/tutorials/advanced/generic_join.html
@stephenroller would that help ? cc @min-xu-ai

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants