-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPS support #790
base: main
Are you sure you want to change the base?
MPS support #790
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/790
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@maximegmd This is awesome! Can you post some loss curves for the finetune you ran? |
I will complete a run during the weekend, losses looked fine but the Llama3 release changed my priorities ^^ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for making this change, supporting MPS has been on our TODO!
I'm a bit confused about how this is working because the device param is used to fetch the device using this utility function, which in turn depends on this function. We seemingly are never actually returning mps as the device. So how is this working? This'll just default to CPU I think
device is not None when this function is called so it just passes 'mps' to torch.device() which is the expected pytorch name. But you are correct that there is room for improvement to automatically return mps when device is not manually specified in the config. |
Oh good point, totally glossed over the fact the |
If I recall correctly it was around 20s/it but I suspect I was swapping a bit so I can probably improve the speed. The main issue is bitsandbytes not supporting MPS so it uses quite a bit of memory for the optimizer state. I will try to push a llama3 config tomorrow with some numbers now that my llama3 finetune is running :) |
Here is a train run on Gemma 2B, sadly laptop went to sleep right before the end but this is a 14 hours run, should be representative enough. |
Context
Changelog
Test plan