@@ -117,17 +117,26 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
117117 load_dequantized_tensor = gguf_loader .load_gguf_tensor
118118 tensor_file_map = gguf_loader .tensor_file_map
119119
120- if gguf_loader .has_tensor (translated_key ):
120+ if gguf_loader .has_tensor (translated_key ) or "kv_b_proj" in translated_key :
121121 target_dtype = torch .get_default_dtype ()
122122 device = get_device (translated_key [:translated_key .rfind ("." )], gguf_loader .tensor_device_map )
123123 print (f"loading { translated_key } to { device } " )
124124 if torch .cuda .is_available ():
125125 torch .cuda .empty_cache ()
126126 elif torch .xpu .is_available ():
127127 torch .xpu .empty_cache ()
128- weights = load_dequantized_tensor (translated_key , device = device ).to (dtype = target_dtype )
129- set_param (module , name , weights )
130- del weights
128+ if "kv_b_proj" in translated_key and not gguf_loader .has_tensor (translated_key ):
129+ attn_k_b = load_dequantized_tensor (translated_key .replace ("self_attn.kv_b_proj" , "attn_k_b" ), device = device ).to (dtype = target_dtype )
130+ attn_k_b = attn_k_b .transpose (1 , 2 ).contiguous ()
131+ attn_v_b = load_dequantized_tensor (translated_key .replace ("self_attn.kv_b_proj" , "attn_v_b" ), device = device ).to (dtype = target_dtype )
132+ kv_b_proj = torch .cat ((attn_k_b , attn_v_b ), dim = 1 )
133+ set_param (module , name , kv_b_proj )
134+ del attn_k_b
135+ del attn_v_b
136+ else :
137+ weights = load_dequantized_tensor (translated_key , device = device ).to (dtype = target_dtype )
138+ set_param (module , name , weights )
139+ del weights
131140 else :
132141 #print(load_config.tensor_file_map.keys())
133142 raise Exception (f"can't find { translated_key } in GGUF file!" )
0 commit comments