11'''
2- Description :
2+ Description :
33Author : Boxin Zhang
44Version : 0.1.0
5- Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
5+ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
66'''
77import torch
88from torch import nn
1616from typing import Optional , Tuple
1717from ktransformers .operators .base_operator import BaseInjectedModule
1818from ktransformers .util .custom_gguf import GGUFLoader
19- from ktransformers .util .utils import get_compute_capability
19+ from ktransformers .util .feature_gate import KTRANSFORMERS_USE_TORCH_NATIVE , KTRANSFORMERS_USE_FLASHINFER
2020import logging
2121from transformers .configuration_utils import PretrainedConfig
2222from transformers .cache_utils import Cache
23- from ktransformers .util .vendors import device_manager , get_device , to_device , GPUVendor
2423
2524try :
2625 from flash_attn import flash_attn_func
2726except :
2827 pass
29- from ktransformers .operators .triton_attention import decode_attention_fwd_grouped
28+ from ktransformers .operators .triton_attention import decode_attention_fwd_grouped
3029from ktransformers .operators .triton_attention_prefill import context_attention_fwd
3130import os
3231from ktransformers .operators .flashinfer_wrapper import flashinfer_enabled
@@ -69,7 +68,7 @@ def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
6968 kv_b_proj = self .kv_b_proj .weight .view (self .num_heads , - 1 , self .kv_lora_rank )
7069 self .q_absorb = kv_b_proj [:, :self .qk_nope_head_dim , :].view (self .num_heads , self .qk_nope_head_dim , self .kv_lora_rank )
7170 self .out_absorb = kv_b_proj [:, self .qk_nope_head_dim :, :].view (self .num_heads , self .v_head_dim , self .kv_lora_rank )
72-
71+
7372 return self .q_absorb , self .out_absorb
7473
7574 def forward_chunck (
@@ -117,7 +116,7 @@ def forward_chunck(
117116
118117 if past_key_value is not None :
119118 cache_kwargs = {"sin" : sin , "cos" : cos , "cache_position" : cache_position } # Specific to RoPE models
120-
119+
121120 # compressed_kv [bsz, q_len, self.kv_lora_rank]
122121 # k_pe [bsz, 1, q_len, self.qk_rope_head_dim]
123122 k_pe = k_pe .transpose (1 ,2 )
@@ -128,7 +127,7 @@ def forward_chunck(
128127 )
129128 # k_pe [pages, page_size, 1, self.qk_rope_head_dim]
130129 # compressed_kv [pages, page_size, 1, self.kv_lora_rank]
131-
130+
132131 q_absorb , out_absorb = self .get_absorbed ()
133132
134133 # q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
@@ -142,9 +141,9 @@ def forward_chunck(
142141 #print(k_pe.shape)
143142 #print(q_nope.shape)
144143 #print(compressed_kv.shape)
145-
144+
146145 attn_weights = (torch .matmul (q_pe , k_pe .mT ) + torch .matmul (q_nope , compressed_kv .mT )) * self .softmax_scale
147-
146+
148147 #attn_weights [bsz, self.num_heads, q_len, kv_seq_len]
149148 compressed_kv = compressed_kv .squeeze (1 )
150149 """
@@ -172,10 +171,10 @@ def forward_chunck(
172171 attn_weights = nn .functional .dropout (
173172 attn_weights , p = self .attention_dropout , training = self .training
174173 )
175-
174+
176175 attn_output = torch .einsum ('bhql,blc->bhqc' , attn_weights , compressed_kv )
177-
178- attn_output = torch .matmul (attn_output , out_absorb .mT )
176+
177+ attn_output = torch .matmul (attn_output , out_absorb .mT )
179178
180179 if attn_output .size () != (bsz , self .num_heads , q_len , self .v_head_dim ):
181180 raise ValueError (
@@ -184,7 +183,7 @@ def forward_chunck(
184183 )
185184
186185 attn_output = attn_output .transpose (1 , 2 ).contiguous ()
187-
186+
188187 attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim )
189188
190189 attn_output = self .o_proj (attn_output )
@@ -231,11 +230,11 @@ def forward_linux_triton(
231230 "with a layer index."
232231 )
233232 kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
234-
233+
235234 cos , sin = self .rotary_emb (q_pe , position_ids )
236235 q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , unsqueeze_dim = 2 )
237236 # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
238-
237+
239238 # decode
240239 if q_len == 1 :
241240 if past_key_value is not None :
@@ -252,20 +251,20 @@ def forward_linux_triton(
252251 q_nope = torch .matmul (q_nope , q_absorb ) # batched MM
253252 q_nope = q_nope .transpose (1 , 2 )
254253 #assert q_nope.is_contiguous()
255-
254+
256255 # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
257256 # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
258257 query_states = torch .cat ([q_nope , q_pe ], dim = - 1 )
259-
258+
260259 query_states = query_states .squeeze (1 )
261260 attn_output = torch .zeros_like (q_nope ) # [bsz, q_len, self.num_heads, self.kv_lora_rank]
262-
261+
263262 attn_logits = torch .empty (
264263 (
265264 bsz ,
266265 self .num_heads ,
267266 4 , #num_kv_splits # follow vLLM, fix it TODO
268- self .kv_lora_rank + 1 ,
267+ self .kv_lora_rank + 1 ,
269268 ),
270269 dtype = torch .float32 ,
271270 device = attn_output .device
@@ -286,16 +285,16 @@ def forward_linux_triton(
286285 4 , #num_kv_splits # follow vLLM, fix it TODO
287286 self .softmax_scale ,
288287 past_key_value .page_size )
289-
288+
290289 # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
291290 # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
292291 attn_output = attn_output .transpose (1 , 2 )
293292 attn_output = torch .matmul (attn_output , out_absorb .mT )
294293 attn_output = attn_output .transpose (1 , 2 )
295-
294+
296295 attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim )
297296 attn_output = self .o_proj (attn_output )
298-
297+
299298 #print("attn_output", torch.isnan(attn_output).any())
300299 return attn_output , None , past_key_value
301300 else :
@@ -323,7 +322,7 @@ def forward_linux_triton(
323322 key_states = k_pe .new_empty (bsz , kv_seq_len , self .num_heads , self .q_head_dim )
324323 key_states [:, :, :, :self .qk_nope_head_dim ] = k_nope
325324 key_states [:, :, :, self .qk_nope_head_dim :] = k_pe .view (bsz , kv_seq_len , 1 , - 1 )
326-
325+
327326 value_states = value_states .view (bsz , kv_seq_len , self .num_heads , self .v_head_dim )
328327 value_states_padded = torch .nn .functional .pad (value_states , [0 , query_states .shape [- 1 ] - value_states .shape [- 1 ]], value = 0 )
329328
@@ -384,11 +383,11 @@ def forward_linux_flashinfer(
384383 "with a layer index."
385384 )
386385 kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
387-
386+
388387 cos , sin = self .rotary_emb (q_pe , position_ids )
389388 q_pe , k_pe = apply_rotary_pos_emb (q_pe , k_pe , cos , sin , unsqueeze_dim = 2 )
390389 # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim] k_pe [bsz, q_len, 1, self.qk_rope_head_dim]
391-
390+
392391 # decode
393392 if q_len == 1 or self .absorb_for_prefill :
394393 if past_key_value is not None :
@@ -407,7 +406,7 @@ def forward_linux_flashinfer(
407406 q_nope = q_nope .transpose (1 , 2 )
408407 q_nope = q_nope .contiguous ()
409408 #assert q_nope.is_contiguous()
410-
409+
411410 # q_nope [bsz, q_len, self.num_heads, self.kv_lora_rank]
412411 # q_pe [bsz, q_len, self.num_heads, self.qk_rope_head_dim]
413412 q_nope .squeeze_ (0 )
@@ -460,17 +459,17 @@ def forward_linux_flashinfer(
460459 )
461460 attn_output = attn_ref.view(bsz, q_len, self.num_heads, self.kv_lora_rank)
462461 """
463-
462+
464463 # mla_wrapper run output: [tokens, self.num_heads, self.kv_lora_rank]
465464 # attn_output [bsz, q_len, self.num_heads, self.kv_lora_rank]
466465 # out_absorb [self.num_heads, self.v_head_dim, self.kv_lora_rank]
467466 attn_output = attn_output .transpose (1 , 2 ) # [bsz, self.num_heads, q_len, self.kv_lora_rank]
468467 attn_output = torch .matmul (attn_output , out_absorb .mT ) # [bsz, self.num_heads, q_len, self.v_head_dim]
469468 attn_output = attn_output .transpose (1 , 2 ).contiguous () # [bsz, q_len, self.num_heads, self.kv_lora_rank]
470-
469+
471470 attn_output = attn_output .reshape (bsz , q_len , self .num_heads * self .v_head_dim ) # [bsz, q_len, self.num_heads * self.v_head_dim]
472471 attn_output = self .o_proj (attn_output )
473-
472+
474473 return attn_output , None , past_key_value
475474 else :
476475 if past_key_value is not None :
@@ -497,7 +496,7 @@ def forward_linux_flashinfer(
497496 key_states = k_pe .new_empty (bsz , kv_seq_len , self .num_heads , self .q_head_dim )
498497 key_states [:, :, :, :self .qk_nope_head_dim ] = k_nope
499498 key_states [:, :, :, self .qk_nope_head_dim :] = k_pe .view (bsz , kv_seq_len , 1 , - 1 )
500-
499+
501500 value_states = value_states .view (bsz , kv_seq_len , self .num_heads , self .v_head_dim )
502501 value_states_padded = torch .nn .functional .pad (value_states , [0 , query_states .shape [- 1 ] - value_states .shape [- 1 ]], value = 0 )
503502
@@ -517,7 +516,7 @@ def forward_linux_flashinfer(
517516 ).contiguous ()
518517 attn_output = self .o_proj (attn_output )
519518 return attn_output , None , past_key_value
520-
519+
521520 def forward_windows (
522521 self ,
523522 hidden_states : torch .Tensor ,
@@ -581,7 +580,7 @@ def forward_windows(
581580 attn_output = cur_output
582581 else :
583582 attn_output = torch .cat ((attn_output , cur_output ), dim = - 2 )
584-
583+
585584 return attn_output , None , past_key_value
586585
587586 def forward (
@@ -595,7 +594,7 @@ def forward(
595594 cache_position : Optional [torch .LongTensor ] = None ,
596595 ** kwargs ,
597596 ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
598- if os . name == 'nt' or get_compute_capability () < 8 or device_manager . gpu_vendor != GPUVendor . NVIDIA :
597+ if KTRANSFORMERS_USE_TORCH_NATIVE :
599598 return self .forward_windows (
600599 hidden_states ,
601600 attention_mask ,
@@ -607,7 +606,7 @@ def forward(
607606 ** kwargs ,
608607 )
609608 else :
610- if flashinfer_enabled :
609+ if KTRANSFORMERS_USE_FLASHINFER :
611610 return self .forward_linux_flashinfer (
612611 hidden_states ,
613612 attention_mask ,
0 commit comments