Skip to content

Commit

Permalink
Adds ragged attention.
Browse files Browse the repository at this point in the history
  • Loading branch information
patemotter committed Aug 20, 2024
1 parent 14379df commit de4e075
Show file tree
Hide file tree
Showing 7 changed files with 479 additions and 18 deletions.
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -413,3 +413,6 @@ enable_single_controller: False

# Split physical axes for https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.mesh_utils.create_device_mesh.html
allow_split_physical_axes: False

use_ragged_attention: False
ragged_block_size: 256
302 changes: 302 additions & 0 deletions MaxText/kernels/ragged_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
"""Kernels for ragged attention."""

import functools

import jax
from jax import lax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
import common_types

from jax.experimental import shard_map


DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
BATCH = common_types.BATCH
shard_map = shard_map.shard_map


@functools.partial(jax.jit, static_argnames=["mask_value"])
def reference_mqa(
q: jax.Array,
k: jax.Array,
v: jax.Array,
lengths: jax.Array,
*,
mask_value: float = DEFAULT_MASK_VALUE,
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""Multi query attention reference.
Args:
q: A [batch_size, num_heads, head_dim] jax.Array.
k: A [batch_size, seq_len, head_dim] jax.Array.
v: A [batch_size, seq_len, head_dim] jax.Array.
lengths: A i32[batch_size] jax.Array.
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the
max logit ([batch_size, num_heads]) and softmax denominator ([batch_size,
num_heads]).
"""
logits = jnp.einsum(
"bhd,btd->bht", q.astype(jnp.float32), k.astype(jnp.float32)
)
mask = jnp.arange(k.shape[1])[None] < lengths[:, None]

logits = logits + jnp.where(mask, 0.0, mask_value)[:, None]
logits_max = logits.max(axis=-1)

unnormalized = jnp.exp(logits - logits_max[..., None])
denominator = unnormalized.sum(axis=-1)
o = (
jnp.einsum("bht,btd->bhd", unnormalized.astype(v.dtype), v)
/ denominator[..., None]
)
return o, logits_max[..., None], denominator[..., None]

@jax.jit
def reference_mha(
q: jax.Array,
k: jax.Array,
v: jax.Array,
lengths: jax.Array,
*,
mask_value: float = DEFAULT_MASK_VALUE,
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""Multi head attention reference.
Args:
q: A [batch_size, 1, num_heads, head_dim] jax.Array.
k: A [batch_size, seq_len, num_heads, head_dim] jax.Array.
v: A [batch_size, seq_len, num_heads, head_dim] jax.Array.
lengths: A i32[batch_size] jax.Array.
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the
max logit ([batch_size, num_heads]) and softmax denominator ([batch_size,
num_heads]).
"""
q = jnp.swapaxes(q, 1, 2)
k = jnp.swapaxes(k, 1, 2)
v = jnp.swapaxes(v, 1, 2)
return jax.vmap(functools.partial(
reference_mqa,
mask_value=mask_value),
in_axes=(1, 1, 1, None),
out_axes=2)(q, k, v, lengths)


def ragged_flash_attention_kernel(
lengths_ref,
q_ref,
k_ref,
v_ref,
o_ref,
m_ref,
l_ref,
*,
block_size: int,
mask_value: float,
):
"""Pallas kernel for flash attention."""
b, i = pl.program_id(0), pl.program_id(1)

@pl.when(i == 0)
def init():
m_ref[...] = jnp.full_like(m_ref, -jnp.inf)
l_ref[...] = jnp.zeros_like(l_ref)
o_ref[...] = jnp.zeros_like(o_ref)

length = lengths_ref[b]

@pl.when(i * block_size < length)
def run():
q = q_ref[...].astype(jnp.float32)
k = k_ref[...].astype(jnp.float32)
v = v_ref[...].astype(jnp.float32)
m_prev, l_prev = m_ref[...], l_ref[...]

qk = lax.dot_general(
q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32
)

mask = i * block_size + jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) < length
qk = qk + jnp.where(mask, 0.0, mask_value)
m_curr = qk.max(axis=-1)

s_curr = jnp.exp(qk - m_curr[..., None])
l_curr = jax.lax.broadcast_in_dim(s_curr.sum(axis=-1), l_prev.shape, (0,))
o_curr_times_l_curr = jnp.dot(s_curr, v)

m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,))
m_next = jnp.maximum(m_prev, m_curr)
alpha = jnp.exp(m_prev - m_next)
beta = jnp.exp(m_curr - m_next)
l_next = alpha * l_prev + beta * l_curr
l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next)

m_ref[...], l_ref[...] = m_next, l_next_safe
o_ref[...] = (
(l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) / l_next_safe
).astype(o_ref.dtype)


def ragged_mqa(
q: jax.Array,
k: jax.Array,
v: jax.Array,
lengths: jax.Array,
*,
block_size: int = 256,
mask_value: float = DEFAULT_MASK_VALUE,
cost_estimate: pltpu.CostEstimate | None = None,
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""Ragged multi query attention.
Args:
q: A [batch_size, 1, head_dim] jax.Array.
k: A [batch_size, seq_len, head_dim] jax.Array.
v: A [batch_size, seq_len, head_dim] jax.Array.
lengths: A i32[batch_size] jax.Array.
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
cost_estimate: A Pallas TPU cost estimate based on a reference implementation
Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the
max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size,
num_heads, 1]).
"""
batch_size, num_heads, head_dim = q.shape
assert lengths.shape == (batch_size,)
assert lengths.dtype == jnp.int32
seq_len = k.shape[1]

def compute_ragged_block_indices(b, i, lengths_ref):
length = lengths_ref[b]
not_done = i * block_size < length
am_last_batch = b == batch_size - 1
last_good_block = lax.div(length, block_size) - 1
b_next = jnp.where(not_done, b, jnp.where(am_last_batch, b, b + 1))
i_next = jnp.where(not_done, i, jnp.where(am_last_batch, last_good_block, 0))
return b_next, i_next, 0

out, m, l = pl.pallas_call(
functools.partial(
ragged_flash_attention_kernel,
block_size=block_size,
mask_value=mask_value,
),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec(
(None, num_heads, head_dim),
lambda b, i, _: (b, 0, 0)),
pl.BlockSpec(
(None, block_size, head_dim),
compute_ragged_block_indices),
pl.BlockSpec(
(None, block_size, head_dim),
compute_ragged_block_indices),
],
out_specs=[
pl.BlockSpec(
(None, num_heads, head_dim),
lambda b, i, _: (b, 0, 0)),
pl.BlockSpec(
(None, num_heads, head_dim),
lambda b, i, _: (b, 0, 0)),
pl.BlockSpec(
(None, num_heads, head_dim),
lambda b, i, _: (b, 0, 0)),
],
grid=(batch_size, seq_len // block_size),
),
compiler_params=dict(
mosaic=dict(
dimension_semantics=("parallel", "arbitrary"),
cost_estimate=cost_estimate,
)
),
out_shape=[
jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32),
jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32),
jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32),
],
)(lengths, q, k, v)
return out, m[..., 0], l[..., 0]


@functools.partial(
jax.jit,
static_argnames=[
"block_size",
"mask_value",
],
)
def ragged_mha(
query: jax.Array,
key: jax.Array,
value: jax.Array,
lengths: jax.Array,
*,
block_size: int = 256,
mask_value: float = DEFAULT_MASK_VALUE,
) -> tuple[jax.Array, jax.Array, jax.Array]:
"""Ragged multi head attention.
Args:
q: A [batch_size, 1, num_heads, head_dim] jax.Array.
k: A [batch_size, seq_len, num_heads, head_dim] jax.Array.
v: A [batch_size, seq_len, num_heads, head_dim] jax.Array.
lengths: A i32[batch_size] jax.Array.
block_size: Value defining the Pallas block length in the seq_len dimension
mask_value: The value used for padding in attention. By default it is a very
negative floating point number.
Returns:
The output of attention([batch_size, num_heads, head_dim]), along with the
max logit ([batch_size, num_heads, 1]) and softmax denominator ([batch_size,
num_heads, 1]).
"""
cost_analysis = (
reference_mha.lower(
query,
key,
value,
lengths,
mask_value=mask_value,
)
.compile()
.cost_analysis()[0]
)
cost_estimate = pltpu.CostEstimate(
flops=int(cost_analysis["flops"]),
transcendentals=int(cost_analysis["transcendentals"]),
bytes_accessed=int(cost_analysis["bytes accessed"]),
)

query = jnp.swapaxes(query, 1, 2)
key = jnp.swapaxes(key, 1, 2)
value = jnp.swapaxes(value, 1, 2)
o, m, l = jax.vmap(
functools.partial(
ragged_mqa,
block_size=block_size,
mask_value=mask_value,
cost_estimate=cost_estimate,
),
in_axes=(1, 1, 1, None),
out_axes=2,
)(query, key, value, lengths)
m = jnp.expand_dims(m, axis=-1)
l = jnp.expand_dims(l, axis=-1)
o = o * l
return o, m, l
Loading

0 comments on commit de4e075

Please sign in to comment.