Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding integration is much slower than pmap #407

Open
lockwo opened this issue Apr 21, 2024 · 10 comments
Open

Sharding integration is much slower than pmap #407

lockwo opened this issue Apr 21, 2024 · 10 comments

Comments

@lockwo
Copy link
Contributor

lockwo commented Apr 21, 2024

This is a follow up to #403, where when I do the equinox sharding design, I see it is substantially slower than a pmap (or even just a vmap on 1 device).

Here is a MVC that I run on my laptop with 10 devices (10 cpus).

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)

import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
import matplotlib.pyplot as plt
from diffrax import *
import optax

def f(t, y, args):
    return jnp.sin(t) + args["theta"] * y

def g(t, y, args):
    return 0.1 * jnp.eye(1)

t0 = 0.
t1 = 1.
dt0 = 0.05
diffusion_shape = jax.ShapeDtypeStruct((1,), "float32")
solver, cont = Heun(), PIDController(1e-3, 1e-6)
ts = jnp.linspace(t0, t1, 100)

def solve(init, key, args):
    control = VirtualBrownianTree(
        t0=t0,
        t1=t1,
        tol=dt0 / 2,
        shape=diffusion_shape,
        key=key,
    )
    vf = ODETerm(f)
    cvf = ControlTerm(g, control)
    terms = MultiTerm(vf, cvf)
    saving = SaveAt(
        ts=ts
    )
    sol = diffeqsolve(
        terms,
        solver,
        y0=init,
        t0=t0,
        t1=t1,
        dt0=dt0,
        args=args,
        saveat=saving,
        stepsize_controller=cont,
        max_steps=10_000,
    )
    return sol.ys

batch_size = 300
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
args = {"theta": 0.1}

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

args_shard = eqx.filter_shard(args, replicated)
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)), donate="all")
_ = fn(x, y, args_shard).block_until_ready()
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()

When I do

%%timeit
_ = fn(x, y, args_shard).block_until_ready()

I see 377 ms ± 19.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

and when I do

%%timeit
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()

I see 10 ms ± 171 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

I thought it could be something with scoping, but if I move everything inside the solve function it has no impact. This just slightly more complex than the pseudo code I described in the previous issue, but seems to be quite slow. Is there an issue with my sharding approach or is there something more complex going on here?

@lockwo
Copy link
Contributor Author

lockwo commented Apr 21, 2024

Note, this is not universal that sharding is just slow on my machine due to some specific problem with my CPUs/install. Running a simpler version of the equinox example yields a sharding that is faster than pmap

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
import equinox as eqx
import jax
import jax.experimental.mesh_utils as mesh_utils
import jax.numpy as jnp
import jax.random as jr
import jax.sharding as jshard
import numpy as np
import optax  # https://github.com/deepmind/optax


# Hyperparameters
dataset_size = 64
channel_size = 4
hidden_size = 32
depth = 1
learning_rate = 3e-4
num_steps = 10
batch_size = 20  # must be a multiple of our number of devices.

# Generate some synthetic data
xs = np.random.normal(size=(dataset_size, channel_size))
ys = np.sin(xs)

model = eqx.nn.MLP(channel_size, channel_size, hidden_size, depth, key=jr.PRNGKey(6789))
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

# Loss function for a batch of data
def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jshard.PositionalSharding(devices)
replicated = sharding.replicate()

data, label = xs[:batch_size], ys[:batch_size]
data_pmap = jnp.reshape(data, (num_devices, len(data) // num_devices, channel_size))
label_pmap = jnp.reshape(label, (num_devices, len(label) // num_devices, channel_size))
x, y = eqx.filter_shard((data, label), sharding)
model_shard, opt_state_shard = eqx.filter_shard((model, opt_state), replicated)

@eqx.filter_jit
def train_step(model, opt_state, x, y):
    grads = eqx.filter_grad(compute_loss)(model, x, y)
    return grads.layers[0].weight
%%timeit
_ = train_step(model_shard, opt_state_shard, x, y).block_until_ready()

316 µs ± 6.88 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

%%timeit
_  = eqx.filter_pmap(train_step, in_axes=(None, None, 0, 0))(model, opt_state, data_pmap, label_pmap).block_until_ready()

1.75 ms ± 2.93 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

@patrick-kidger
Copy link
Owner

Hmm, this is a little odd.

I think the main two things I'd highlight are that (a) I think you may be including compile time in your measurements -- if this is larger for sharding then that may skew things, and also (b) I don't know to what extent CPU is supported in each case. CPU support has always been fairly iffy in JAX, so I wouldn't be surprised if sharding/pmap each take better advantage of the hardware somehow. Can you replicate these findings on a GPU?

Other than that I know the plan is to have pmap eventually become a backward-compatibility wrapper for sharding, so it'd be good to get these details straightened out.

@lockwo
Copy link
Contributor Author

lockwo commented Apr 21, 2024

This did accidentally include compile time, but I took them out and it has no real impact on times (updated the code above).

@lockwo
Copy link
Contributor Author

lockwo commented Apr 21, 2024

I will try on GPUs and report back

@lockwo
Copy link
Contributor Author

lockwo commented Apr 21, 2024

I slightly adjusted the code just to make some things a little faster (just reduced integration time). In addition, this is now with 4x A100s, rather than 10 (Mac) CPUs. I see sharding based slowdowns way more severe actually with this code on GPUs than on cpus, which was interesting! For exactness, this is what I ran:

import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
import matplotlib.pyplot as plt
from diffrax import *
import optax

def f(t, y, args):
    return jnp.sin(t) + args["theta"] * y

def g(t, y, args):
    return 0.1 * jnp.eye(1)

t0 = 0.
t1 = 0.1
dt0 = 0.05
diffusion_shape = jax.ShapeDtypeStruct((1,), "float32")
solver, cont = Heun(), PIDController(1e-3, 1e-6)
ts = jnp.linspace(t0, t1, 100)

def solve(init, key, args):
    control = VirtualBrownianTree(
        t0=t0,
        t1=t1,
        tol=dt0 / 2,
        shape=diffusion_shape,
        key=key,
    )
    vf = ODETerm(f)
    cvf = ControlTerm(g, control)
    terms = MultiTerm(vf, cvf)
    saving = SaveAt(
        ts=ts
    )
    sol = diffeqsolve(
        terms,
        solver,
        y0=init,
        t0=t0,
        t1=t1,
        dt0=dt0,
        args=args,
        saveat=saving,
        stepsize_controller=cont,
        max_steps=100,
    )
    return sol.ys

batch_size = 300
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
args = {"theta": 0.1}

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

args_shard = eqx.filter_shard(args, replicated)
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)), donate="all")
_ = fn(x, y, args_shard).block_until_ready()
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()
%%timeit
_ = fn(x, y, args_shard).block_until_ready()

yields 13.7 s ± 261 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()

yields 57.1 ms ± 3.94 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

(naturally, CPUs are faster than GPUs on such small sequential problems).

@patrick-kidger
Copy link
Owner

Interesting! Okay so the next things to try, in order, are:

  • perhaps it's the runtime errors. Try setting EQX_ON_ERROR=nan.
  • perhaps the additional complexity here is triggering something. Try removing adaptive step size control, ODEs rather than SDEs, saving at t1 rather than all ts`, use Euler instead of Heun, etc.
  • perhaps it's something else subtle inside Diffrax. The result of the above should be equivalent to a while-loop-over-steps, so rewrite as that.
  • perhaps there's some dodgy interaction between sharding and lax.while_loop. Try unrolling the loop with for loop instead. (You only have 20 steps, this shouldn't be too bad from a compile time perspective. Do make sure to continue to exclude that from the bencmark results though!)

...let's start with that and see where we end up!

@lockwo
Copy link
Contributor Author

lockwo commented Apr 22, 2024

I added the EQX_ON_ERROR, then I started to slowly roll back complexities (until we would have a manually stepped Euler ODE) when I noticed something. Stepsize controller with saveat ts seems to be slowdown. If I do controller and saveat ts, sharding is ~50x slower than pmap. Without it, sharding is slightly faster. To put it concretely, I have this code (simple ODE), I will also check on GPU and post the results but I bet they will be the same:

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
os.environ['EQX_ON_ERROR'] = 'nan'
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import equinox as eqx
from diffrax import *

def f(t, y, args):
    return jnp.sin(t) + args["theta"] * y

t0 = 0.
t1 = 0.1
dt0 = 0.01
diffusion_shape = jax.ShapeDtypeStruct((1,), "float32")
solver, cont = Heun(), PIDController(1e-3, 1e-6)
ts = jnp.linspace(t0, t1, 100)

def solve(init, key, args):
    vf = ODETerm(f)
    terms = vf
    saving = SaveAt(ts=ts)
    #saving = SaveAt(t1=True)
    sol = diffeqsolve(
        terms,
        solver,
        y0=init,
        t0=t0,
        t1=t1,
        dt0=dt0,
        args=args,
        saveat=saving,
        stepsize_controller=cont,
        max_steps=100,
    )
    return sol.ys

batch_size = 300
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
args = {"theta": 0.1}

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

args_shard = eqx.filter_shard(args, replicated)
x, y = eqx.filter_shard((inits, keys), sharding)
fn = eqx.filter_jit(eqx.filter_vmap(solve, in_axes=(0, 0, None)))
_ = fn(x, y, args_shard).block_until_ready()
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()

and

%%timeit
_ = fn(x, y, args_shard).block_until_ready()

and

%%timeit
_ = eqx.filter_pmap(fn, in_axes=(0, 0, None))(inits_pmap, keys_pmap, args).block_until_ready()

If I have SaveAt(ts=ts) + controller:

  • shard = 36.6 ms ± 1.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
  • pmap = 684 µs ± 5.08 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

If I have SaveAt(ts=ts) + no controller:

  • shard = 165 µs ± 2.43 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
  • pmap = 621 µs ± 2.23 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

If I have SaveAt(t1=True) + controller:

  • shard = 1.01 ms ± 38 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
  • pmap = 648 µs ± 40 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

If I have SaveAt(t1=True) + no controller:

  • shard = 147 µs ± 3.97 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
  • pmap = 630 µs ± 4.27 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

As you can see, the only one of (meaningful) difference is SaveAt(ts=ts) + controller. I will look into the implementations, but since you are much more familiar with the code base, maybe there is a reason that stands out to you?

@patrick-kidger
Copy link
Owner

Oh, that's interesting! Okay, so it's probably something to do with this chunk of code:

def save_ts_impl(ts, fn, save_state: SaveState) -> SaveState:
def _cond_fun(_save_state):
return (
keep_step
& (ts[_save_state.saveat_ts_index] <= state.tnext)
& (_save_state.saveat_ts_index < len(ts))
)
def _body_fun(_save_state):
_t = ts[_save_state.saveat_ts_index]
_y = interpolator.evaluate(_t)
_ts = _save_state.ts.at[_save_state.save_index].set(_t)
_ys = jtu.tree_map(
lambda __y, __ys: __ys.at[_save_state.save_index].set(__y),
fn(_t, _y, args),
_save_state.ys,
)
return SaveState(
saveat_ts_index=_save_state.saveat_ts_index + 1,
ts=_ts,
ys=_ys,
save_index=_save_state.save_index + 1,
)
return inner_while_loop(
_cond_fun,
_body_fun,
save_state,
max_steps=len(ts),
buffers=_inner_buffers,
checkpoints=len(ts),

Which is the bit that saves outputs if SaveAt(ts=...) is used. Notably, this is a while loop: at every numerical step we have to do a loop over the times we haven't yet seen, to see whether they occured during the step we just made. When t0, t1 and dt are constant over the batch then this while loop will run for the same number of iterations for every batch element. But for an adaptive step size controller, this might not be the case.

So I'd probably suggest playing around with that and seeing what results you get as you start perturbing things.

One very concrete place to start: can you check whether the output of _cond_fun is batched or unbatched in either case? Depending on this we need to perform different logic:

  • If it's unbatched then the loop performs the same number of iterations for all batch elements so the vmap and _body_fun commute, so we can lower vmap-of-while-of-(cond, body) into a while-of-(cond, vmap-of-body).
  • However if this predicate is batched then the loop may run for different numbers of iterations for each batch element: to perform the lowering we still need to commute as above, but need to add in additional logic to ignore the iterations from shorter-than-maximal batch elements.

@lockwo
Copy link
Contributor Author

lockwo commented Apr 27, 2024

I was looking more into the while loop, since we have a double while loop as you described, the first is static and the second may vary in length depending on the iteration. To maximally remove complexity, I tried a minimal example in equinox of this varying double while loop, then I tried it even less complex in pure jax. I notice a slowdown in just jax when I have a double while loop.

import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)
import jax
import jax.numpy as jnp
import jax.experimental.mesh_utils as mesh_utils
import functools as ft 

def f(t, y, theta):
    return jnp.abs(jnp.sin(t)) + theta * y

def solve(init, key):
    def inner_loop_cond(state):
        t, y, _ = state
        return y.squeeze() < 10

    def inner_loop_body(state):
        t, y, theta = state
        dy = f(t, y, theta)
        return (t + 0.1, y + 0.1 * dy, theta)
    
    def outer_loop_cond(state):
        _, _, _, count = state
        return count < 5
    
    def outer_loop_body(state):
        t, y, theta, count = state
        y = jax.random.uniform(jax.random.PRNGKey(count), shape=(1,))
        new_t, new_y, _ = inner_while_loop(inner_loop_cond, inner_loop_body, (t, y, theta))
        return (new_t, new_y, theta, count + 1)

    inner_while_loop = jax.lax.while_loop
    outer_while_loop = jax.lax.while_loop
    theta = 5.0
    t_initial = 0.0
    y_initial = init
    count_initial = jax.random.randint(key, minval=-2, maxval=2, shape=())
    final_state = outer_while_loop(outer_loop_cond, outer_loop_body, (t_initial, y_initial, theta, count_initial))
    return final_state[1]


batch_size = 30
inits = 0.1 * jnp.ones((batch_size, 1))
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)

num_devices = len(jax.devices())
devices = mesh_utils.create_device_mesh((num_devices, 1))
sharding = jax.sharding.PositionalSharding(devices)
replicated = sharding.replicate()

inits_pmap = inits.reshape(num_devices, batch_size // num_devices, *inits.shape[1:])
keys_pmap = keys.reshape(num_devices, batch_size // num_devices, *keys.shape[1:])

x, y = jax.device_put((inits, keys), sharding)

fn = jax.jit(jax.vmap(solve))
pmap_fn = jax.pmap(fn)
_ = fn(x, y).block_until_ready()
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()
%%timeit
_ = fn(x, y).block_until_ready()

yields 5.11 ms ± 53.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%%timeit
_ = pmap_fn(inits_pmap, keys_pmap).block_until_ready()

yields 251 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

I made a jax issue (google/jax#20968) since this is all in jax. There may be something else that can be done in diffrax but unless I am overlooking something, this jax example should be the same speed but isn't.

@patrick-kidger
Copy link
Owner

Ah, this is awesome. I'm glad to have this so clearly isolated in a MWE.

Since I believe at some point the goal is to convert pmap to use jit under the hood, then the pmap version might be of interest to add as a JAX benchmark. If the pmap->jit transition ever happens without this being fixed then it would be good to ensure that this is noticed.

(I'm guessing the root issue here is probably something in XLA. If you're feeling brave then you could maybe try delving down to that level.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants