Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi GPU support #9

Open
Idate96 opened this issue Sep 27, 2023 · 5 comments
Open

Multi GPU support #9

Idate96 opened this issue Sep 27, 2023 · 5 comments

Comments

@Idate96
Copy link

Idate96 commented Sep 27, 2023

I was wondering if there are any plans to release multi-gpu training code?
Naively pmapping and using DDPPO does not seem to scale well, as the gpus remain idle while syncing the gradients.

@luchris429
Copy link
Owner

Ahh that's a good idea. It was not on the roadmap, but I would imagine doing something like it would not be that difficult.

Do you think it would just largely involve pmean-ing the grad updates?

@ugurbolat
Copy link

also interested in this. +1

@Howuhh
Copy link

Howuhh commented Dec 4, 2023

We adapted PureJaxRL ppo+rnn implementation to the multi-gpu with pmap in XLand-MiniGrid and it scales well (almost linear from 1 up to 8 A100 gpus)!

@luchris429
Copy link
Owner

Awesome! I took a quick look -- I see that the env steps per second scales linearly; however, do you know how performance scales with time?

@Howuhh
Copy link

Howuhh commented Dec 5, 2023

@luchris429 It just takes a bit more to compile in general (If I correctly understood time as number of total timesteps). I didn't notice any other performance dips for the 10 minute and ~8 hour runs. GPU utilization 100%, OOM does not happen.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants