We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Here's the pallas kernel from the repo that I've slightly modified by introducing control over accumulator dtype:
def mha_forward_kernel( q_ref, k_ref, v_ref, o_ref, *residual_refs, dot_product_scale: float, block_q: int, block_d: int, block_kv: int ): dtype = jnp.float32 # HANGS IF I REPLACE THIS WITH BFLOAT16 !!! seq_len = q_ref.shape[0] start_q = pl.program_id(0) neg_inf = -1e20 # acc is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. m_i = jnp.full(block_q, dtype=dtype, fill_value=neg_inf) l_i = jnp.zeros(block_q, dtype=dtype) # acc is the buffer where we accumulate the output on sram. acc = jnp.zeros((block_q, block_d), dtype=dtype) # Load q: it will stay in L1 throughout. Indices form a matrix because we # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_q, block_d], block_d == head_dim. q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None))) # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size # (Bc == block_k here), and fast over blocks of q (size Br == block_q here). # Here we only loop over blocks of kv to process entire seq_len, the loop over # blocks of q is carried out by the grid. def body(start_k, carry): acc, m_prev, l_prev = carry k = pl.load(k_ref, (pl.dslice(start_k * block_kv, block_kv), slice(None))) qk = jnp.zeros([block_q, block_kv], dtype=dtype) qk += pl.dot(q, k.T) # [block_q, block_k] qk *= dot_product_scale # [block_q, block_k] m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev) l_prev *= jnp.exp(m_prev - m_curr) p = jnp.exp(qk - m_curr[:, None]) l_curr = jnp.sum(p, axis=1) + l_prev l_rcp = jnp.ones((), dtype=dtype) / l_curr p = p * l_rcp[:, None] acc *= (l_prev * l_rcp)[:, None] v = pl.load( v_ref, (pl.dslice(start_k * block_kv, block_kv), pl.dslice(block_d)) ) acc = acc + pl.dot(p.astype(v.dtype), v) return acc.astype(dtype), m_curr.astype(dtype), l_curr.astype(dtype) upper_bound = jt.cdiv(seq_len, block_kv) acc, m_i, l_i = jax.lax.fori_loop(0, upper_bound, body, (acc, m_i, l_i)) if residual_refs: l_ref, m_ref = residual_refs pl.store(l_ref, (pl.ds(start_q * block_q, block_q),), l_i) pl.store(m_ref, (pl.ds(start_q * block_q, block_q),), m_i) # Write output to dram. acc = acc.astype(o_ref.dtype) pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc)
Suprisingly, the compilation of this kernel hangs (!) if I set the dtype to be bfloat16. I suspect there's a bug somewhere.
The text was updated successfully, but these errors were encountered:
Thanks for the heads-up. This is likely a Triton compiler bug but I will try to repro and investigate this week.
Sorry, something went wrong.
No branches or pull requests
Here's the pallas kernel from the repo that I've slightly modified by introducing control over accumulator dtype:
Suprisingly, the compilation of this kernel hangs (!) if I set the dtype to be bfloat16. I suspect there's a bug somewhere.
The text was updated successfully, but these errors were encountered: