Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why ShardedDDP and OSS are slower than Vanilla DDP #1131

Open
powermano opened this issue Aug 18, 2023 · 0 comments
Open

Why ShardedDDP and OSS are slower than Vanilla DDP #1131

powermano opened this issue Aug 18, 2023 · 0 comments

Comments

@powermano
Copy link

I have test the https://github.com/facebookresearch/fairscale/blob/main/benchmarks/oss.py using two 3080ti and 4080ti respectively.

As mentioned in https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html

The training process can be modified from that carried out by DDP as follows:

1. The wrapped optimizer shards the optimizer state in a greedy fashion based on the parameter size but not the order in which it is used. This is to ensure that each rank has almost the same optimizer memory footprint.

2. The training process is similar to that used by PyTorch’s Distributed Data Parallel (DDP). The forward pass completes on each of the ranks followed by the backward pass. During the backward pass, gradients are synchronized using allreduce.

3. Each rank updates the parameters for the shard of optimizer state that it is responsible for and then discards the rest.

4. After update, a broadcast or allgather follows to ensure all ranks receive the latest updated parameter values.

OSS is very useful when you are using an optimizer such as Adam that has additional state. The wrapping of the optimizer is a one-line non intrusive change that provides memory savings.

If you are using SGD or any optimizer with a limited memory footprint, it is likely that you will see a slowdown when using multiple nodes, due to the additional communication in step 4. There is also some wasteful memory used to store gradients during allreduce in step 2 that is then discarded, although this also happens with normal PyTorch (nothing extraneous here).

Compared to DDP, the OSS + DDP has the additional communication in step 4, why On a single node, OSS should be always faster than vanilla PyTorch ?.

Performance tips for fairscale.optim.oss
1. On a single node, OSS should be always faster than vanilla PyTorch, memory savings will vary depending on the optimizer being used

3080ti

Optimizer Median Throughput (img/s) (rank 0) Peak Memory (MB)
Vanilla 1795.03 +/- 34.88 1462.5MiB
OSS + DDP 1645.64 +/- 31.78 1290.0MiB
OSS + ShardedDDP 1468.54 +/- 12.97 1049.7MiB

4080ti (set export NCCL_P2P_DISABLE=1, as this is a issue about the nvidia Driver and has not been solved.) :

Optimizer Median Throughput (img/s) (rank 0) Peak Memory (MB)
Vanilla 2117.12 +/- 16.13 1556.4MiB
OSS + DDP 1850.65 +/- 5.97 1377.8MiB
OSS + ShardedDDP 1530.15 +/- 8.69 1158.6MiB
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant