Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add casting to real option. #2329

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

Ericgig
Copy link
Member

@Ericgig Ericgig commented Feb 19, 2024

Description
In a few places, we remove the imaginary part for hermitian matrices.
It's nice for the user, but it breaks a few cases: isherm check breaks jax.jit and tensorflows does not support casting for auto-differentiation. The solution we had when tensorflow was added was to detect it (no real method). But this did not fix jit as seen in qutip/qutip-jax#34.

This add a setting that is checked first, so after qutip.settings.core["auto_real_casting"] =False all these checks are removed and jit works.

@coveralls
Copy link

coveralls commented Feb 19, 2024

Coverage Status

coverage: 85.82% (+0.006%) from 85.814%
when pulling 05cc4d7 on Ericgig:misc.less_branching
into 0695ad3 on qutip:master.

@hodgestar
Copy link
Contributor

It would be nice to avoid a proliferation of settings if we can. Could we perhaps add a casting function to the data layer somewhere? For example:

if oper.isherm and ...:
    out = oper.data.cast_to_real(out)

where for the JAX backend cast_to_real is something like jax.numpy,real.

This is a little bit awkward if oper.data and state.data are different dtypes, because there are two possible functions to choose from, but perhaps that's okay?

@Ericgig
Copy link
Member Author

Ericgig commented Feb 19, 2024

For jax, the issue is with isherm, not the casting. We cannot call it in jitted functions since they cannot branch depending on the input data. The casting itself cause issues with tensorflow.

The isherm_jax could be set to always return False, but returning False only inside jitted function is too hacky and could break with jax release. When applied everywhere, it will have strange side effect such as qeye(n) *2 is hermitian but qeye(n)+qeye(n) is not. Don't want that.

If the Qobj are always the inputs of the function, then we could probably compute the hermiticity in qobj_tree_flatten, but it would not be any help when the object is build inside the function.

I am somewhat at a lost of idea other than that...

I see a helper function like qutip_jax.set_as_default() that would set all the settings for the user for the session in one go. Not have them manually updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants