Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions tests/layers/vllm/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def test_forward_with_attention_sink(self, mesh):
kv_cache_dtype="auto",
attn_type=AttentionType.DECODER,
sinks=sinks)
impl.process_weights_after_loading(torch.bfloat16)

layer = MagicMock()
layer.layer_name = "0"
Expand All @@ -314,3 +315,35 @@ def test_forward_with_attention_sink(self, mesh):
layer_name_to_kvcache_index={'0': 0}):
assert impl.sinks is not None
impl.forward(layer, query, key, value, torch.tensor([]), metadata)

def test_forward_with_attention_sink_head_dim_128_raises_error(self, mesh):
head_dim = 128
sinks = torch.rand([NUM_HEADS], dtype=torch.float32)

impl = PallasAttentionBackendImpl(num_heads=NUM_HEADS,
head_size=head_dim,
scale=0.088,
num_kv_heads=NUM_KV_HEADS,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="auto",
attn_type=AttentionType.DECODER,
sinks=sinks)
impl.process_weights_after_loading(torch.bfloat16)

layer = MagicMock()
layer.layer_name = "0"

query, key, value, kv_cache, metadata = create_inputs(
mesh, head_dim=head_dim)

with torchax.default_env(), set_vllm_model_wrapper_context(
kv_caches=[kv_cache],
mesh=mesh,
layer_name_to_kvcache_index={'0': 0}
), pytest.raises(
NotImplementedError,
match=
"Attention sink support is only available when head_dim==64"):
assert impl.sinks is not None
impl.forward(layer, query, key, value, torch.tensor([]), metadata)
60 changes: 43 additions & 17 deletions tpu_inference/layers/vllm/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,19 @@ def __init__(
"are not implemented for "
"PallasAttentionBackendImpl")

#TODO (kyuyeunk): Shard the sinks along head axis.
self.sinks = sinks
if self.sinks is not None:
self.sinks = t2j(self.sinks, use_dlpack=False).astype(jnp.float32)
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer")

def process_weights_after_loading(self, act_dtype: torch.dtype):
#TODO (kyuyeunk): Shard the sinks along num_heads dim
if self.sinks is not None:
sinks = t2j(self.sinks, use_dlpack=False)
sinks = torch_view(sinks.astype(jnp.float32))
self.sinks = torch.nn.Parameter(sinks, requires_grad=False)

def forward(
self,
layer: AttentionLayer,
Expand Down Expand Up @@ -121,25 +126,44 @@ def forward(
k_scale = layer._k_scale_float
v_scale = layer._v_scale_float

sinks = None if self.sinks is None else jax_view(self.sinks)

new_kv_cache, outputs = _jax_attn_func(kv_cache, query, key, value,
sinks, attn_metadata, mesh,
self.scale, self.head_size,
self.num_heads,
self.num_kv_heads, q_scale,
k_scale, v_scale)
sinks = jax_view(self.sinks)

new_kv_cache, outputs = _jax_attn_func(
kv_cache,
query,
key,
value,
sinks,
attn_metadata,
mesh,
self.scale,
self.head_size,
self.num_heads,
self.num_kv_heads,
q_scale,
k_scale,
v_scale,
self.sliding_window,
)
vllm_model_wrapper_context.kv_caches[kv_cache_index] = new_kv_cache

return torch_view(outputs)


@functools.partial(
jax.jit,
static_argnums=(
6, 7, 8, 9, 10, 11, 12, 13
), # mesh, scale, head_size, num_heads, num_kv_heads, q_scale, k_scale, v_scale
donate_argnums=(0, ), # donate kv_cache
static_argnames=(
"mesh",
"scale",
"head_size",
"num_heads",
"num_kv_heads",
"q_scale",
"k_scale",
"v_scale",
"sliding_window",
),
donate_argnames=("kv_cache"),
)
def _jax_attn_func(
kv_cache: jax.Array,
Expand All @@ -153,9 +177,10 @@ def _jax_attn_func(
head_size: int,
num_heads: int,
num_kv_heads: int,
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
q_scale: float | None = None,
k_scale: float | None = None,
v_scale: float | None = None,
sliding_window: int | None = None,
) -> Tuple[jax.Array, jax.Array]:
del scale # Unused for now, as the attention function applies a default scale.

Expand Down Expand Up @@ -184,6 +209,7 @@ def _jax_attn_func(
k_scale=k_scale,
v_scale=v_scale,
sinks=sinks,
attention_chunk_size=sliding_window,
)

# Convert the shape back to vLLM's convention
Expand Down