Skip to content

Commit e57dc62

Browse files
pcullitonabetlenslaren
authored
llama: Add support for Gemma2ForCausalLM (ggml-org#8156)
* Inference support for Gemma 2 model family * Update convert-hf-to-gguf.py, constants, and tensor mappings * cleanup * format fix * Fix special token vocab bug * Don't add space prefix * fix deleted lines * Update src/llama.cpp Co-authored-by: slaren <[email protected]> * Add model type names * Add control vector * Fix model type identification --------- Co-authored-by: Andrei Betlen <[email protected]> Co-authored-by: slaren <[email protected]>
1 parent a27aa50 commit e57dc62

File tree

4 files changed

+274
-1
lines changed

4 files changed

+274
-1
lines changed

convert-hf-to-gguf.py

+40
Original file line numberDiff line numberDiff line change
@@ -2340,6 +2340,46 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
23402340
return [(self.map_tensor_name(name), data_torch)]
23412341

23422342

2343+
@Model.register("Gemma2ForCausalLM")
2344+
class Gemma2Model(Model):
2345+
model_arch = gguf.MODEL_ARCH.GEMMA2
2346+
2347+
def set_vocab(self):
2348+
self._set_vocab_llama_hf()
2349+
self.gguf_writer.add_add_space_prefix(False)
2350+
2351+
def set_gguf_parameters(self):
2352+
hparams = self.hparams
2353+
block_count = hparams["num_hidden_layers"]
2354+
2355+
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
2356+
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
2357+
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
2358+
self.gguf_writer.add_block_count(block_count)
2359+
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
2360+
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
2361+
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
2362+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
2363+
self.gguf_writer.add_key_length(hparams["head_dim"])
2364+
self.gguf_writer.add_value_length(hparams["head_dim"])
2365+
self.gguf_writer.add_file_type(self.ftype)
2366+
2367+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2368+
del bid # unusem
2369+
2370+
# lm_head is not used in llama.cpp, while autoawq will include this tensor in model
2371+
# To prevent errors, skip loading lm_head.weight.
2372+
if name == "lm_head.weight":
2373+
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
2374+
return []
2375+
2376+
# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
2377+
if name.endswith("norm.weight"):
2378+
data_torch = data_torch + 1
2379+
2380+
return [(self.map_tensor_name(name), data_torch)]
2381+
2382+
23432383
@Model.register("Starcoder2ForCausalLM")
23442384
class StarCoder2Model(Model):
23452385
model_arch = gguf.MODEL_ARCH.STARCODER2

gguf-py/gguf/constants.py

+23
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ class MODEL_ARCH(IntEnum):
150150
INTERNLM2 = auto()
151151
MINICPM = auto()
152152
GEMMA = auto()
153+
GEMMA2 = auto()
153154
STARCODER2 = auto()
154155
MAMBA = auto()
155156
XVERSE = auto()
@@ -180,10 +181,13 @@ class MODEL_TENSOR(IntEnum):
180181
ATTN_NORM = auto()
181182
ATTN_NORM_2 = auto()
182183
ATTN_OUT_NORM = auto()
184+
ATTN_POST_NORM = auto()
183185
ATTN_ROT_EMBD = auto()
184186
FFN_GATE_INP = auto()
185187
FFN_GATE_INP_SHEXP = auto()
186188
FFN_NORM = auto()
189+
FFN_PRE_NORM = auto()
190+
FFN_POST_NORM = auto()
187191
FFN_GATE = auto()
188192
FFN_DOWN = auto()
189193
FFN_UP = auto()
@@ -270,6 +274,7 @@ class MODEL_TENSOR(IntEnum):
270274
MODEL_ARCH.INTERNLM2: "internlm2",
271275
MODEL_ARCH.MINICPM: "minicpm",
272276
MODEL_ARCH.GEMMA: "gemma",
277+
MODEL_ARCH.GEMMA2: "gemma2",
273278
MODEL_ARCH.STARCODER2: "starcoder2",
274279
MODEL_ARCH.MAMBA: "mamba",
275280
MODEL_ARCH.XVERSE: "xverse",
@@ -303,9 +308,12 @@ class MODEL_TENSOR(IntEnum):
303308
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
304309
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
305310
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
311+
MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
306312
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
307313
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
308314
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
315+
MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
316+
MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
309317
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
310318
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
311319
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
@@ -751,6 +759,21 @@ class MODEL_TENSOR(IntEnum):
751759
MODEL_TENSOR.FFN_UP,
752760
MODEL_TENSOR.FFN_NORM,
753761
],
762+
MODEL_ARCH.GEMMA2: [
763+
MODEL_TENSOR.TOKEN_EMBD,
764+
MODEL_TENSOR.OUTPUT_NORM,
765+
MODEL_TENSOR.ATTN_Q,
766+
MODEL_TENSOR.ATTN_K,
767+
MODEL_TENSOR.ATTN_V,
768+
MODEL_TENSOR.ATTN_OUT,
769+
MODEL_TENSOR.FFN_GATE,
770+
MODEL_TENSOR.FFN_DOWN,
771+
MODEL_TENSOR.FFN_UP,
772+
MODEL_TENSOR.ATTN_NORM,
773+
MODEL_TENSOR.ATTN_POST_NORM,
774+
MODEL_TENSOR.FFN_PRE_NORM,
775+
MODEL_TENSOR.FFN_POST_NORM,
776+
],
754777
MODEL_ARCH.STARCODER2: [
755778
MODEL_TENSOR.TOKEN_EMBD,
756779
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

+14
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@ class TensorNameMap:
187187
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
188188
),
189189

190+
MODEL_TENSOR.ATTN_POST_NORM: (
191+
"model.layers.{bid}.post_attention_layernorm", # gemma2
192+
),
193+
190194
# Rotary embeddings
191195
MODEL_TENSOR.ATTN_ROT_EMBD: (
192196
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
@@ -210,6 +214,16 @@ class TensorNameMap:
210214
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
211215
),
212216

217+
# Post feed-forward norm
218+
MODEL_TENSOR.FFN_PRE_NORM: (
219+
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
220+
),
221+
222+
# Post feed-forward norm
223+
MODEL_TENSOR.FFN_POST_NORM: (
224+
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
225+
),
226+
213227
MODEL_TENSOR.FFN_GATE_INP: (
214228
"layers.{bid}.feed_forward.gate", # mixtral
215229
"model.layers.{bid}.block_sparse_moe.gate", # mixtral

0 commit comments

Comments
 (0)