Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DistributedDataParallel #382

Open
dgm2 opened this issue Jun 6, 2022 · 2 comments
Open

DistributedDataParallel #382

dgm2 opened this issue Jun 6, 2022 · 2 comments

Comments

@dgm2
Copy link

dgm2 commented Jun 6, 2022

It seems that a DistributedDataParallel (DDP) pytorch setup is not supported in OT - specifically on emd2 computation.
Any workarounds ideas for making this working?
or any example for multi-gpu setups for OT?

ideally, I would like to make OT working with this torch setup
https://github.com/pytorch/examples/blob/main/distributed/ddp/main.py

Many thanks

example of failed DDP

  ot.emd2(a, b, dist)
  File "/python3.8/site-packages/ot/lp/__init__.py", line 468, in emd2
    nx = get_backend(M0, a0, b0)
  File "/python3.8/site-packages/ot/backend.py", line 168, in get_backend
    return TorchBackend()
  File "/python3.8/site-packages/ot/backend.py", line 1517, in __init__
    self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
RuntimeError: CUDA error: all CUDA-capable devices are busy or unavailable

my current workaround is:
changing
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device='cuda'))
to
self.__type_list__.append(torch.tensor(1, dtype=torch.float32, device=device_id))
passing device id from backend, recompiling this OT from source.

@rflamary
Copy link
Collaborator

rflamary commented Jun 7, 2022

Hello @dgm2 ,

This workaround works? Note that the list is here mainly for debugging and tests (so that we can rub them on all available devices) so I'm a bit surprised if this is the only bottleneck for running POT with DPP.

We are obviously interested in your contribution if you manage to manage it work properly (we don not have multiple GPU so it is a bit hard to implement and debug on our side), probably the device device_id should be detected automatically whene using get_backend and creation, the back-ends should not need parameters to remain practical to use.

@ncassereau-idris
Copy link
Contributor

Hello @dgm2,
Could you provide us with the exact code you used to get this error ?
I ran https://github.com/pytorch/examples/blob/main/distributed/ddp/main.py with 4 GPUs and ot.emd2 as the loss function, yet did not get any error, everything seems to have run smoothly whether the distribution was performed with torch or slurm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants