New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharding is much slower than pmap for while loops of varying length while loops #20968
Comments
To be extra clear, if you make |
Also, if you replace the body and do new_t, new_y, theta = inner_loop_body((t, y, theta))
#new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
return (new_t, new_y, theta, count + 1) you see the same speed, which indicates the second while loop is important to the slowdown |
@patrick-kidger you mentioned in patrick-kidger/diffrax#407 that you suspect this is within XLA, do you have any advice on how to approach that? I haven't investigated an XLA system this complex before. Even my reduced complexity example (shown below) yields XLA's that are not exceedingly readable (shown further below). Is there a goto issue/piece of XLA/jax documentation on identifying whether a bug is in jax vs XLA and how to spot it? import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
def solve(init, key):
def inner_loop_cond(state):
t, y, _ = state
return y.squeeze() < 2
def inner_loop_body(state):
t, y, theta = state
return (t + 0.1, y + 0.1, theta)
def outer_loop_cond(state):
_, _, _, count = state
return count < 5
def outer_loop_body(state):
t, y, theta, count = state
y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
return (new_t, new_y, theta, count + 1)
inner_while_loop = jax.lax.while_loop
outer_while_loop = jax.lax.while_loop
theta = 5.0
t_initial = 0.0
y_initial = init
count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
return final_state[1]
|
For performance-related things like this it is usually in XLA. JAX is mostly at the mercy of whatever code XLA generates. Unfortunately the parallelism part of this isn't something I'm familiar with at all. I think @sharadmv might know more? This one is out of my wheelhouse I'm afraid. |
Description
As the title indicated, with a double while loop, where the inner while loop may change in length over outer while loop steps,
pmap
is substantially faster than sharding. This may sound contrived, but is exactly what happens in other packages, such as diffrax where I first identified this issue: patrick-kidger/diffrax#407. I believe there are two possibilities, 1) I am using sharding wrong and that is why it is slow (very possible, I am new to sharding), 2) something else is going on in sharding.I have included a MVC below. I ran on both CPU and GPU and the results on GPU are even more noticeable.
CPU:
5.11 ms ± 53.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
GPU:
1.18 s ± 48.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
CPU:
251 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
GPU:
3.93 ms ± 225 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
System info (python version, jaxlib version, accelerator, etc.)
CPU:
GPU:
The text was updated successfully, but these errors were encountered: