Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Issue Report: JAX Slower Than Autograd on GPU and CPU Setups #20948

Open
alibastami opened this issue Apr 26, 2024 · 4 comments
Open
Labels
bug Something isn't working type:performance

Comments

@alibastami
Copy link

Description

Introduction:

This report outlines a performance issue observed with JAX on both GPU and CPU hardware setups. The purpose of this report is to provide detailed feedback to the JAX development team to aid in identifying potential areas for optimization.
Observed Performance Issues:

  • GPU Performance:
    • JAX is significantly slower than expected when compared to Autograd on identical tasks, showing a minimum of 5x slower performance on NVIDIA GPUs.
  • CPU Performance:
    • Similar underperformance is observed on Intel Core i7 CPUs, where JAX operations are markedly slower than those performed with Autograd.

Steps to Reproduce:

  1. Set up the environment with specified hardware and software versions.
  2. Run benchmark tests including matrix operations, gradient calculations (ADAM).
  3. Compare execution times of JAX and Autograd.

Expected Behavior:

JAX should exhibit comparable or better performance than Autograd given its design for high-performance machine learning tasks, especially on platforms supporting GPU acceleration.

Actual Behavior:

JAX underperforms significantly compared to Autograd across all tested hardware setups.

Attachments:

  • Benchmarking Scripts:

ADAM (beta1: 0.95, beta2:0.99, epsilon: 0.001), BFGS, Newton-CG, CG (standard scipy.optimize.minimize configuration) on synthetic function:


def f(x):

  term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2)  # Central parabolic valley

  # Nested valley 1 (deep, narrow)
  term2_x = -4 * anp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
  term2_y = -8 * anp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
  term2 = term2_x + term2_y

  # Nested valley 2 (wide, shallow)
  term3_x = 2 * anp.sin(5 * anp.pi * (x[0] - 1.25)) * anp.sin(5 * anp.pi * (x[1] - 1.75))
  term3_y = 3 * anp.sin(7 * anp.pi * (x[0] - 1.25)) * anp.sin(7 * anp.pi * (x[1] - 1.75))
  term3 = 0.2 * (term3_x + term3_y)  # Adjust coefficient for shallower valley

  term4 = 3 * anp.sin(3 * anp.pi * x[0]) * anp.sin(3 * anp.pi * x[1])  # Oscillating term

  term5 = -5 * anp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2)  # Deeper global minimum

  term6 = -anp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2)  # Local minimum

  term7 = -2 * anp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2)  # Local minimum

  return term1 + term2 + term3 + term4 + term5 + term6 + term7

and

def f(x):

  term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2)  # Central parabolic valley

  # Nested valley 1 (deep, narrow)
  term2_x = -4 * jnp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
  term2_y = -8 * jnp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
  term2 = term2_x + term2_y

  # Nested valley 2 (wide, shallow)
  term3_x = 2 * jnp.sin(5 * jnp.pi * (x[0] - 1.25)) * jnp.sin(5 * jnp.pi * (x[1] - 1.75))
  term3_y = 3 * jnp.sin(7 * jnp.pi * (x[0] - 1.25)) * jnp.sin(7 * jnp.pi * (x[1] - 1.75))
  term3 = 0.2 * (term3_x + term3_y)  # Adjust coefficient for shallower valley

  term4 = 3 * jnp.sin(3 * jnp.pi * x[0]) * jnp.sin(3 * jnp.pi * x[1])  # Oscillating term

  term5 = -5 * jnp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2)  # Deeper global minimum

  term6 = -jnp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2)  # Local minimum

  term7 = -2 * jnp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2)  # Local minimum

  return term1 + term2 + term3 + term4 + term5 + term6 + term7

Conclusion:

JAX is rich in features, but is slower than Autograd.

Recommendations:

  • Conduct a thorough investigation into the causes of the observed performance bottlenecks.

Acknowledgments:

Thank you to the developers of JAX for their ongoing efforts and contributions to the open-source community.

System info (python version, jaxlib version, accelerator, etc.)

  • Hardware: Intel Core i7-9750H CPU, 16GB DDR4 RAM, NVIDIA GTX 1650 GPU, NVIDIA Tesla L4 GPU (used in Colab Pro)
  • Software: JAX version 0.4.26, Jaxlib version 0.4.26, Python version 3.9.15, Jupyter notebook 5.7.2.
  • Comparison Reference: Autograd version 1.6.2
  • Additional info:

JAX Available devices: [cuda(id=0)]
Torch CUDA Available: True
Torch CUDA Device Name: NVIDIA GeForce GTX 1650 (on Colab Nvidia L4)
Torch Current CUDA Device ID: 0
Torch Number of GPUs: 1

@alibastami alibastami added the bug Something isn't working label Apr 26, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 26, 2024

What happens if you wrap the jax function in jax.jit?

Also, can you include details on how you ran the benchmarks? Keep in mind these tips to make sure you're measuring what you think you're measuring when running benchmarks of JAX code: https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code

@alibastami
Copy link
Author

Hi Jake,

I read the documentation you mentioned, I believe I haven't miss anything important, since my code is simple and trivial.

  1. I wrapped the function using one of two ways:

    • The decorator: @jit
    • Directly on the function: jax.jit(f)
  2. I moved x0 to the GPU as follows:

    x = jnp.array([-10., -80])
    x0 = device_put(x, jax.devices('gpu')[0])
    
  3. I ran identical code samples using jnp and anp. The anp version completed in under a second, while the jnp version has been running for over 10 minutes (I reduced the number of iteration to a minimum number to finish the test and take the screenshots, unlike anp which broke the loop upon meeting convergence criteria).

  4. Here is ADAM with time :

def adam(grad_func, x0, alpha=0.01, beta1=0.95, beta2=0.99, epsilon=1e-3):

    start_time = time.time()
    
    max_iter=500
    
    initial_function_value = f(x0)

    initial_function_value = f(x0)
    m = jnp.zeros_like(x0)
    v = jnp.zeros_like(x0)
    t = 0
    x = x0
    path = [x0]
    
    #while True:
    for i in range(max_iter):
        grad = grad_func(x)
        t += 1
        m = beta1 * m + (1 - beta1) * grad

        v = beta2 * v + (1 - beta2) * grad ** 2

        m_hat = m / (1 - beta1 ** t)
        v_hat = v / (1 - beta2 ** t)
        x = x - alpha * m_hat / (jnp.sqrt(v_hat) + epsilon)
        path.append(x)
        if jnp.linalg.norm(grad) < epsilon or abs(f(x) - initial_function_value) <= epsilon:
            break
            
    initial_function_value = f(x)

    end_time = time.time()  # End timing
    execution_time = end_time - start_time  # Calculate total execution time

    return x, path, execution_time

I am using the time library for rough performance measurement. The function in question is simple, as described.

1
2

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 27, 2024

When I try benchmarking your original function using jax.jit, I find that JAX is 4x faster than autograd on both CPU and GPU for inputs of size 1000

import autograd.numpy as anp

import jax
import jax.numpy as jnp

def f_autograd(x):

  term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2)  # Central parabolic valley

  # Nested valley 1 (deep, narrow)
  term2_x = -4 * anp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
  term2_y = -8 * anp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
  term2 = term2_x + term2_y

  # Nested valley 2 (wide, shallow)
  term3_x = 2 * anp.sin(5 * anp.pi * (x[0] - 1.25)) * anp.sin(5 * anp.pi * (x[1] - 1.75))
  term3_y = 3 * anp.sin(7 * anp.pi * (x[0] - 1.25)) * anp.sin(7 * anp.pi * (x[1] - 1.75))
  term3 = 0.2 * (term3_x + term3_y)  # Adjust coefficient for shallower valley

  term4 = 3 * anp.sin(3 * anp.pi * x[0]) * anp.sin(3 * anp.pi * x[1])  # Oscillating term

  term5 = -5 * anp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2)  # Deeper global minimum

  term6 = -anp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2)  # Local minimum

  term7 = -2 * anp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2)  # Local minimum

  return term1 + term2 + term3 + term4 + term5 + term6 + term7


@jax.jit
def f_jax(x):

  term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2)  # Central parabolic valley

  # Nested valley 1 (deep, narrow)
  term2_x = -4 * jnp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
  term2_y = -8 * jnp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
  term2 = term2_x + term2_y

  # Nested valley 2 (wide, shallow)
  term3_x = 2 * jnp.sin(5 * jnp.pi * (x[0] - 1.25)) * jnp.sin(5 * jnp.pi * (x[1] - 1.75))
  term3_y = 3 * jnp.sin(7 * jnp.pi * (x[0] - 1.25)) * jnp.sin(7 * jnp.pi * (x[1] - 1.75))
  term3 = 0.2 * (term3_x + term3_y)  # Adjust coefficient for shallower valley

  term4 = 3 * jnp.sin(3 * jnp.pi * x[0]) * jnp.sin(3 * jnp.pi * x[1])  # Oscillating term

  term5 = -5 * jnp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2)  # Deeper global minimum

  term6 = -jnp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2)  # Local minimum

  term7 = -2 * jnp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2)  # Local minimum

  return term1 + term2 + term3 + term4 + term5 + term6 + term7

shape = (2, 1000)

x_autograd = anp.ones(shape)
%timeit f_autograd(x_autograd)
# 797 µs ± 440 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

x_jax = jnp.ones(shape)
_ = f_jax(x_jax)  # trigger compilation
%timeit f_jax(x_jax).block_until_ready()
# 141 µs ± 17.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

This is on a Colab CPU runtime, using the built-in %timeit magic function. On a Colab T4 GPU, the timings I get are:

374 µs ± 73.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
92.9 µs ± 3.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

If you could include your full end-to-end benchmark script, including all imports, array definitions, function definitions, and function calls, I may be able to comment on why you are seeing different results.

@alibastami
Copy link
Author

alibastami commented Apr 27, 2024

Ah I think now I see, when I run your snippet I got this warning:

CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuSPARSE installation found.
Version JAX was built against: 12200
Minimum supported: 12100
Installed version: 12002
The local installation version must be no lower than 12100. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
221 µs ± 3.81 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
118 µs ± 3.85 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

And got :

JAX Available devices: [CpuDevice(id=0)]

When running:

devices = jax.devices()
print("JAX Available devices:", devices)

But when I import jaxlib this warning disappears, and the performance drops to become almost equal to autograd (btw, never seen this warning before, maybe because I was importing jaxlib, but why really?)

227 µs ± 19.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) -- autograd
182 µs ± 49.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) -- Jax

and still :

JAX Available devices: [CpuDevice(id=0)]

This is my nvcc --version:

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Feb__7_19:32:13_PST_2023
Cuda compilation tools, release 12.1, V12.1.66
Build cuda_12.1.r12.1/compiler.32415258_0

And my jax and jaxlib versions are 0.4.26

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working type:performance
Projects
None yet
Development

No branches or pull requests

2 participants