Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU-enabled solver? #2320

Open
reTELEport opened this issue Feb 7, 2024 · 9 comments
Open

GPU-enabled solver? #2320

reTELEport opened this issue Feb 7, 2024 · 9 comments

Comments

@reTELEport
Copy link

Describe the Issue!

Hi, I would like to know if there is any way I can run the solver (mesove, mcsolve, etc.) in a GPU-enabled setting? I don't see any tutorial or manual to support this feature.

@Ericgig
Copy link
Member

Ericgig commented Feb 7, 2024

It's in development.
If you use qutip v5.0.0a2 on pypi or master branch here with qutip-jax, some solver can work on gpu. (We tested sesolve and mesolve, mcsolve should work, stochastic, HOEM, brmesolve don't.)
The readthedocs of qutip-jax shows how to use it.

qutip-jax is not up to date with the latest versions.
If you are interested in trying it, I will make it works with the latest jax version and qutip master.
Any and all feedback appreciated.

We expect an official release of these features in March.

@reTELEport
Copy link
Author

Thank you! I'll try qutip-jax. Looking forward to the official release!

@reTELEport
Copy link
Author

reTELEport commented Feb 12, 2024

Is there a simplest case I can run the mesovle on Linux with Nvidia GPU?

@Ericgig
Copy link
Member

Ericgig commented Feb 12, 2024

I suggest you to use the example in qutip/qutip-jax#26.

First ensure that jax is using the gpu:

import jax
jax.devices()

After I would suggest to use the branch in qutip/qutip-jax#20.
Otherwise we only support dense matrix and dense on gpu...

import qutip as qt
import numpy as np
import qutip_jax
import jax
from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController

with qt.CoreOptions(default_dtype="jaxdia"):
    N = 2
    a = qt.destroy(N) & qt.qeye(N) & qt.qeye(N)
    b = qt.qeye(N) & qt.destroy(N)  & qt.qeye(N)
    c = qt.qeye(N) & qt.qeye(N) & qt.destroy(N)  
    H = (
        a.dag()*a 
        + b.dag()*b 
        + c.dag()*c
        + (a.dag()+a) * (b+b.dag())
        + (b.dag()+b) * (c+c.dag())
    )
    
c_ops =[a, b, c]

t = 10
options = {
    "method": "diffrax", 
    "normalize_output": False, 
    "stepsize_controller" : PIDController(rtol=1e-5, atol=1e-5), 
    "solver": Dopri5()
}
psi_0 = qt.basis(N, 1, dtype="jax") & qt.basis(N,0, dtype="jax") & qt.basis(N,0, dtype="jax")
e_ops = qt.num(N, dtype="jaxdia") & qt.qeye(N, dtype="jaxdia") & qt.qeye(N, dtype="jaxdia")

result = qt.mesolve(H, psi_0, [0, t], c_ops, e_ops=e_ops, options=options)
result.expect

@nwlambert
Copy link
Member

Just to add, I tried to benchmark Eric's jax data layer a bit more with an Ising model, the example is at the end of this colab notebook we made for a tutorial talk, which shows some crossover in performance at certain system sizes: https://colab.research.google.com/drive/1RcgX7oEzGjzPAF8Ryus54Q5UmyMddmLA?usp=sharing

benchmark

Note colab does not have free GPUs, so you will have to download and use it locally. also in the actual ising example, replace
with jax.default_device(jax.devices("cpu")[0]):
with
with jax.default_device(jax.devices("gpu")[0]):

@reTELEport
Copy link
Author

Hi again,

I followed the instruction in this doc qutip-jax, and run the following codes:

import qutip
import qutip_jax

with qutip.CoreOptions(default_dtype="jax"):
    H = qutip.rand_herm(5)
    c_ops = [qutip.destroy(5)]
    rho0 = qutip.basis(5, 4)

result = qutip.mesolve(H, rho0, [0, 1], c_ops=c_ops, options={"method": "diffrax"})

and this will output a warning:

UserWarning: Complex dtype support is work in progress, please read https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.
  out = fun(*args, **kwargs)

Is it a thing if I need a relatively good simulation?

@Ericgig
Copy link
Member

Ericgig commented May 2, 2024

It should not affect the results of simulations. The diffrax package does not interact with complex number directly when used through qutip-jax.

Gradient or other derivatives could be affected, this need more testing.

@reTELEport
Copy link
Author

reTELEport commented May 2, 2024

Thanks. I come across another problem when I try to run the following codes on a L4 GPU provided by google colab:

import qutip
import qutip_jax

with qutip.CoreOptions(default_dtype="jax"):
    H = qutip.rand_herm(5)
    c_ops = [qutip.destroy(5)]
    rho0 = qutip.basis(5, 4)

result = qutip.mesolve(H, rho0, [0, 1], c_ops=c_ops, options={"method": "diffrax"})

And an error will occur:

JaxStackTraceBeforeTransformation: NotImplementedError: Schur decomposition is only implemented on the CPU backend.

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

NotImplementedError                       Traceback (most recent call last)
    [... skipping hidden 28 frame]

[/usr/local/lib/python3.10/site-packages/jax/_src/lax/linalg.py](https://localhost:8080/#) in _schur_lowering(ctx, *args, **kwargs)
   2231 
   2232 def _schur_lowering(ctx, *args, **kwargs):
-> 2233   raise NotImplementedError(
   2234       "Schur decomposition is only implemented on the CPU backend.")
   2235 

NotImplementedError: Schur decomposition is only implemented on the CPU backend.

So I assume the qutip-jax doesn't work with L4 GPU, right? What GPU or TPU can be used in this case?

@Ericgig
Copy link
Member

Ericgig commented May 2, 2024

I don't know, since jax is developed by google I expect it to work well with gpus it provide through collab...

We don't use schur decomposition for mesolve directly. I think it's the integrator from diffrax that does, if not it could be the norm (it's using trace norm which call sqrtm instead of trace, fixed in #2408). Maybe trying other ODE solver or not normalizing would work. Neill seems to have run most of his test using dopri:

from diffrax import Dopri5, PIDController

options = {
    "method": "diffrax",
    "normalize_output": False,
    "stepsize_controller" : PIDController(rtol=1e-8, atol=1e-6), # This is now the default.
    "solver": Dopri5(),
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants