Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Fully Sharded Data Parallel #3740

Merged
merged 33 commits into from Jul 1, 2021
Merged

Fully Sharded Data Parallel #3740

merged 33 commits into from Jul 1, 2021

Conversation

stephenroller
Copy link
Contributor

@stephenroller stephenroller commented Jun 22, 2021

Patch description
Add support for Fairscale's FullyShardedDataParallel (FSDP). This is an implementation of DeepSpeed's Zero2 optimization, wherein optimizer state and gradients are sharded across different workers in order to reduce memory usage. Switching to --ddp-backend zero2 results in about a 25% speedup in UPS (without bg workers, probably can be a bit higher), and about a 50% reduction in memory usage. It's recommended everyone switches to this for distributed training, and use the savings to increase batchsize or lower number of GPUs.

We also carve out support for Zero3, but cannot support it at this time due to high level design in ParlAI. See #3753 for a detailed description of why, and how we might overcome this in the future.

As a side change, this also makes our unit tests use OS-assigned free ports, instead of randomized ones, to slightly improve the reliability of running our test suites. I tried pulling this into another PR, but got tired of dealing with stacking.

Testing steps
Manual tests. New CI.

Here are some screenshots from a sweep that contained both --ddp-backend ddp and --ddp-backend zero2

image

image

image

image

@stephenroller
Copy link
Contributor Author

See #3753 for why Zero3 won't be supported in this implementation.

@stephenroller stephenroller marked this pull request as ready for review June 28, 2021 14:49
Copy link
Contributor

@EricMichaelSmith EricMichaelSmith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable - minor comments

@@ -772,6 +772,16 @@ def add_distributed_training_args(self):
grp.add_argument(
'--distributed-world-size', type=int, help='Number of workers.'
)
grp.add_argument(
'--ddp-backend',
choices=['ddp', 'zero2', 'zero3'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm should we even give 'zero3' as an option for the time being? (Don't really care either way)


def should_sync_gradnorm(opt):
"""
Indicates whether fp16 optimizer wrappers should cumulate over workers.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "accumulate"?


For models or optimizers that shard parameters, this ensures we sync.
"""
if self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should we pull in DEFAULT_DDP_BACKEND here?

if (
shared is None
and is_distributed()
and opt.get('ddp_backend', 'ddp') == 'ddp'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same here about maybe using DEFAULT_DDP_BACKEND instead)

parlai/utils/distributed.py Show resolved Hide resolved
tests/test_distributed.py Show resolved Hide resolved
Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really really cool. lots of nits though (and a few real questions 😄 )

@@ -1969,10 +1974,11 @@ def state_dict(self):
"""
states = {}
if hasattr(self, 'model'): # save model params
if hasattr(self.model, 'module'):
# did we wrap in a DistributedDataParallel
if hasattr(self.model, 'module') and not is_fsdp(self.model):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could make this a helper function too? like should_sync_gradnorm (not necessary of course)

self.model = self.build_model()
with fsdp_utils.maybe_fsdp_wrap(opt):
self.model = fsdp_utils.fsdp_wrap(self.build_model())
if self.fp16 and not fsdp_utils.should_use_fsdp(opt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember that bug with the instability stuff? is this not re-introducing it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(because we moved the model.half() call?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I think this needs to use my utility should_delay_halving. Forgot this.

We haven't really moved it the moment of halving. The operations between these two points don't do much, and the original code path should be about the same.

  • We now half it on CPU instead of GPU, and then transfer. That's probably a small speedup in initialization really, with maybe some small numerical differences
  • We model parallel after halving. Probably small speedup at initialization.
  • We synchronize parameters after halving. Again, small initialization speedup.

The catch is that FSDP expects the model pre-halved if we're doing safe optimization, and post-halved if we're doing memory-efficient. (Similar to the optimizer wrappers, it looks for parameters of types to decide what type are the gradients).

This is the desired pattern

  • If we're in Safe and using DDP, we SHOULD still halve, just as before
  • If we're in MemEff and using DDP, we SHOULD still halve, just as before
  • If we're in Safe and Zero2, we should NOT halve here
  • If we're in MemEff and Zero2, we SHOULD halve here.

@@ -55,10 +54,12 @@ def multiprocess_train(
raise


def launch_and_train(opt, port):
def launch_and_train(opt, port=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we ever specify a port here?

@@ -543,7 +546,7 @@ def validate(self):
)
self.best_valid = new_valid
self.impatience = 0
if opt.get('model_file') and is_primary_worker():
if opt.get('model_file'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just making sure I understand - we can get rid of this check because it's handled in save_model right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be able do save_on_nonprimary_worker actually

if max_norm > 0:
clip_coef = max_norm / (grad_norm + 1e-6)
for p in params:
p.grad.detach().mul_(clip_coef)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we detach here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't want grads of grads! (This is in the original pytorch code too)

return

# zero3 not supported at this time. Throw an exception
if opt['ddp_backend'] == 'zero3':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i know this is just for overkill testing but it's not even a choice in the param options so we'll already error there if calling from command line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm leaving it for the future

return (
self.fp16
and self.opt.get('ddp_backend', 'ddp') in ('zero2', 'zero3')
and self.opt['fp16_impl'] == 'safe'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but if we're using mem_efficient we don't delay?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, see main comment

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants