Skip to content

Commit

Permalink
Llama-3_1-Nemotron-51B support
Browse files Browse the repository at this point in the history
  • Loading branch information
ymcki committed Dec 3, 2024
1 parent ef22490 commit 31d5981
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 21 deletions.
85 changes: 81 additions & 4 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,6 +1520,61 @@ def prepare_tensors(self):
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

@staticmethod
def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
# DeciLM-specific code
intermediate_size = int(2 * ffn_mult * n_embd / 3)
return LlamaModel._find_multiple(intermediate_size, 256)

@staticmethod
def _find_multiple(n: int, k: int) -> int:
# DeciLM-specific code
if n % k == 0:
return n
return n + k - (n % k)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if "block_configs" in self.hparams: # Llama-3.1-51B
_block_configs: list[dict[str,Any]] = self.hparams["block_configs"]
assert self.block_count == len(_block_configs)
self._num_kv_heads = list()
self._num_heads = list()
_ffn_multipliers = list()
# ***linear attention layer***
# if n_heads_in_group is None and replace_with_linear is True
# then _num_kv_heads[il] is 0 and _num_heads[il] is num_attention_heads
# ***attention-free layer***
# if n_heads_in_group is None and replace_with_linear is False
# then _num_kv_heads[il] is 0 and _num_heads[il] is 0
# ***normal attention-layer***
# if n_heads_in_group is not None, then
# _num_kv_heads[il] is num_attention_head // n_heads_in_group and
# _num_heads[il] is num_attention_head
for il in range(len(_block_configs)):
if _block_configs[il]["attention"]["n_heads_in_group"] is None:
if _block_configs[il]["attention"]["replace_with_linear"] is True:
self._num_kv_heads.append(0)
self._num_heads.append(self.hparams["num_attention_heads"])
else:
self._num_kv_heads.append(0)
self._num_heads.append(0)
else:
self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"])
self._num_heads.append(self.hparams["num_attention_heads"])
_ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"])
assert self.block_count == len(self._num_kv_heads)
assert self.block_count == len(self._num_heads)
assert self.block_count == len(_ffn_multipliers)
assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int)
assert isinstance(self._num_heads, list) and isinstance(self._num_heads[0], int)
assert isinstance(_ffn_multipliers, list) and isinstance(_ffn_multipliers[0], float)
self._ffn_dims: list[int] = [
LlamaModel._ffn_mult_to_intermediate_size(multiplier, self.hparams["hidden_size"])
for multiplier in _ffn_multipliers
]

def set_vocab(self):
try:
self._set_vocab_sentencepiece()
Expand Down Expand Up @@ -1554,10 +1609,26 @@ def set_vocab(self):
self.gguf_writer.add_add_bos_token(False)

def set_gguf_parameters(self):
super().set_gguf_parameters()
if "block_configs" in self.hparams: # Llama-3.1-51B
assert self.block_count == len(self._num_kv_heads)
assert self.block_count == len(self._num_heads)
assert self.block_count == len(self._ffn_dims)
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
self.gguf_writer.add_head_count(self._num_heads)
self.gguf_writer.add_feed_forward_length(self._ffn_dims)
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_key_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_value_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_file_type(self.ftype)
else:
super().set_gguf_parameters()

hparams = self.hparams
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "num_key_value_heads_per_layer" in hparams:
if "num_key_value_heads_per_layer" in hparams: # DeciLM-7B
self._num_kv_heads: list[int] = hparams["num_key_value_heads_per_layer"]
assert self.block_count == len(self._num_kv_heads)
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
Expand Down Expand Up @@ -1585,8 +1656,14 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
if bid is not None and "num_key_value_heads_per_layer" in self.hparams:
n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid]
if bid is not None:
if "num_key_value_heads_per_layer" in self.hparams:
n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid]
elif "block_configs" in self.hparams:
n_kv_head = self._num_kv_heads[bid]
n_head = self._num_heads[bid]
else:
n_kv_head = self.hparams.get("num_key_value_heads")
else:
n_kv_head = self.hparams.get("num_key_value_heads")

Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class TensorNameMap:
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo_1124
"model.layers.{bid}.self_attn.linear_attn", # deci
"layers.{bid}.attention.wo", # llama-pth
"encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j
Expand Down
12 changes: 8 additions & 4 deletions gguf-py/gguf/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def _set_special_token(self, typ: str, tid: Any) -> None:
if tid < 0:
raise ValueError(f'invalid value for special token type {typ}: {tid}')
if self.n_vocab is None or tid < self.n_vocab:
if typ in self.special_token_ids:
return
self.special_token_ids[typ] = tid
self.special_token_ids[typ] = tid # allow override
return
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')

Expand Down Expand Up @@ -188,7 +186,13 @@ def _try_load_from_config_json(self, path: Path) -> bool:
with open(config_file, encoding = 'utf-8') as f:
config = json.load(f)
for typ in self.special_token_types:
self._set_special_token(typ, config.get(f'{typ}_token_id'))
if typ == 'eos' and isinstance(config.get(f'{typ}_token_id'), list):
eos_ids = config.get(f'{typ}_token_id')
self._set_special_token('eos', eos_ids[0])
self._set_special_token('eom', eos_ids[1])
self._set_special_token('eot', eos_ids[2])
else:
self._set_special_token(typ, config.get(f'{typ}_token_id'))
return True


Expand Down
49 changes: 36 additions & 13 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7609,13 +7609,23 @@ static bool llm_load_tensors(
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i);
const int64_t n_ff = hparams.n_ff(i);
const int64_t n_head = hparams.n_head(i);
const int64_t n_head_kv = hparams.n_head_kv(i);

if (n_head_kv == 0 && n_head > 0) {
// linear attention for DeciLMCausalModel
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
}
else if (n_head_kv > 0) {
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
}

// optional bias tensors
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED);
Expand Down Expand Up @@ -10715,13 +10725,22 @@ struct llm_build_context {
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_head = hparams.n_head(il);

// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
if (n_head == 0) // attention-free layer of Llama-3_1-Nemotron-51B
cur = inpL;
else {
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
}

if (n_head > 0 && n_head_kv == 0) { // "linear attention" of Llama-3_1-Nemotron-51B
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
cb(cur, "wo", il);
} else if (n_head > 0)
// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
Expand Down Expand Up @@ -10781,8 +10800,12 @@ struct llm_build_context {
cur = ggml_scale(ctx0, cur, hparams.f_residual_scale);
}

struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// modified to support attention-free layer of Llama-3_1-Nemotron-51B
struct ggml_tensor * ffn_inp = cur;
if (n_head > 0) {
ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
}

// feed-forward network
if (model.layers[il].ffn_gate_inp == nullptr) {
Expand Down

0 comments on commit 31d5981

Please sign in to comment.