New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Counterintuitive speed of einsums vs equivalent matmuls #20952
Comments
Hi - thanks for the question! Could you take a look at https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code and update your benchmarks? In particular, accounting for asynchronous dispatch via |
Thanks, Jake. Compile time should already be separated out from runtime – the code evaluates the functions once first to make sure they all agree on the output before then timing them. I've added |
Thanks! A couple things:
unoptimized_hlo = reduce_double_einsum.lower(w_arr, a_arr, b_arr).as_text()
optimized_hlo = reduce_double_einsum.lower(w_arr, a_arr, b_arr).compile().as_text() Note that the optimized version is the result of XLA compilation, and will generally vary by device type (CPU vs GPU). If it looks like XLA is doing something sub-optimal, we should probably report it at http://github.com/openxla/xla and see if the compiler can be improved in this case. Thanks for raising the issue! |
cross-ref #2160 |
Thanks! Just noting that I checked these speeds on GPU as well, arriving at the same conclusions as CPU. On the same code above (with %timeit reduce_double_einsum(w_arr, a_arr, b_arr).block_until_ready()
%timeit reduce_einsum(w_arr, a_arr, b_arr).block_until_ready()
%timeit reduce_matmul(w_arr, a_arr, b_arr).block_until_ready() prints
and m = jax.device_put(rnd(g_size, a_size))
%timeit jnp.einsum('ga,gb->gab', m, m).block_until_ready()
%timeit jnp.einsum('g,ga,gb->ab', m[:, 0], m, m).block_until_ready() prints
System info:
|
Description
We came across some einsums "in the wild" and noticed some surprisingly large differences in speed across three implementations of the same operation.
Here's a reproduction of the issue:
This prints:
Method 1 is a little slower than method 2 which is much slower than method 3.
Some other observations:
a_size = 256
instead ofa_size = 650
, thenreduce_einsum
andreduce_double_einsum
take around the same time, butreduce_matmul
is still much faster. On other machines,reduce_einsum
can sometimes be noticeably faster thanreduce_double_einsum
(and also see below).Method 1 vs method 2 relates to another counter-intuitive einsum behaviour . On first look, the second einsum (
g,ga,gb->ab
) feels like it should be more expensive but it is around 30x faster:prints:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: