Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for MPS devices (i.e. the GPU on Mac) when using PyTorch backend #42

Open
magnusross opened this issue Aug 25, 2023 · 11 comments
Open

Comments

@magnusross
Copy link

Currently it seems that only cuda devices are supported to use as GPUs when using PyTorch backend. It would be nice to also be able to use the MPS device that is now supported by PyTorch for acceleration on Mac.

I guess a good use case for this is for demoing the package on a laptop, which would avoid the need to connect to a cluster just to try the models out, but would still provide enough compute to run more interesting models than simple toy examples in a reasonable time.

I think it should be a reasonably straightforward change but I am quite inexperienced with the backend stuff, so maybe it's quite complex!

@magnusross
Copy link
Author

I have looked into this a bit more deeply, and it is slightly more complex than expected. I have managed to get it working somewhat, but there are some complications:

  • The changes required in DeepSensor are pretty small, just a line or so.
  • Unfortunately, a small change is also required in the backends package. I will raise an issue there to try get that merged.

Given these two small changes, the forward pass of the model seems to work, and indeed gives a big speedup over the CPU.

However there is an additional problem, the evaluation of the likelihood in the neural processes package requires the use of float64 which is not supported when using MPS, only float32 is. I am guessing float64 is used for numerical stability, so it might be a dealbreaker. Would be interested to hear thoughts on this.

@tom-andersson
Copy link
Collaborator

Hi @magnusross, thank you for opening this issue. It would be great if DeepSensor users could capitalise on MPS devices. I can't handle this myself because I don't have a Mac, so I appreciate your efforts to get this working.

As you've realised, the backend agnosticism of DeepSensor is enabled by @wesselb's backends library, so it's good that you've opened a PR over there. Once that's working, we will need to generalise set_gpu_default_device to instead choose between CPU, GPU, and MPS (e.g. set_default_device("mps")) - please open a PR once that's working on your side.

Regarding the issue of whether we can evaluate the log-likelihood ConvNP.loss_fn in single precision to get training working on MPS devices, I will need to loop @wesselb in here. There are two things to consider:

  1. Implementation: ConvNP.loss_fn calls neuralprocesses.loglik under the hood, which itself appears to be hard-coded to use double-precision, which is for numerical stability as you suggest. We would need an update to the code here to be able to evaluate the loss in single-precision/float32.
  2. Numerical stability: My hunch is that the loglik will be unstable for the in single-precision for the ConvGNP due to the Cholesky decomposition of the covariance matrix, but stable for the ConvCNP because no matrix inversion is necessary. Here I'm particularly keen to hear what @wesselb thinks.

@wesselb
Copy link
Collaborator

wesselb commented Sep 5, 2023

Hey @magnusross and @tom-andersson! Apologies for the delay on my part. I've just come back from a holiday and am still catching up on all email.

@magnusross, I will look at your PR for backends very shortly! Thanks so much for putting that together. It's truly much appreciated. :)

@tom-andersson, to answer your questions:

  1. Double precision was indeed hardcoded for numerical stability. To support MPS, it would be a simple change to allow the float64 there to be configured to float32.

  2. The ConvGNP might indeed run into stability issues because of the Cholesky. Nevertheless, I think you should still be able to manage with appropriate regularisation. B.cholesky automatically increases the level of regularisation until the Cholesky succeeds (or a threshold is exceeded), so with a bit of luck this will just work.

@tom-andersson
Copy link
Collaborator

Thanks @wesselb, perhaps it will indeed just work once to allow nps.loglik to use single-precision. Is this something you have time to implement yourself?

Once the two changes are made in backends and neuralprocesses, @magnusross can try training the CNP and GNP on his MPS device using the quickstart code in the DeepSensor Gallery.

@wesselb
Copy link
Collaborator

wesselb commented Sep 7, 2023

Yes, I should be able to merge @magnusross’s PR and allow single precision on the relatively short term. :) Will keep you updated!

@magnusross
Copy link
Author

magnusross commented Sep 7, 2023

Thanks for your help both!

@wesselb if you need me to help in any way r.e. the PR on backends, then let me know, what is there now is probably not the best way of doing it.

@wesselb
Copy link
Collaborator

wesselb commented Sep 10, 2023

@tom-andersson, I've added a keyword argument dtype_lik to nps.elbo and nps.loglik, which you can set to e.g. torch.float32 to use float32s everywhere.

@magnusross, I've left some comments on your PR. I think that's basically ready to go, pending a unit test. :)

@tom-andersson
Copy link
Collaborator

Excellent, thanks @wesselb!

@magnusross, if this gets MPS support working on your side then we can bump the neuralprocesses requirement for deepsensor to the new 0.2.3 and generalise the deepsensor.train.train.set_gpu_default_device method.

@magnusross
Copy link
Author

Hey both, I've been looking at this briefly this morning and have unfortunately run into another problem. I thought before that I was running the full forward pass, but I was actually just running the encoding. Unfortunately now when I run the forward I get:

NotImplementedError: convolution_overrideable not implemented. You are likely triggering this with tensor backend other than CPU/CUDA/MKLDNN, if this is intended, please use TORCH_LIBRARY_IMPL to override this function 

I'm not sure exactly what's causing this but I'll try look into it later this week or next week, I'm a bit busy atm, sorry this is taking some time. Made a PR (#49) so I'll add stuff to that if I find what's causing it, if either of you have any ideas they'd be very welcome!

@tom-andersson
Copy link
Collaborator

Hey @magnusross, I'm afraid I've never come across this error before... Don't worry if this takes you a while to dig into though.

@wesselb
Copy link
Collaborator

wesselb commented Sep 20, 2023

@magnusross, @tom-andersson, my suspicion is that it might take some time before all advanced convolution operations are supported by MPS. Namely, convolutions are implemented with highly optimised GPU kernels, and these will need to be ported to MPS. You might be able to run a forward pass with a simpler convolutional architecture.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants