diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 754edf014e1dd..6bb709951630c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1423,6 +1423,20 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter "o_proj.weight")): data_torch = self.weight_quant(data_torch) + # pad 1D tensors + # TODO: is padding with 0s an invariant, or do we also need some scaling factor? + if name.endswith(("input_layernorm.weight", "post_attention_layernorm.weight", "model.norm.weight")): + data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(0) % 256), mode='constant', value=0) + logger.info(f"pad {name} to {data_torch.size()}") + + # pad 2D tensors + # TODO: double-check that this is the correct way to pad the rows + if name.endswith(("embed_tokens.weight", "q_proj.weight", "k_proj.weight", "v_proj.weight", + "down_proj.weight", "up_proj.weight", "gate_proj.weight", + "o_proj.weight")): + data_torch = torch.nn.functional.pad(data_torch, (0, 256 - data_torch.size(1) % 256), mode='constant', value=0) + logger.info(f"pad {name} to {data_torch.size()}") + return [(self.map_tensor_name(name), data_torch)] diff --git a/llama.cpp b/llama.cpp index 16ae07dd33ab6..085a5a236df0b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2021,9 +2021,9 @@ struct llama_layer { struct ggml_tensor * wkv_b; // attention bias - struct ggml_tensor * bq; - struct ggml_tensor * bk; - struct ggml_tensor * bv; + struct ggml_tensor * bq = nullptr; + struct ggml_tensor * bk = nullptr; + struct ggml_tensor * bv = nullptr; struct ggml_tensor * bo; struct ggml_tensor * bqkv; @@ -6440,14 +6440,18 @@ static bool llm_load_tensors( } break; case LLM_ARCH_BITNET: { - model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + const uint32_t n_ff = hparams.n_ff; + const uint32_t n_ff_pad = GGML_PAD(n_ff, 256); + + const int64_t n_embd_pad = GGML_PAD(n_embd, 256); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd_pad, n_vocab}); // output { - model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd_pad}); } - const uint32_t n_ff = hparams.n_ff; model.layers.resize(n_layer); for (int i = 0; i < n_layer; ++i) { @@ -6456,20 +6460,20 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; - layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd_pad}); layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); - layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_pad, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_pad, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_pad, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_pad, n_embd}); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd_pad}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); - layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd_pad, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_pad, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd_pad, n_ff}); } } break; default: @@ -11658,6 +11662,7 @@ struct llm_build_context { ggml_build_forward_expand(graph, cur_attn); + cur_attn = ggml_pad(ctx0, cur_attn, (256 - cur_attn->ne[0] % 256) % 256, 0, 0, 0); cur = ggml_mul_mat(ctx0, wo, cur_attn); cb(cur, "kqv_out", il); @@ -11670,6 +11675,8 @@ struct llm_build_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } + cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); cb(ffn_inp, "ffn_inp", il); @@ -11680,8 +11687,8 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_norm", il); - struct ggml_tensor *tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); - + struct ggml_tensor * tmp = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur); + cb(tmp, "ffn_up", il); cur = ggml_mul_mat(ctx0, model.layers[il].ffn_gate, cur); @@ -11700,9 +11707,14 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "ffn_sub_norm", il); + cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0); + cur = ggml_mul_mat(ctx0, model.layers[il].ffn_down, cur); cb(cur, "ffn_down", il); } + + cur = ggml_pad(ctx0, cur, (256 - cur->ne[0] % 256) % 256, 0, 0, 0); + cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "l_out", il);