Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add torch.float32 support for Apple M1 #223

Open
SteffenPL opened this issue Mar 3, 2023 · 1 comment
Open

Add torch.float32 support for Apple M1 #223

SteffenPL opened this issue Mar 3, 2023 · 1 comment

Comments

@SteffenPL
Copy link

Hello,

Thanks for your amazing package!

I wanted to ask if it would be possible to provide a mechanism to use torch.float32 as dtype for some of the adaptive solvers.
On Apple M1 (mps) torch.float64 is not supported...

E.g. here

alpha=torch.tensor([1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], dtype=torch.float64),

the dtype is fixed independent of the input.

I am new to your package (and also a bit to pytorch, since I usually use Julia), therefore, sorry if I was overlooking something obvious.

@hchau630
Copy link

hchau630 commented Mar 6, 2023

I encountered this issue as well, but I think all you need to do is to specify options={'dtype': torch.float32} in your odeint call (assuming you're using odeint with an adaptive solver like I do).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants