Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel compilation hangs with a particular dtype #197

Open
hr0nix opened this issue Jul 16, 2023 · 1 comment
Open

Kernel compilation hangs with a particular dtype #197

hr0nix opened this issue Jul 16, 2023 · 1 comment

Comments

@hr0nix
Copy link

hr0nix commented Jul 16, 2023

Here's the pallas kernel from the repo that I've slightly modified by introducing control over accumulator dtype:

def mha_forward_kernel(
    q_ref,
    k_ref,
    v_ref,
    o_ref,
    *residual_refs,
    dot_product_scale: float,
    block_q: int,
    block_d: int,
    block_kv: int
):
    dtype = jnp.float32  # HANGS IF I REPLACE THIS WITH BFLOAT16 !!!

    seq_len = q_ref.shape[0]
    start_q = pl.program_id(0)

    neg_inf = -1e20

    # acc is the buffer where we accumulate the output on sram.
    # m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
    m_i = jnp.full(block_q, dtype=dtype, fill_value=neg_inf)
    l_i = jnp.zeros(block_q, dtype=dtype)
    # acc is the buffer where we accumulate the output on sram.
    acc = jnp.zeros((block_q, block_d), dtype=dtype)

    # Load q: it will stay in L1 throughout. Indices form a matrix because we
    # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
    # q tile has shape [block_q, block_d], block_d == head_dim.
    q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)))

    # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
    # (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
    # Here we only loop over blocks of kv to process entire seq_len, the loop over
    # blocks of q is carried out by the grid.
    def body(start_k, carry):
        acc, m_prev, l_prev = carry

        k = pl.load(k_ref, (pl.dslice(start_k * block_kv, block_kv), slice(None)))

        qk = jnp.zeros([block_q, block_kv], dtype=dtype)
        qk += pl.dot(q, k.T)  # [block_q, block_k]
        qk *= dot_product_scale  # [block_q, block_k]

        m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev)
        l_prev *= jnp.exp(m_prev - m_curr)
        p = jnp.exp(qk - m_curr[:, None])
        l_curr = jnp.sum(p, axis=1) + l_prev

        l_rcp = jnp.ones((), dtype=dtype) / l_curr
        p = p * l_rcp[:, None]
        acc *= (l_prev * l_rcp)[:, None]

        v = pl.load(
            v_ref, (pl.dslice(start_k * block_kv, block_kv), pl.dslice(block_d))
        )
        acc = acc + pl.dot(p.astype(v.dtype), v)
        return acc.astype(dtype), m_curr.astype(dtype), l_curr.astype(dtype)

    upper_bound = jt.cdiv(seq_len, block_kv)
    acc, m_i, l_i = jax.lax.fori_loop(0, upper_bound, body, (acc, m_i, l_i))

    if residual_refs:
        l_ref, m_ref = residual_refs
        pl.store(l_ref, (pl.ds(start_q * block_q, block_q),), l_i)
        pl.store(m_ref, (pl.ds(start_q * block_q, block_q),), m_i)

    # Write output to dram.
    acc = acc.astype(o_ref.dtype)
    pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc)

Suprisingly, the compilation of this kernel hangs (!) if I set the dtype to be bfloat16. I suspect there's a bug somewhere.

@sharadmv
Copy link
Collaborator

Thanks for the heads-up. This is likely a Triton compiler bug but I will try to repro and investigate this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants