16
16
# At the start of the ggml file we write the model parameters
17
17
# and vocabulary.
18
18
#
19
- import os
19
+ import argparse
20
20
import sys
21
21
import json
22
22
import struct
23
23
import numpy as np
24
24
import torch
25
25
from sentencepiece import SentencePieceProcessor
26
26
27
- if len (sys .argv ) < 3 :
28
- print ("Usage: convert-ckpt-to-ggml.py dir-model ftype\n " )
29
- print (" ftype == 0 -> float32" )
30
- print (" ftype == 1 -> float16" )
31
- sys .exit (1 )
27
+ def parse_args ():
32
28
33
- # output in the same directory as the model
34
- dir_model = sys .argv [1 ]
35
-
36
- fname_hparams = sys .argv [1 ] + "/params.json"
37
- fname_tokenizer = sys .argv [1 ] + "/../tokenizer.model"
29
+ parser = argparse .ArgumentParser (description = 'Convert a LLaMA model checkpoint to a ggml compatible file' )
30
+ parser .add_argument ('dir_model' , help = 'directory containing the model checkpoint' )
31
+ parser .add_argument ('ftype' , type = int , choices = [0 , 1 ], default = 1 , help = 'file type (0: float32, 1: float16)' )
32
+ return parser .parse_args ()
38
33
39
34
def get_n_parts (dim ):
40
- if dim == 4096 :
41
- return 1
42
- elif dim == 5120 :
43
- return 2
44
- elif dim == 6656 :
45
- return 4
46
- elif dim == 8192 :
47
- return 8
48
- else :
49
- print ("Invalid dim: " + str (dim ))
50
- sys .exit (1 )
51
35
52
- # possible data types
53
- # ftype == 0 -> float32
54
- # ftype == 1 -> float16
55
- #
56
- # map from ftype to string
57
- ftype_str = ["f32" , "f16" ]
58
-
59
- ftype = 1
60
- if len (sys .argv ) > 2 :
61
- ftype = int (sys .argv [2 ])
62
- if ftype < 0 or ftype > 1 :
63
- print ("Invalid ftype: " + str (ftype ))
36
+ mappings = {4096 : 1 , 5120 : 2 , 6656 : 4 , 8192 : 8 }
37
+ n_parts = mappings .get (dim )
38
+ if n_parts is None :
39
+ print (f"Invalid dim: { dim } " )
64
40
sys .exit (1 )
65
- fname_out = sys .argv [1 ] + "/ggml-model-" + ftype_str [ftype ] + ".bin"
66
-
67
- if os .path .exists (fname_out ):
68
- print (f"Skip conversion, it already exists: { fname_out } " )
69
- sys .exit (0 )
70
41
71
- with open ( fname_hparams , "r" ) as f :
72
- hparams = json . load ( f )
42
+ print ( f"n_parts = { n_parts } \n " )
43
+ return n_parts
73
44
74
- tokenizer = SentencePieceProcessor ( fname_tokenizer )
45
+ def load_hparams_and_tokenizer ( dir_model ):
75
46
76
- hparams .update ({"vocab_size" : tokenizer .vocab_size ()})
47
+ fname_hparams = f"{ dir_model } /params.json"
48
+ fname_tokenizer = f"{ dir_model } /../tokenizer.model"
77
49
78
- n_parts = get_n_parts (hparams ["dim" ])
50
+ with open (fname_hparams , "r" ) as f :
51
+ hparams = json .load (f )
52
+ print (hparams )
79
53
80
- print ( hparams )
81
- print ( 'n_parts = ' , n_parts )
54
+ tokenizer = SentencePieceProcessor ( fname_tokenizer )
55
+ hparams . update ({ "vocab_size" : tokenizer . vocab_size ()} )
82
56
83
- for p in range (n_parts ):
84
- print ('Processing part ' , p )
57
+ return hparams , tokenizer
85
58
86
- #fname_model = sys.argv[1] + "/consolidated.00.pth"
87
- fname_model = sys .argv [1 ] + "/consolidated.0" + str (p ) + ".pth"
88
- fname_out = sys .argv [1 ] + "/ggml-model-" + ftype_str [ftype ] + ".bin"
89
- if (p > 0 ):
90
- fname_out = sys .argv [1 ] + "/ggml-model-" + ftype_str [ftype ] + ".bin" + "." + str (p )
59
+ def write_header (fout , hparams , ftype ):
91
60
92
- model = torch .load (fname_model , map_location = "cpu" )
61
+ keys = ["vocab_size" , "dim" , "multiple_of" , "n_heads" , "n_layers" ]
62
+ values = [
63
+ 0x67676d66 , # magic: ggml in hex
64
+ 1 , # file version
65
+ * [hparams [key ] for key in keys ],
66
+ hparams ["dim" ] // hparams ["n_heads" ], # rot (obsolete)
67
+ ftype
68
+ ]
69
+ fout .write (struct .pack ("i" * len (values ), * values ))
93
70
94
- fout = open ( fname_out , "wb" )
71
+ def write_tokens ( fout , tokenizer ):
95
72
96
- fout .write (struct .pack ("i" , 0x67676d6c )) # magic: ggml in hex
97
- fout .write (struct .pack ("i" , hparams ["vocab_size" ]))
98
- fout .write (struct .pack ("i" , hparams ["dim" ]))
99
- fout .write (struct .pack ("i" , hparams ["multiple_of" ]))
100
- fout .write (struct .pack ("i" , hparams ["n_heads" ]))
101
- fout .write (struct .pack ("i" , hparams ["n_layers" ]))
102
- fout .write (struct .pack ("i" , hparams ["dim" ] // hparams ["n_heads" ])) # rot (obsolete)
103
- fout .write (struct .pack ("i" , ftype ))
104
-
105
- # Is this correct??
106
73
for i in range (tokenizer .vocab_size ()):
107
74
if tokenizer .is_unknown (i ):
108
- # "<unk>" token (translated as ??)
109
75
text = " \u2047 " .encode ("utf-8" )
110
- fout .write (struct .pack ("i" , len (text )))
111
- fout .write (text )
112
76
elif tokenizer .is_control (i ):
113
- # "<s>"/"</s>" tokens
114
- fout .write (struct .pack ("i" , 0 ))
77
+ text = b""
115
78
elif tokenizer .is_byte (i ):
116
- # "<U+XX>" tokens (which may be invalid UTF-8)
117
79
piece = tokenizer .id_to_piece (i )
118
80
if len (piece ) != 6 :
119
- print ("Invalid token: " + piece )
81
+ print (f "Invalid token: { piece } " )
120
82
sys .exit (1 )
121
83
byte_value = int (piece [3 :- 1 ], 16 )
122
- fout .write (struct .pack ("i" , 1 ))
123
- fout .write (struct .pack ("B" , byte_value ))
84
+ text = struct .pack ("B" , byte_value )
124
85
else :
125
- # normal token. Uses U+2581 (LOWER ONE EIGHTH BLOCK) to represent spaces.
126
86
text = tokenizer .id_to_piece (i ).replace ("\u2581 " , " " ).encode ("utf-8" )
127
- fout .write (struct .pack ("i" , len (text )))
128
- fout .write (text )
87
+ fout .write (struct .pack ("i" , len (text )))
88
+ fout .write (text )
89
+ fout .write (struct .pack ("f" , tokenizer .get_score (i )))
129
90
130
- for k , v in model .items ():
131
- name = k
132
- shape = v .shape
91
+ def process_and_write_variables (fout , model , ftype ):
133
92
134
- # skip layers.X.attention.inner_attention.rope.freqs
135
- if name [- 5 :] == "freqs" :
136
- continue
93
+ for name , datao in model .items ():
137
94
138
- print ("Processing variable: " + name + " with shape: " , shape , " and type: " , v .dtype )
95
+ if name .endswith ("freqs" ):
96
+ continue
139
97
140
- #data = tf.train.load_variable(dir_model, name).squeeze()
141
- data = v .numpy ().squeeze ()
142
- n_dims = len (data .shape );
98
+ shape = datao .shape
143
99
144
- # for efficiency - transpose some matrices
145
- # "model/h.*/attn/c_attn/w"
146
- # "model/h.*/attn/c_proj/w"
147
- # "model/h.*/mlp/c_fc/w"
148
- # "model/h.*/mlp/c_proj/w"
149
- #if name[-14:] == "/attn/c_attn/w" or \
150
- # name[-14:] == "/attn/c_proj/w" or \
151
- # name[-11:] == "/mlp/c_fc/w" or \
152
- # name[-13:] == "/mlp/c_proj/w":
153
- # print(" Transposing")
154
- # data = data.transpose()
100
+ print (f"Processing variable: { name } with shape: { shape } and type: { datao .dtype } " )
155
101
156
- dshape = data .shape
102
+ data = datao .numpy ().squeeze ()
103
+ n_dims = len (shape )
157
104
158
105
# default type is fp16
159
106
ftype_cur = 1
@@ -164,18 +111,40 @@ def get_n_parts(dim):
164
111
165
112
# header
166
113
sname = name .encode ('utf-8' )
167
- fout .write (struct .pack ("iii" , n_dims , len (sname ), ftype_cur ))
168
- for i in range ( n_dims ):
169
- fout .write (struct .pack ("i" , dshape [ n_dims - 1 - i ] ))
170
- fout .write (sname );
114
+ fout .write (struct .pack ("iii" , len ( data . shape ) , len (sname ), ftype_cur ))
115
+ for dim in reversed ( data . shape ):
116
+ fout .write (struct .pack ("i" , dim ))
117
+ fout .write (sname )
171
118
172
- # data
119
+ # data output to file
173
120
data .tofile (fout )
174
121
175
- # I hope this deallocates the memory ..
176
- model = None
122
+ def main ():
123
+
124
+ args = parse_args ()
125
+ dir_model = args .dir_model
126
+ ftype = args .ftype
127
+ ftype_str = ["f32" , "f16" ]
128
+
129
+ hparams , tokenizer = load_hparams_and_tokenizer (dir_model )
130
+ n_parts = get_n_parts (hparams ["dim" ])
131
+
132
+ for p in range (n_parts ):
133
+
134
+ print (f"Processing part { p } \n " )
135
+
136
+ fname_model = f"{ dir_model } /consolidated.0{ p } .pth"
137
+ fname_out = f"{ dir_model } /ggml-model-{ ftype_str [ftype ]} .bin{ '' if p == 0 else '.' + str (p )} "
138
+
139
+ model = torch .load (fname_model , map_location = "cpu" )
140
+
141
+ with open (fname_out , "wb" ) as fout :
142
+ write_header (fout , hparams , ftype )
143
+ write_tokens (fout , tokenizer )
144
+ process_and_write_variables (fout , model , ftype )
177
145
178
- fout .close ()
146
+ del model
147
+ print (f"Done. Output file: { fname_out } , (part { p } )\n " )
179
148
180
- print ( "Done. Output file: " + fname_out + ", (part " , p , ")" )
181
- print ( "" )
149
+ if __name__ == "__main__" :
150
+ main ( )
0 commit comments