Skip to content

Commit cc367f5

Browse files
authored
Merge pull request #1351 from kvcache-ai/load-DeepSeek-0528
[Patch] load DeepSeek-R1-0528 and enable CPU GGUF dequant when GPU dequant is not implemented
2 parents ac48a58 + a6b3243 commit cc367f5

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

ktransformers/util/custom_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,11 @@ def load_gguf_tensor(self, name: str, device:str = "cpu", target_dtype = None)->
446446
blocks_begin = i * blocks_per_iter
447447
blocks_end = min(blocks_begin + blocks_per_iter, num_blocks)
448448
if "cuda" in device.lower():
449-
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
449+
try:
450+
cur_values = GGML_DEQUANTIZE_GPU[ggml_name](data[blocks_begin*block_size : blocks_end*block_size], device, target_dtype)
451+
except:
452+
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
453+
cur_values = torch.from_numpy(cur_values.copy()).to(device)
450454
else:
451455
cur_values = GGML_DEQUANTIZE[ggml_name](data[blocks_begin*block_size : blocks_end*block_size])
452456
cur_values = torch.from_numpy(cur_values.copy())

ktransformers/util/utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,26 @@ def load_cur_state_dict(module: nn.Module, gguf_loader: ModelLoader, prefix: str
117117
load_dequantized_tensor = gguf_loader.load_gguf_tensor
118118
tensor_file_map = gguf_loader.tensor_file_map
119119

120-
if gguf_loader.has_tensor(translated_key):
120+
if gguf_loader.has_tensor(translated_key) or "kv_b_proj" in translated_key:
121121
target_dtype = torch.get_default_dtype()
122122
device = get_device(translated_key[:translated_key.rfind(".")], gguf_loader.tensor_device_map)
123123
print(f"loading {translated_key} to {device}")
124124
if torch.cuda.is_available():
125125
torch.cuda.empty_cache()
126126
elif torch.xpu.is_available():
127127
torch.xpu.empty_cache()
128-
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
129-
set_param(module, name, weights)
130-
del weights
128+
if "kv_b_proj" in translated_key and not gguf_loader.has_tensor(translated_key):
129+
attn_k_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_k_b"), device=device).to(dtype=target_dtype)
130+
attn_k_b = attn_k_b.transpose(1, 2).contiguous()
131+
attn_v_b = load_dequantized_tensor(translated_key.replace("self_attn.kv_b_proj", "attn_v_b"), device=device).to(dtype=target_dtype)
132+
kv_b_proj = torch.cat((attn_k_b, attn_v_b), dim=1)
133+
set_param(module, name, kv_b_proj)
134+
del attn_k_b
135+
del attn_v_b
136+
else:
137+
weights = load_dequantized_tensor(translated_key, device=device).to(dtype=target_dtype)
138+
set_param(module, name, weights)
139+
del weights
131140
else:
132141
#print(load_config.tensor_file_map.keys())
133142
raise Exception(f"can't find {translated_key} in GGUF file!")

0 commit comments

Comments
 (0)