Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pallas jax.lax.fori_loop over long inputs slows down #20909

Open
jbuckman opened this issue Apr 24, 2024 · 1 comment
Open

Pallas jax.lax.fori_loop over long inputs slows down #20909

jbuckman opened this issue Apr 24, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@jbuckman
Copy link

jbuckman commented Apr 24, 2024

Description

Inside Pallas kernels, we often want a loop, and to speed up compilation, we typically use a scan function such as jax.lax.fori_loop. (For example, in the attention kernel example here.)

As the length of the loop grows, fori_loop slows down execution substantially (relative to using a Python for-loop). I put together a minimal script to isolate the issue, and running it on an A6000, saw a 2-3x slowdown on long loops:

T=256
python for-loop: compile = 177ms, execution ms_per_kernel_call = 0.317ms
jax.lax.fori_loop: compile = 193ms, execution ms_per_kernel_call = 0.318ms
T=2048
python for-loop: compile = 242ms, execution ms_per_kernel_call = 2.247ms
jax.lax.fori_loop: compile = 200ms, execution ms_per_kernel_call = 2.255ms
T=8192
python for-loop: compile = 473ms, execution ms_per_kernel_call = 8.946ms
jax.lax.fori_loop: compile = 194ms, execution ms_per_kernel_call = 9.281ms
T=16384
python for-loop: compile = 776ms, execution ms_per_kernel_call = 18.177ms
jax.lax.fori_loop: compile = 198ms, execution ms_per_kernel_call = 22.288ms
T=32768
python for-loop: compile = 1460ms, execution ms_per_kernel_call = 36.009ms
jax.lax.fori_loop: compile = 200ms, execution ms_per_kernel_call = 58.552ms
T=65536
python for-loop: compile = 2978ms, execution ms_per_kernel_call = 71.313ms
jax.lax.fori_loop: compile = 195ms, execution ms_per_kernel_call = 172.925ms

Here is the script that generated these results:

import time

import jax
import jax.experimental.pallas as pl
import jax.numpy as jnp

# ---------- PALLAS ------------

class JaxKernel:
    fwd_blk_i = 128
    fwd_blk_j = 64

    def __init__(self, use_scan):
        self.use_scan = use_scan

    def __call__(self, X):
        t, d = X.shape
        grid = (t // self.fwd_blk_i,)
        Y = pl.pallas_call(
            self.fwd_kernel,
            grid=grid,
            out_shape=jax.ShapeDtypeStruct(X.shape, X.dtype),
        )(X)
        return Y

    def fwd_kernel(self, X_ref, Y_ref):
        i = pl.program_id(0)
        t, d = X_ref.shape

        X_i = pl.load(X_ref, pl.ds(start=i * self.fwd_blk_i, size=self.fwd_blk_i))
        Y_i_acc = jnp.zeros([self.fwd_blk_i, d], dtype=X_i.dtype)
        def body(j, carry):
            B_ij = X_i.sum()
            carry += B_ij + j # crashes if loop variable not involved
            return carry

        if self.use_scan:
            Y_i = jax.lax.fori_loop(0, t // self.fwd_blk_j, body, Y_i_acc)
        else:
            for j in range(0, t // self.fwd_blk_j):
                Y_i_acc = body(j, Y_i_acc)
            Y_i = Y_i_acc

        pl.store(Y_ref, pl.ds(start=i*self.fwd_blk_i, size=self.fwd_blk_i), Y_i)

# ------ BENCHMARK UTILS -------

def prepare_data(b, t, d, dtype):
    """Creates the data for a forward pass."""
    return jax.random.normal(jax.random.PRNGKey(0), shape=(b, t, d), dtype=dtype)

def heavyweight(f, n, b, t, d, dtype):
    """Given a kernel for a single batch, transforms it to repeat on many batches.

    Returns output to make sure no computation gets compiled away."""
    @jax.jit
    def heavy_f():
        def batch_f(X):
            out = f(X).mean()
            return out
        def scanner(_, __):
            X = prepare_data(b, t, d, dtype)
            return None, batch_f(X)
        _, Y = jax.lax.scan(scanner, None, jnp.arange(n))
        return Y
    return heavy_f

# ---- RUN -----

def main():
    # set up pallas att
    pallas_scanless = jax.vmap(JaxKernel(use_scan=False))
    pallas_scanner = jax.vmap(JaxKernel(use_scan=True))

    # confirm output is correct
    X = prepare_data(1, 2048, 16, jnp.float32)
    pallas_scanless_Y = pallas_scanless(X)
    pallas_scanner_Y = pallas_scanner(X)
    assert jnp.allclose(pallas_scanner_Y, pallas_scanless_Y, atol=.001)

    # choose hyperparameters
    N = 16      # number of times to repeat execution
    B = 512     # batch size
    T = 8192   # context size
    D = 64      # feature size
    dtype = jnp.float16
    print(f'{N=} {B=} {T=} {D=}')

    # jit functions
    heavy_scanless = heavyweight(pallas_scanless, N, B, T, D, dtype)
    heavy_scanner = heavyweight(pallas_scanner, N, B, T, D, dtype)

    # compile
    _t = time.time()
    jax.block_until_ready(heavy_scanless())
    scanless_time_to_compile_and_execute = time.time() - _t

    _t = time.time()
    jax.block_until_ready(heavy_scanner())
    scanner_time_to_compile_and_execute = time.time() - _t

    # time the main execution
    _t = time.time()
    jax.block_until_ready(heavy_scanless())
    scanless_time_to_execute = time.time() - _t
    ms_per_kernel_call = 1000 * scanless_time_to_execute / N
    scanless_time_to_compile = (
            scanless_time_to_compile_and_execute - scanless_time_to_execute)
    print(f'{1000*scanless_time_to_compile = :.3f}ms, '
          f'execution {ms_per_kernel_call = :.3f}ms')

    _t = time.time()
    jax.block_until_ready(heavy_scanner())
    scanner_time_to_execute = time.time() - _t
    ms_per_kernel_call = 1000 * scanner_time_to_execute / N
    scanner_time_to_compile = (
            scanner_time_to_compile_and_execute - scanner_time_to_execute)
    print(f'{1000*scanner_time_to_compile = :.3f}ms, '
          f'execution {ms_per_kernel_call = :.3f}ms')


if __name__ == '__main__':
    main()

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.4
python: 3.11.4 (main, Dec  7 2023, 15:43:41) [GCC 12.3.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='jacob-manifestai', release='6.2.0-39-generic', version='#40-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 14 14:18:00 UTC 2023', machine='x86_64')


$ nvidia-smi
Wed Apr 24 01:15:09 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07             Driver Version: 535.161.07   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA RTX A6000               On  | 00000000:2D:00.0 Off |                  Off |
| 30%   39C    P2              38W / 300W |    269MiB / 49140MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA RTX A6000               On  | 00000000:41:00.0 Off |                  Off |
| 30%   40C    P2              30W / 300W |    269MiB / 49140MiB |      1%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   4176762      C   ...anifest2-ozEuPuop-py3.11/bin/python      262MiB |
|    1   N/A  N/A   4176762      C   ...anifest2-ozEuPuop-py3.11/bin/python      262MiB |
+---------------------------------------------------------------------------------------+
@jbuckman jbuckman added the bug Something isn't working label Apr 24, 2024
@masylum
Copy link

masylum commented Apr 24, 2024

I think the best explanation I found online is the following:

To elaborate on this, the reason GPU is so fast for vectorized operations is not that individual floating point operations are particularly fast (they're actually often slower than similar operations on a CPU!), but rather that it can very efficiently run many operations in parallel. For an operation like scan in which each step depends on the output of the previous, the sequence of operations as a whole cannot be parallelized. So you end up not taking advantage of any of the GPU's inherent parallelism, and the result is slow execution.
Contrast this to CPU, where individual floating point operations are relatively fast, but there is no so much in-built parallelism available. Because of this, scan does not incur as much of a performance penalty.

I think this falls into the same bucket. Like my example, where the input of each layer depended on the output of the previous one (the carry), this is a pure sequential loop. There is probably a sweet spot with the unroll parameter where compilation times and loop times are optimal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants