Skip to content

Commit 6bd4e16

Browse files
authored
[Torchax] Fix sink parameter initialization (#1056)
Signed-off-by: Kyuyeun Kim <[email protected]>
1 parent a05034c commit 6bd4e16

File tree

2 files changed

+76
-17
lines changed

2 files changed

+76
-17
lines changed

tests/layers/vllm/test_attention.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def test_forward_with_attention_sink(self, mesh):
301301
kv_cache_dtype="auto",
302302
attn_type=AttentionType.DECODER,
303303
sinks=sinks)
304+
impl.process_weights_after_loading(torch.bfloat16)
304305

305306
layer = MagicMock()
306307
layer.layer_name = "0"
@@ -314,3 +315,35 @@ def test_forward_with_attention_sink(self, mesh):
314315
layer_name_to_kvcache_index={'0': 0}):
315316
assert impl.sinks is not None
316317
impl.forward(layer, query, key, value, torch.tensor([]), metadata)
318+
319+
def test_forward_with_attention_sink_head_dim_128_raises_error(self, mesh):
320+
head_dim = 128
321+
sinks = torch.rand([NUM_HEADS], dtype=torch.float32)
322+
323+
impl = PallasAttentionBackendImpl(num_heads=NUM_HEADS,
324+
head_size=head_dim,
325+
scale=0.088,
326+
num_kv_heads=NUM_KV_HEADS,
327+
alibi_slopes=None,
328+
sliding_window=None,
329+
kv_cache_dtype="auto",
330+
attn_type=AttentionType.DECODER,
331+
sinks=sinks)
332+
impl.process_weights_after_loading(torch.bfloat16)
333+
334+
layer = MagicMock()
335+
layer.layer_name = "0"
336+
337+
query, key, value, kv_cache, metadata = create_inputs(
338+
mesh, head_dim=head_dim)
339+
340+
with torchax.default_env(), set_vllm_model_wrapper_context(
341+
kv_caches=[kv_cache],
342+
mesh=mesh,
343+
layer_name_to_kvcache_index={'0': 0}
344+
), pytest.raises(
345+
NotImplementedError,
346+
match=
347+
"Attention sink support is only available when head_dim==64"):
348+
assert impl.sinks is not None
349+
impl.forward(layer, query, key, value, torch.tensor([]), metadata)

tpu_inference/layers/vllm/attention.py

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,19 @@ def __init__(
7171
"are not implemented for "
7272
"PallasAttentionBackendImpl")
7373

74-
#TODO (kyuyeunk): Shard the sinks along head axis.
7574
self.sinks = sinks
7675
if self.sinks is not None:
77-
self.sinks = t2j(self.sinks, use_dlpack=False).astype(jnp.float32)
7876
assert self.sinks.shape[0] == num_heads, (
7977
"Sinks must have the same number of heads as the number of "
8078
"heads in the layer")
8179

80+
def process_weights_after_loading(self, act_dtype: torch.dtype):
81+
#TODO (kyuyeunk): Shard the sinks along num_heads dim
82+
if self.sinks is not None:
83+
sinks = t2j(self.sinks, use_dlpack=False)
84+
sinks = torch_view(sinks.astype(jnp.float32))
85+
self.sinks = torch.nn.Parameter(sinks, requires_grad=False)
86+
8287
def forward(
8388
self,
8489
layer: AttentionLayer,
@@ -121,25 +126,44 @@ def forward(
121126
k_scale = layer._k_scale_float
122127
v_scale = layer._v_scale_float
123128

124-
sinks = None if self.sinks is None else jax_view(self.sinks)
125-
126-
new_kv_cache, outputs = _jax_attn_func(kv_cache, query, key, value,
127-
sinks, attn_metadata, mesh,
128-
self.scale, self.head_size,
129-
self.num_heads,
130-
self.num_kv_heads, q_scale,
131-
k_scale, v_scale)
129+
sinks = jax_view(self.sinks)
130+
131+
new_kv_cache, outputs = _jax_attn_func(
132+
kv_cache,
133+
query,
134+
key,
135+
value,
136+
sinks,
137+
attn_metadata,
138+
mesh,
139+
self.scale,
140+
self.head_size,
141+
self.num_heads,
142+
self.num_kv_heads,
143+
q_scale,
144+
k_scale,
145+
v_scale,
146+
self.sliding_window,
147+
)
132148
vllm_model_wrapper_context.kv_caches[kv_cache_index] = new_kv_cache
133149

134150
return torch_view(outputs)
135151

136152

137153
@functools.partial(
138154
jax.jit,
139-
static_argnums=(
140-
6, 7, 8, 9, 10, 11, 12, 13
141-
), # mesh, scale, head_size, num_heads, num_kv_heads, q_scale, k_scale, v_scale
142-
donate_argnums=(0, ), # donate kv_cache
155+
static_argnames=(
156+
"mesh",
157+
"scale",
158+
"head_size",
159+
"num_heads",
160+
"num_kv_heads",
161+
"q_scale",
162+
"k_scale",
163+
"v_scale",
164+
"sliding_window",
165+
),
166+
donate_argnames=("kv_cache"),
143167
)
144168
def _jax_attn_func(
145169
kv_cache: jax.Array,
@@ -153,9 +177,10 @@ def _jax_attn_func(
153177
head_size: int,
154178
num_heads: int,
155179
num_kv_heads: int,
156-
q_scale: Optional[float] = None,
157-
k_scale: Optional[float] = None,
158-
v_scale: Optional[float] = None,
180+
q_scale: float | None = None,
181+
k_scale: float | None = None,
182+
v_scale: float | None = None,
183+
sliding_window: int | None = None,
159184
) -> Tuple[jax.Array, jax.Array]:
160185
del scale # Unused for now, as the attention function applies a default scale.
161186

@@ -184,6 +209,7 @@ def _jax_attn_func(
184209
k_scale=k_scale,
185210
v_scale=v_scale,
186211
sinks=sinks,
212+
attention_chunk_size=sliding_window,
187213
)
188214

189215
# Convert the shape back to vLLM's convention

0 commit comments

Comments
 (0)