Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 0906a11

Browse files
metascroymalfet
authored andcommitted
add et export with gguf with test (#245)
* add et export with gguf with test * fix generate too * add gguf path to generate
1 parent a45e86e commit 0906a11

File tree

5 files changed

+99
-23
lines changed

5 files changed

+99
-23
lines changed

.github/workflows/et.yml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ jobs:
6060
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
6161
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
6262
popd
63+
64+
mkdir gguf_files
65+
export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf
66+
export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model
67+
wget -O ${GGUF_PATH} "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true"
68+
wget -O ${GGUF_TOKENIZER_PATH} https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
69+
6370
- name: Run inference
6471
run: |
6572
export MODEL_PATH=${PWD}/checkpoints/stories15M/stories15M.pt
@@ -75,7 +82,7 @@ jobs:
7582
echo "Tests complete."
7683
7784
- name: Run inference
78-
run: |
85+
run: |
7986
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
8087
export MODEL_NAME=stories15M
8188
export MODEL_DIR=/tmp
@@ -121,3 +128,13 @@ jobs:
121128
echo "tests complete"
122129
echo "******************************************"
123130
131+
- name: Run GGUF export + inference
132+
run: |
133+
export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf
134+
export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model
135+
136+
python torchchat.py export --gguf-path ${GGUF_PATH} --output-pte-path ${PWD}/${MODEL_NAME}.pte
137+
python torchchat.py generate --gguf-path ${GGUF_PATH} --pte-path ${PWD}/${MODEL_NAME}.pte --tokenizer-path ${GGUF_TOKENIZER_PATH} --temperature 0 --max-new-tokens 20 > ${PWD}/output_et
138+
cat ${PWD}/output_et
139+
140+
echo "Tests complete."

build/builder.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010
from dataclasses import dataclass
1111
from pathlib import Path
12-
from typing import Optional, Union
12+
from typing import Any, Optional, Union
1313

1414
import torch
1515
import torch._dynamo.config
@@ -29,6 +29,7 @@ class BuilderArgs:
2929
params_path: Optional[Union[Path, str]] = None
3030
params_table: Optional[str] = None
3131
gguf_path: Optional[Union[Path, str]] = None
32+
gguf_kwargs: Optional[dict[str, Any]] = None
3233
dso_path: Optional[Union[Path, str]] = None
3334
pte_path: Optional[Union[Path, str]] = None
3435
device: str = "cpu"
@@ -91,6 +92,7 @@ def from_args(cls, args): # -> BuilderArgs:
9192
params_path=args.params_path,
9293
params_table=args.params_table,
9394
gguf_path=args.gguf_path,
95+
gguf_kwargs=None,
9496
dso_path=args.dso_path,
9597
pte_path=args.pte_path,
9698
device=args.device,
@@ -174,9 +176,30 @@ def device_sync(device):
174176
sys.path.append(str(wd))
175177

176178

179+
# TODO: remove these once ET supports _weight_int4pack_mm
180+
def _set_gguf_kwargs(builder_args, is_et, context: str):
181+
assert context in ["export", "generate"]
182+
assert builder_args.gguf_kwargs is None
183+
184+
if builder_args.gguf_path is None:
185+
print("No gguf_path provided, so ignoring set_gguf_kwargs.")
186+
return
187+
188+
builder_args.gguf_kwargs = {}
189+
if is_et:
190+
builder_args.gguf_kwargs["load_as_quantized"] = False
191+
192+
def _unset_gguf_kwargs(builder_args):
193+
builder_args.gguf_kwargs = None
194+
195+
177196
def _load_model_gguf(builder_args):
178197
assert builder_args.gguf_path
179-
model = Transformer.from_gguf(builder_args.gguf_path)
198+
if builder_args.gguf_kwargs is None:
199+
kwargs = {}
200+
else:
201+
kwargs = builder_args.gguf_kwargs
202+
model = Transformer.from_gguf(builder_args.gguf_path, **kwargs)
180203
return model
181204

182205

@@ -254,6 +277,15 @@ def _initialize_model(
254277
):
255278
print("Loading model ...")
256279
t0 = time.time()
280+
281+
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
282+
print("Setting gguf_kwargs for generate.")
283+
is_dso = builder_args.dso_path is not None
284+
is_pte = builder_args.pte_path is not None
285+
assert not (is_dso and is_pte)
286+
assert builder_args.gguf_kwargs is None
287+
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")
288+
257289
model_ = _load_model(builder_args)
258290
device_sync(device=builder_args.device)
259291
print(f"Time to load model: {time.time() - t0:.02f} seconds")

build/gguf_loader.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
139139
return model
140140

141141

142-
def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module:
142+
def load_model_and_state_dict(gguf_file: str, *, load_state_dict: bool = True, load_as_quantized: bool = True, inner_k_tiles = 8) -> torch.nn.Module:
143143
"""
144144
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
145145
that can be loaded into it.
@@ -174,14 +174,14 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_
174174
in_features = mod.in_features
175175
assert all(t.shape == (in_features, out_features))
176176

177-
q, s, z = Q4_0.unpack(t)
178-
scales_and_zeros = pack_scales_and_zeros(s, z)
179-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
180-
q, inner_k_tiles
181-
)
182-
183-
state_dict[f"{fqn}.weight"] = weight_int4pack
184-
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
177+
if load_state_dict:
178+
q, s, z = Q4_0.unpack(t)
179+
scales_and_zeros = pack_scales_and_zeros(s, z)
180+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
181+
q, inner_k_tiles
182+
)
183+
state_dict[f"{fqn}.weight"] = weight_int4pack
184+
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
185185

186186
parent = _fqn_lookup(_fqn_up(fqn), model)
187187
setattr(
@@ -197,8 +197,10 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_
197197
),
198198
)
199199
else:
200-
state_dict[f"{fqn}.weight"] = to_float(t)
200+
if load_state_dict:
201+
state_dict[f"{fqn}.weight"] = to_float(t)
201202

203+
assert (state_dict == {}) == (not load_state_dict)
202204
return model, state_dict
203205

204206

build/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,11 @@ def from_params(cls, params_path: str):
246246
return cls(ModelArgs.from_params(params_path))
247247

248248
@classmethod
249-
def from_gguf(cls, gguf_path: str):
249+
def from_gguf(cls, gguf_path: str, **kwargs):
250250
from build.gguf_loader import load_model_and_state_dict
251-
model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8)
252-
model.load_state_dict(state_dict, assign=True)
251+
model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
252+
if state_dict != {}:
253+
model.load_state_dict(state_dict, assign=True)
253254
return model
254255

255256

export.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111

12-
from build.builder import _initialize_model, BuilderArgs
12+
from build.builder import _initialize_model, BuilderArgs, _set_gguf_kwargs, _unset_gguf_kwargs
1313
from cli import add_arguments_for_export, arg_init, check_args
1414
from export_aoti import export_model as export_model_aoti
1515

@@ -42,24 +42,48 @@ def main(args):
4242
print(f"Using device={builder_args.device}")
4343
set_precision(builder_args.precision)
4444

45+
4546
builder_args.dso_path = None
4647
builder_args.pte_path = None
4748
builder_args.setup_caches = True
48-
model = _initialize_model(
49-
builder_args,
50-
quantize,
51-
)
5249

5350
output_pte_path = args.output_pte_path
5451
output_dso_path = args.output_dso_path
5552

53+
# TODO: clean this up
54+
# This mess is because ET does not support _weight_int4pack_mm right now
55+
if not builder_args.gguf_path:
56+
model = _initialize_model(
57+
builder_args,
58+
quantize,
59+
)
60+
model_to_pte = model
61+
model_to_dso = model
62+
else:
63+
if output_pte_path:
64+
_set_gguf_kwargs(builder_args, is_et=True, context="export")
65+
model_to_pte = _initialize_model(
66+
builder_args,
67+
quantize,
68+
)
69+
_unset_gguf_kwargs(builder_args)
70+
71+
if output_dso_path:
72+
_set_gguf_kwargs(builder_args, is_et=False, context="export")
73+
model_to_dso = _initialize_model(
74+
builder_args,
75+
quantize,
76+
)
77+
_unset_gguf_kwargs(builder_args)
78+
79+
5680
with torch.no_grad():
5781
if output_pte_path:
5882
output_pte_path = str(os.path.abspath(output_pte_path))
5983
print(f">{output_pte_path}<")
6084
if executorch_export_available:
6185
print(f"Exporting model using Executorch to {output_pte_path}")
62-
export_model_et(model, builder_args.device, args.output_pte_path, args)
86+
export_model_et(model_to_pte, builder_args.device, args.output_pte_path, args)
6387
else:
6488
print(
6589
"Export with executorch requested but Executorch could not be loaded"
@@ -68,7 +92,7 @@ def main(args):
6892
if output_dso_path:
6993
output_dso_path = str(os.path.abspath(output_dso_path))
7094
print(f"Exporting model using AOT Inductor to {output_dso_path}")
71-
export_model_aoti(model, builder_args.device, output_dso_path, args)
95+
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)
7296

7397

7498
if __name__ == "__main__":

0 commit comments

Comments
 (0)