New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance Issue Report: JAX Slower Than Autograd on GPU and CPU Setups #20948
Comments
What happens if you wrap the jax function in Also, can you include details on how you ran the benchmarks? Keep in mind these tips to make sure you're measuring what you think you're measuring when running benchmarks of JAX code: https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code |
Hi Jake, I read the documentation you mentioned, I believe I haven't miss anything important, since my code is simple and trivial.
I am using the |
When I try benchmarking your original function using import autograd.numpy as anp
import jax
import jax.numpy as jnp
def f_autograd(x):
term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2) # Central parabolic valley
# Nested valley 1 (deep, narrow)
term2_x = -4 * anp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2_y = -8 * anp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2 = term2_x + term2_y
# Nested valley 2 (wide, shallow)
term3_x = 2 * anp.sin(5 * anp.pi * (x[0] - 1.25)) * anp.sin(5 * anp.pi * (x[1] - 1.75))
term3_y = 3 * anp.sin(7 * anp.pi * (x[0] - 1.25)) * anp.sin(7 * anp.pi * (x[1] - 1.75))
term3 = 0.2 * (term3_x + term3_y) # Adjust coefficient for shallower valley
term4 = 3 * anp.sin(3 * anp.pi * x[0]) * anp.sin(3 * anp.pi * x[1]) # Oscillating term
term5 = -5 * anp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2) # Deeper global minimum
term6 = -anp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2) # Local minimum
term7 = -2 * anp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2) # Local minimum
return term1 + term2 + term3 + term4 + term5 + term6 + term7
@jax.jit
def f_jax(x):
term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2) # Central parabolic valley
# Nested valley 1 (deep, narrow)
term2_x = -4 * jnp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2_y = -8 * jnp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2 = term2_x + term2_y
# Nested valley 2 (wide, shallow)
term3_x = 2 * jnp.sin(5 * jnp.pi * (x[0] - 1.25)) * jnp.sin(5 * jnp.pi * (x[1] - 1.75))
term3_y = 3 * jnp.sin(7 * jnp.pi * (x[0] - 1.25)) * jnp.sin(7 * jnp.pi * (x[1] - 1.75))
term3 = 0.2 * (term3_x + term3_y) # Adjust coefficient for shallower valley
term4 = 3 * jnp.sin(3 * jnp.pi * x[0]) * jnp.sin(3 * jnp.pi * x[1]) # Oscillating term
term5 = -5 * jnp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2) # Deeper global minimum
term6 = -jnp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2) # Local minimum
term7 = -2 * jnp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2) # Local minimum
return term1 + term2 + term3 + term4 + term5 + term6 + term7
shape = (2, 1000)
x_autograd = anp.ones(shape)
%timeit f_autograd(x_autograd)
# 797 µs ± 440 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
x_jax = jnp.ones(shape)
_ = f_jax(x_jax) # trigger compilation
%timeit f_jax(x_jax).block_until_ready()
# 141 µs ± 17.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each) This is on a Colab CPU runtime, using the built-in
If you could include your full end-to-end benchmark script, including all imports, array definitions, function definitions, and function calls, I may be able to comment on why you are seeing different results. |
Ah I think now I see, when I run your snippet I got this warning:
And got :
When running:
But when I import
and still :
This is my
And my |
Description
Introduction:
This report outlines a performance issue observed with JAX on both GPU and CPU hardware setups. The purpose of this report is to provide detailed feedback to the JAX development team to aid in identifying potential areas for optimization.
Observed Performance Issues:
Steps to Reproduce:
Expected Behavior:
JAX should exhibit comparable or better performance than Autograd given its design for high-performance machine learning tasks, especially on platforms supporting GPU acceleration.
Actual Behavior:
JAX underperforms significantly compared to Autograd across all tested hardware setups.
Attachments:
ADAM (beta1: 0.95, beta2:0.99, epsilon: 0.001), BFGS, Newton-CG, CG (standard scipy.optimize.minimize configuration) on synthetic function:
and
Conclusion:
JAX is rich in features, but is slower than Autograd.
Recommendations:
Acknowledgments:
Thank you to the developers of JAX for their ongoing efforts and contributions to the open-source community.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: