New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sharding integration is much slower than pmap #407
Comments
Note, this is not universal that sharding is just slow on my machine due to some specific problem with my CPUs/install. Running a simpler version of the equinox example yields a sharding that is faster than pmap import os
import multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count()
)
import equinox as eqx
import jax
import jax.experimental.mesh_utils as mesh_utils
import jax.numpy as jnp
import jax.random as jr
import jax.sharding as jshard
import numpy as np
import optax # https://github.com/deepmind/optax
# Hyperparameters
dataset_size = 64
channel_size = 4
hidden_size = 32
depth = 1
learning_rate = 3e-4
num_steps = 10
batch_size = 20 # must be a multiple of our number of devices.
# Generate some synthetic data
xs = np.random.normal(size=(dataset_size, channel_size))
ys = np.sin(xs)
model = eqx.nn.MLP(channel_size, channel_size, hidden_size, depth, key=jr.PRNGKey(6789))
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
# Loss function for a batch of data
def compute_loss(model, x, y):
pred_y = jax.vmap(model)(x)
return jnp.mean((y - pred_y) ** 2)
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jshard.PositionalSharding(devices)
replicated = sharding.replicate()
data, label = xs[:batch_size], ys[:batch_size]
data_pmap = jnp.reshape(data, (num_devices, len(data) // num_devices, channel_size))
label_pmap = jnp.reshape(label, (num_devices, len(label) // num_devices, channel_size))
x, y = eqx.filter_shard((data, label), sharding)
model_shard, opt_state_shard = eqx.filter_shard((model, opt_state), replicated)
@eqx.filter_jit
def train_step(model, opt_state, x, y):
grads = eqx.filter_grad(compute_loss)(model, x, y)
return grads.layers[0].weight %%timeit
_ = train_step(model_shard, opt_state_shard, x, y).block_until_ready()
%%timeit
_ = eqx.filter_pmap(train_step, in_axes=(None, None, 0, 0))(model, opt_state, data_pmap, label_pmap).block_until_ready()
|
Hmm, this is a little odd. I think the main two things I'd highlight are that (a) I think you may be including compile time in your measurements -- if this is larger for sharding then that may skew things, and also (b) I don't know to what extent CPU is supported in each case. CPU support has always been fairly iffy in JAX, so I wouldn't be surprised if sharding/pmap each take better advantage of the hardware somehow. Can you replicate these findings on a GPU? Other than that I know the plan is to have pmap eventually become a backward-compatibility wrapper for sharding, so it'd be good to get these details straightened out. |
This did accidentally include compile time, but I took them out and it has no real impact on times (updated the code above). |
I will try on GPUs and report back |
I slightly adjusted the code just to make some things a little faster (just reduced integration time). In addition, this is now with 4x A100s, rather than 10 (Mac) CPUs. I see sharding based slowdowns way more severe actually with this code on GPUs than on cpus, which was interesting! For exactness, this is what I ran: import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
import matplotlib.pyplot as plt
from diffrax import *
import optax
def f(t, y, args):
return jnp.sin(t) + args["theta"] * y
def g(t, y, args):
return 0.1 * jnp.eye(1)
t0 = 0.
t1 = 0.1
dt0 = 0.05
diffusion_shape = jax.ShapeDtypeStruct((1,), "float32")
solver, cont = Heun(), PIDController(1e-3, 1e-6)
ts = jnp.linspace(t0, t1, 100)
def solve(init, key, args):
control = VirtualBrownianTree(
t0=t0,
t1=t1,
tol=dt0 / 2,
shape=diffusion_shape,
key=key,
)
vf = ODETerm(f)
cvf = ControlTerm(g, control)
terms = MultiTerm(vf, cvf)
saving = SaveAt(
ts=ts
)
sol = diffeqsolve(
terms,
solver,
y0=init,
t0=t0,
t1=t1,
dt0=dt0,
args=args,
saveat=saving,
stepsize_controller=cont,
max_steps=100,
)
return sol.ys
batch_size = 300
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
args = {"theta": 0.1}
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()
inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])
args_shard = eqx.filter_shard(args, replicated)
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)), donate="all")
_ = fn(x, y, args_shard).block_until_ready()
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready() %%timeit
_ = fn(x, y, args_shard).block_until_ready() yields %%timeit
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready() yields (naturally, CPUs are faster than GPUs on such small sequential problems). |
Interesting! Okay so the next things to try, in order, are:
...let's start with that and see where we end up! |
I added the EQX_ON_ERROR, then I started to slowly roll back complexities (until we would have a manually stepped Euler ODE) when I noticed something. Stepsize controller with saveat ts seems to be slowdown. If I do controller and saveat ts, sharding is ~50x slower than pmap. Without it, sharding is slightly faster. To put it concretely, I have this code (simple ODE), I will also check on GPU and post the results but I bet they will be the same: import os
import multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count()
)
os.environ['EQX_ON_ERROR'] = 'nan'
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
from diffrax import *
def f(t, y, args):
return jnp.sin(t) + args["theta"] * y
t0 = 0.
t1 = 0.1
dt0 = 0.01
diffusion_shape = jax.ShapeDtypeStruct((1,), "float32")
solver, cont = Heun(), PIDController(1e-3, 1e-6)
ts = jnp.linspace(t0, t1, 100)
def solve(init, key, args):
vf = ODETerm(f)
terms = vf
saving = SaveAt(ts=ts)
#saving = SaveAt(t1=True)
sol = diffeqsolve(
terms,
solver,
y0=init,
t0=t0,
t1=t1,
dt0=dt0,
args=args,
saveat=saving,
stepsize_controller=cont,
max_steps=100,
)
return sol.ys
batch_size = 300
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
args = {"theta": 0.1}
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()
inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])
args_shard = eqx.filter_shard(args, replicated)
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)))
_ = fn(x, y, args_shard).block_until_ready()
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready() and %%timeit
_ = fn(x, y, args_shard).block_until_ready() and %%timeit
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready() If I have
If I have
If I have
If I have
As you can see, the only one of (meaningful) difference is |
Oh, that's interesting! Okay, so it's probably something to do with this chunk of code: Lines 325 to 355 in 5bc3e07
Which is the bit that saves outputs if So I'd probably suggest playing around with that and seeing what results you get as you start perturbing things. One very concrete place to start: can you check whether the output of
|
I was looking more into the while loop, since we have a double while loop as you described, the first is static and the second may vary in length depending on the iteration. To maximally remove complexity, I tried a minimal example in equinox of this varying double while loop, then I tried it even less complex in pure jax. I notice a slowdown in just jax when I have a double while loop. import os
import multiprocessing
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
multiprocessing.cpu_count()
)
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import functools as ft
def f(t, y, theta):
return jnp.abs(jnp.sin(t)) + theta * y
def solve(init, key):
def inner_loop_cond(state):
t, y, _ = state
return y.squeeze() < 10
def inner_loop_body(state):
t, y, theta = state
dy = f(t, y, theta)
return (t + 0.1, y + 0.1 * dy, theta)
def outer_loop_cond(state):
_, _, _, count = state
return count < 5
def outer_loop_body(state):
t, y, theta, count = state
y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
return (new_t, new_y, theta, count + 1)
inner_while_loop = jax.lax.while_loop
outer_while_loop = jax.lax.while_loop
theta = 5.0
t_initial = 0.0
y_initial = init
count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
return final_state[1]
batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()
inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])
x, y = jax.device_put((inits, keys), sharding)
fn = jax.jit(jax.vmap(solve))
pmap_fn = jax.pmap(fn)
_ = fn(x, y).block_until_ready()
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready() %%timeit
_ = fn(x, y).block_until_ready() yields %%timeit
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready() yields I made a jax issue (google/jax#20968) since this is all in jax. There may be something else that can be done in diffrax but unless I am overlooking something, this jax example should be the same speed but isn't. |
Ah, this is awesome. I'm glad to have this so clearly isolated in a MWE. Since I believe at some point the goal is to convert (I'm guessing the root issue here is probably something in XLA. If you're feeling brave then you could maybe try delving down to that level.) |
This is a follow up to #403, where when I do the equinox sharding design, I see it is substantially slower than a pmap (or even just a vmap on 1 device).
Here is a MVC that I run on my laptop with 10 devices (10 cpus).
When I do
I see
377 ms ± 19.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
and when I do
I see
10 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
I thought it could be something with scoping, but if I move everything inside the solve function it has no impact. This just slightly more complex than the pseudo code I described in the previous issue, but seems to be quite slow. Is there an issue with my sharding approach or is there something more complex going on here?
The text was updated successfully, but these errors were encountered: