@@ -71,14 +71,19 @@ def __init__(
7171 "are not implemented for "
7272 "PallasAttentionBackendImpl" )
7373
74- #TODO (kyuyeunk): Shard the sinks along head axis.
7574 self .sinks = sinks
7675 if self .sinks is not None :
77- self .sinks = t2j (self .sinks , use_dlpack = False ).astype (jnp .float32 )
7876 assert self .sinks .shape [0 ] == num_heads , (
7977 "Sinks must have the same number of heads as the number of "
8078 "heads in the layer" )
8179
80+ def process_weights_after_loading (self , act_dtype : torch .dtype ):
81+ #TODO (kyuyeunk): Shard the sinks along num_heads dim
82+ if self .sinks is not None :
83+ sinks = t2j (self .sinks , use_dlpack = False )
84+ sinks = torch_view (sinks .astype (jnp .float32 ))
85+ self .sinks = torch .nn .Parameter (sinks , requires_grad = False )
86+
8287 def forward (
8388 self ,
8489 layer : AttentionLayer ,
@@ -121,25 +126,44 @@ def forward(
121126 k_scale = layer ._k_scale_float
122127 v_scale = layer ._v_scale_float
123128
124- sinks = None if self .sinks is None else jax_view (self .sinks )
125-
126- new_kv_cache , outputs = _jax_attn_func (kv_cache , query , key , value ,
127- sinks , attn_metadata , mesh ,
128- self .scale , self .head_size ,
129- self .num_heads ,
130- self .num_kv_heads , q_scale ,
131- k_scale , v_scale )
129+ sinks = jax_view (self .sinks )
130+
131+ new_kv_cache , outputs = _jax_attn_func (
132+ kv_cache ,
133+ query ,
134+ key ,
135+ value ,
136+ sinks ,
137+ attn_metadata ,
138+ mesh ,
139+ self .scale ,
140+ self .head_size ,
141+ self .num_heads ,
142+ self .num_kv_heads ,
143+ q_scale ,
144+ k_scale ,
145+ v_scale ,
146+ self .sliding_window ,
147+ )
132148 vllm_model_wrapper_context .kv_caches [kv_cache_index ] = new_kv_cache
133149
134150 return torch_view (outputs )
135151
136152
137153@functools .partial (
138154 jax .jit ,
139- static_argnums = (
140- 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13
141- ), # mesh, scale, head_size, num_heads, num_kv_heads, q_scale, k_scale, v_scale
142- donate_argnums = (0 , ), # donate kv_cache
155+ static_argnames = (
156+ "mesh" ,
157+ "scale" ,
158+ "head_size" ,
159+ "num_heads" ,
160+ "num_kv_heads" ,
161+ "q_scale" ,
162+ "k_scale" ,
163+ "v_scale" ,
164+ "sliding_window" ,
165+ ),
166+ donate_argnames = ("kv_cache" ),
143167)
144168def _jax_attn_func (
145169 kv_cache : jax .Array ,
@@ -153,9 +177,10 @@ def _jax_attn_func(
153177 head_size : int ,
154178 num_heads : int ,
155179 num_kv_heads : int ,
156- q_scale : Optional [float ] = None ,
157- k_scale : Optional [float ] = None ,
158- v_scale : Optional [float ] = None ,
180+ q_scale : float | None = None ,
181+ k_scale : float | None = None ,
182+ v_scale : float | None = None ,
183+ sliding_window : int | None = None ,
159184) -> Tuple [jax .Array , jax .Array ]:
160185 del scale # Unused for now, as the attention function applies a default scale.
161186
@@ -184,6 +209,7 @@ def _jax_attn_func(
184209 k_scale = k_scale ,
185210 v_scale = v_scale ,
186211 sinks = sinks ,
212+ attention_chunk_size = sliding_window ,
187213 )
188214
189215 # Convert the shape back to vLLM's convention
0 commit comments