Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Counterintuitive speed of einsums vs equivalent matmuls #20952

Open
shehzaidi opened this issue Apr 26, 2024 · 5 comments
Open

Counterintuitive speed of einsums vs equivalent matmuls #20952

shehzaidi opened this issue Apr 26, 2024 · 5 comments
Labels
bug Something isn't working

Comments

@shehzaidi
Copy link

shehzaidi commented Apr 26, 2024

Description

We came across some einsums "in the wild" and noticed some surprisingly large differences in speed across three implementations of the same operation.

Here's a reproduction of the issue:

import jax.numpy as jnp
import jax
jax.config.update('jax_enable_x64', True)

# Just a utility for generating random arrays.
class RandomArr:

  def __init__(self, rng=0):
    self.rng = jax.random.PRNGKey(rng)

  def __call__(self, *shape):
    self.rng, _ = jax.random.split(self.rng, 2)
    return jax.random.normal(self.rng, shape)

rnd = RandomArr()

# Method 1: two einsums done successively. Initially these were far away in the
# code, so it wasn't immediately obvious to combine them.
@jax.jit
def reduce_double_einsum(w, a, b):
  aab = jnp.einsum('xga,xgb,gs->gsab', a, a, b)
  prod = 0.5 * jnp.einsum('sg,gsab->sab', w, aab)
  return prod

# Method 2: the two einsums above combined into one einsum.
@jax.jit
def reduce_einsum(w, a, b):
  prod = 0.5 * jnp.einsum('sg,xga,xgb,gs->sab', w, a, a, b)
  return prod

# Method 3: rewriting the whole operation without any einsums, instead using
# element-wise multiplications, matmuls, reshapes etc.
@jax.jit
def reduce_matmul(w, a, b):
  sg = w * b.T
  xga = a
  xg_a = xga.reshape(-1, xga.shape[-1])

  xgsa = sg.T[None, :, :, None] * xga[:, :, None, :]
  xg_sa = xgsa.reshape(-1, sg.shape[0] * xga.shape[-1])
  prod = jnp.matmul(xg_sa.T, xg_a).reshape(
      sg.shape[0], xga.shape[-1], xga.shape[-1]
  )
  prod *= 0.5
  return prod

# Make some random arrays.
# NOTE: sizes matter for the run times.
w_arr = rnd(2, g_size)
a_arr = rnd(10, g_size, a_size)
b_arr = rnd(g_size, 2)

# Run them all (including jit compilation).
double_einsum = reduce_double_einsum(w_arr, a_arr, b_arr)
einsum = reduce_einsum(w_arr, a_arr, b_arr)
matmul = reduce_matmul(w_arr, a_arr, b_arr)

# Make sure all functions agree.
assert jnp.allclose(matmul, double_einsum)
assert jnp.allclose(matmul, einsum)
assert jnp.allclose(einsum, double_einsum)

%timeit reduce_double_einsum(w_arr, a_arr, b_arr).block_until_ready()
%timeit reduce_einsum(w_arr, a_arr, b_arr).block_until_ready()
%timeit reduce_matmul(w_arr, a_arr, b_arr).block_until_ready()

This prints:

14.4 s ± 412 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
11.4 s ± 1.11 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.04 s ± 96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Method 1 is a little slower than method 2 which is much slower than method 3.

Some other observations:

  • If you set a_size = 256 instead of a_size = 650, then reduce_einsum and reduce_double_einsum take around the same time, but reduce_matmul is still much faster. On other machines, reduce_einsum can sometimes be noticeably faster than reduce_double_einsum (and also see below).
  • Without JIT compilation, the rankings remain the same but some methods become slower.

Method 1 vs method 2 relates to another counter-intuitive einsum behaviour . On first look, the second einsum (g,ga,gb->ab) feels like it should be more expensive but it is around 30x faster:

m = rnd(g_size, a_size)
%timeit jnp.einsum('ga,gb->gab', m, m).block_until_ready()
%timeit jnp.einsum('g,ga,gb->ab', m[:, 0], m, m).block_until_ready()

prints:

1.82 s ± 328 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
50.8 ms ± 648 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='1c084c76869d', release='6.1.58+', version='#1 SMP PREEMPT_DYNAMIC Sat Nov 18 15:31:17 UTC 2023', machine='x86_64')
@shehzaidi shehzaidi added the bug Something isn't working label Apr 26, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 26, 2024

Hi - thanks for the question! Could you take a look at https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code and update your benchmarks? In particular, accounting for asynchronous dispatch via block_until_ready() and separating-out compile time and runtime would make the benchmark results more reliable and easier to compare.

@shehzaidi
Copy link
Author

Hi - thanks for the question! Could you take a look at https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code and update your benchmarks? In particular, accounting for asynchronous dispatch via block_until_ready() and separating-out compile time and runtime would make the benchmark results more reliable and easier to compare.

Thanks, Jake. Compile time should already be separated out from runtime – the code evaluates the functions once first to make sure they all agree on the output before then timing them. I've added .block_until_ready(), which gives almost the same timings as before.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 26, 2024

Thanks! A couple things:

  1. Since you're running on CPU, you might also try on GPU. The XLA GPU compiler is a bit more mature than the CPU compiler, so you might see different performance characteristics.
  2. Either way, this seems like something the compiler should be able to optimize more effectively. The place where that will happen is in http://github.com/openxla/xla, which is the repo for the compiler that JAX targets by default.
  3. If you want to dig-in yourself, you could use JAX's ahead-of-time compilation tools to get a sense for the HLO instructions that JAX is emitting and what the compiler is doing with them. It would look something like this for your first function:
unoptimized_hlo = reduce_double_einsum.lower(w_arr, a_arr, b_arr).as_text()
optimized_hlo = reduce_double_einsum.lower(w_arr, a_arr, b_arr).compile().as_text()

Note that the optimized version is the result of XLA compilation, and will generally vary by device type (CPU vs GPU).

If it looks like XLA is doing something sub-optimal, we should probably report it at http://github.com/openxla/xla and see if the compiler can be improved in this case.

Thanks for raising the issue!

@mattjj
Copy link
Member

mattjj commented Apr 26, 2024

cross-ref #2160

@shehzaidi
Copy link
Author

shehzaidi commented Apr 29, 2024

Thanks! Just noting that I checked these speeds on GPU as well, arriving at the same conclusions as CPU. On the same code above (with device_put done before timing):

%timeit reduce_double_einsum(w_arr, a_arr, b_arr).block_until_ready()
%timeit reduce_einsum(w_arr, a_arr, b_arr).block_until_ready()
%timeit reduce_matmul(w_arr, a_arr, b_arr).block_until_ready()

prints

521 ms ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
293 ms ± 469 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
75.7 ms ± 209 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

and

m = jax.device_put(rnd(g_size, a_size))
%timeit jnp.einsum('ga,gb->gab', m, m).block_until_ready()
%timeit jnp.einsum('g,ga,gb->ab', m[:, 0], m, m).block_until_ready()

prints

24.2 ms ± 55.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
6.63 ms ± 103 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

System info:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.25.2
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='133beff4b4f6', release='6.1.58+', version='#1 SMP PREEMPT_DYNAMIC Sat Nov 18 15:31:17 UTC 2023', machine='x86_64')


$ nvidia-smi
Mon Apr 29 11:13:25 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   64C    P0              32W /  70W |  11455MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants