Skip to content

How to properly use quimb with jax #219

Answered by jcmgray
stfnmangini asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @stfnmangini, sorry to be slow getting to this and thanks for the detailed examples! Indeed I get the same results when I run them.

Yes high level the aim is allow both the "jax in quimb" (TNOptimizer) approach for simple things and the "quimb in jax" (where quimb just orchestrates various array operations) approach for detailed jax things.

Certainly my understanding of the jax_register_pytree functionality was that it should enable jittable functions to accept/return quimb structures. However I have actually not looked much into this direction and so am not aware if this re-compilation thing is a bug or some misunderstanding of how pytrees work in jax - I can try and look into it but …

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@stfnmangini
Comment options

@jcmgray
Comment options

Answer selected by stfnmangini
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants