Skip to content

Commit

Permalink
llama : support InternLM2 (ggerganov#5184)
Browse files Browse the repository at this point in the history
* support InternLM2 inference
  * add add_space_prefix KV pair
  • Loading branch information
SolenoidWGT authored Feb 1, 2024
1 parent 1cfb537 commit ce32060
Show file tree
Hide file tree
Showing 5 changed files with 387 additions and 5 deletions.
152 changes: 152 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def from_model_architecture(model_architecture):
return CodeShellModel
if model_architecture == "OrionForCausalLM":
return OrionModel
if model_architecture == "InternLM2ForCausalLM":
return InternLM2Model
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -254,6 +256,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.CODESHELL
if arch == "OrionForCausalLM":
return gguf.MODEL_ARCH.ORION
if arch == "InternLM2ForCausalLM":
return gguf.MODEL_ARCH.INTERNLM2

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -1344,6 +1348,154 @@ def write_tensors(self):
self.gguf_writer.add_tensor("output.weight", data)
print(name, f"=> output.weight, shape = {data.shape}, {old_dtype} --> {data.dtype}")


class InternLM2Model(Model):
def set_vocab(self):
# (TODO): Is there a better way?
# Copy from _set_vocab_sentencepiece, The only difference is that we will treat the character
# \x00 specially and convert it into an emoji character to prevent it from being mistakenly
# recognized as an empty string in C++.
from sentencepiece import SentencePieceProcessor
from sentencepiece import sentencepiece_model_pb2 as model

tokenizer_path = self.dir_model / 'tokenizer.model'

tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []

if not tokenizer_path.is_file():
print(f'Error: Missing {tokenizer_path}', file=sys.stderr)
sys.exit(1)

sentencepiece_model = model.ModelProto()
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix

tokenizer = SentencePieceProcessor(str(tokenizer_path))
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())

for token_id in range(vocab_size):
piece = tokenizer.id_to_piece(token_id)
text = piece.encode("utf-8")
score = tokenizer.get_score(token_id)
if text == b"\x00":
# (TODO): fixme
# Hack here and replace the \x00 characters.
print(f"InternLM2 convert token '{text}' to '🐉'!")
text = "🐉"

toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.is_unknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.is_control(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.is_unused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE

tokens.append(text)
scores.append(score)
toktypes.append(toktype)

added_tokens_file = self.dir_model / 'added_tokens.json'
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)

for key in added_tokens_json:
tokens.append(key.encode("utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)

self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_add_space_prefix(add_prefix)

special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
self.gguf_writer.add_name("InternLM2")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])

def post_write_tensors(self, tensor_map, name, data_torch):
old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

data = data_torch.squeeze().numpy()

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert any float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)

def write_tensors(self):
from einops import rearrange

num_heads = self.hparams.get("num_attention_heads")
num_kv_heads = self.hparams.get("num_key_value_heads")
hidden_size = self.hparams.get("hidden_size")
q_per_kv = num_heads // num_kv_heads
head_dim = hidden_size // num_heads
num_groups = num_heads // q_per_kv

block_count = self.hparams["num_hidden_layers"]
model_kv = dict(self.get_tensors())
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
qkv_pattern = r"model\.layers\.(\d+)\.attention\.wqkv"
for name, data_torch in model_kv.items():
# we don't need these
if name.endswith(".rotary_emb.inv_freq"):
continue

if re.match(qkv_pattern, name):
bid = re.findall(qkv_pattern, name)[0]
qkv = data_torch
qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
q = rearrange(q, " o g n i -> o (g n i)").T
k = rearrange(k, " o g n i -> o (g n i)").T
v = rearrange(v, " o g n i -> o (g n i)").T
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wv.weight", v)
else:
self.post_write_tensors(tensor_map, name, data_torch)


###### CONVERSION LOGIC ######


Expand Down
18 changes: 18 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class Tokenizer:
PAD_ID = "tokenizer.ggml.padding_token_id"
ADD_BOS = "tokenizer.ggml.add_bos_token"
ADD_EOS = "tokenizer.ggml.add_eos_token"
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
HF_JSON = "tokenizer.huggingface.json"
RWKV = "tokenizer.rwkv.world"
CHAT_TEMPLATE = "tokenizer.chat_template"
Expand Down Expand Up @@ -102,6 +103,7 @@ class MODEL_ARCH(IntEnum):
PLAMO = auto()
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -153,6 +155,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.PLAMO: "plamo",
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -446,6 +449,21 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.INTERNLM2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,9 @@ def add_add_bos_token(self, value: bool) -> None:
def add_add_eos_token(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_EOS, value)

def add_add_space_prefix(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)

def add_chat_template(self, value: str) -> None:
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)

Expand Down
14 changes: 12 additions & 2 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class TensorNameMap:
"language_model.embedding.word_embeddings", # persimmon
"wte", # gpt2
"transformer.embd.wte", # phi2
"model.tok_embeddings", # internlm2
),

# Token type embeddings
Expand All @@ -42,7 +43,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
"output", # llama-pth bloom
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
),
Expand All @@ -51,7 +52,7 @@ class TensorNameMap:
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon
"model.norm", # llama-hf baichuan
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt
Expand Down Expand Up @@ -84,6 +85,7 @@ class TensorNameMap:
"h.{bid}.ln_1", # gpt2
"transformer.h.{bid}.ln", # phi2
"model.layers.layers.{bid}.norm", # plamo
"model.layers.{bid}.attention_norm", # internlm2
),

# Attention norm 2
Expand Down Expand Up @@ -111,6 +113,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.query", # bert
"transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq" # internlm2
),

# Attention key
Expand All @@ -120,6 +123,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.key", # bert
"transformer.h.{bid}.attn.k_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk" # internlm2
),

# Attention value
Expand All @@ -129,6 +133,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.value", # bert
"transformer.h.{bid}.attn.v_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
"model.layers.{bid}.attention.wv" # internlm2
),

# Attention output
Expand All @@ -147,6 +152,7 @@ class TensorNameMap:
"h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
"model.layers.{bid}.attention.wo", # internlm2
),

# Rotary embeddings
Expand All @@ -169,6 +175,7 @@ class TensorNameMap:
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
"model.layers.{bid}.ln2", # yi
"h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
),

MODEL_TENSOR.FFN_GATE_INP: (
Expand All @@ -194,6 +201,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.fc1", # phi2
"model.layers.{bid}.mlp.fc1", # phi2
"model.layers.layers.{bid}.mlp.up_proj", # plamo
"model.layers.{bid}.feed_forward.w3", # internlm2
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand All @@ -212,6 +220,7 @@ class TensorNameMap:
"layers.{bid}.feed_forward.w1", # llama-pth
"transformer.h.{bid}.mlp.w2", # qwen
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
"model.layers.{bid}.feed_forward.w1", # internlm2
),

MODEL_TENSOR.FFN_GATE_EXP: (
Expand All @@ -236,6 +245,7 @@ class TensorNameMap:
"transformer.h.{bid}.mlp.fc2", # phi2
"model.layers.{bid}.mlp.fc2", # phi2
"model.layers.layers.{bid}.mlp.down_proj", # plamo
"model.layers.{bid}.feed_forward.w2", # internlm2
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down
Loading

0 comments on commit ce32060

Please sign in to comment.