Replies: 1 comment
-
As nobody else has answered yet I'll just suggest a comment, although I can not test it without a reproducible example. I believe this is due to My solution was to replace return jax.lax.switch(which_gen, generators, key) with list indexing return generators[which_gen](key) This avoids the evaluation of all branches during the switch statement. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I have$M$ different simulations. Each simulation should be executed $N_m$ times. Each single simulation generates some data from RNG.$\sum_m N_m$ simulations.
I am using a nested operation of
jax.pmap
andjax.vmap
to parallelize theIt works just fine but the problem is that compilation time is huge (several hours).
My question: Is this some form of an anti-pattern?
Toy Example Code
Beta Was this translation helpful? Give feedback.
All reactions