Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit f8884e6

Browse files
committed
add et export with gguf with test
1 parent d8c30a3 commit f8884e6

File tree

5 files changed

+70
-22
lines changed

5 files changed

+70
-22
lines changed

.github/workflows/et.yml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ jobs:
6060
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
6161
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
6262
popd
63+
64+
mkdir gguf_files
65+
export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf
66+
export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model
67+
wget -O ${GGUF_PATH} "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true"
68+
wget -O ${GGUF_TOKENIZER_PATH} https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
69+
6370
- name: Run inference
6471
run: |
6572
export MODEL_PATH=${PWD}/checkpoints/stories15M/stories15M.pt
@@ -75,7 +82,7 @@ jobs:
7582
echo "Tests complete."
7683
7784
- name: Run inference
78-
run: |
85+
run: |
7986
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
8087
export MODEL_NAME=stories15M
8188
export MODEL_DIR=/tmp
@@ -121,3 +128,13 @@ jobs:
121128
echo "tests complete"
122129
echo "******************************************"
123130
131+
- name: Run GGUF export + inference
132+
run: |
133+
export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf
134+
export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model
135+
136+
python torchchat.py export --gguf-path ${GGUF_PATH} --output-pte-path ${PWD}/${MODEL_NAME}.pte
137+
python torchchat.py generate --pte-path ${PWD}/${MODEL_NAME}.pte --tokenizer-path ${GGUF_TOKENIZER_PATH} --max-new-tokens 10 --temperature 0 > ${PWD}/output_et
138+
cat ${PWD}/output_et
139+
140+
echo "Tests complete."

build/builder.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import time
1010
from dataclasses import dataclass
1111
from pathlib import Path
12-
from typing import Optional, Union
12+
from typing import Any, Optional, Union
1313

1414
import torch
1515
import torch._dynamo.config
@@ -29,6 +29,7 @@ class BuilderArgs:
2929
params_path: Optional[Union[Path, str]] = None
3030
params_table: Optional[str] = None
3131
gguf_path: Optional[Union[Path, str]] = None
32+
gguf_kwargs: Optional[dict[str, Any]] = None
3233
dso_path: Optional[Union[Path, str]] = None
3334
pte_path: Optional[Union[Path, str]] = None
3435
device: str = "cpu"
@@ -72,6 +73,7 @@ def from_args(cls, args): # -> BuilderArgs:
7273
params_path=args.params_path,
7374
params_table=args.params_table,
7475
gguf_path=args.gguf_path,
76+
gguf_kwargs=None,
7577
dso_path=args.dso_path,
7678
pte_path=args.pte_path,
7779
device=args.device,
@@ -156,7 +158,11 @@ def device_sync(device):
156158

157159
def _load_model_gguf(builder_args):
158160
assert builder_args.gguf_path
159-
model = Transformer.from_gguf(builder_args.gguf_path)
161+
if builder_args.gguf_kwargs is None:
162+
kwargs = {}
163+
else:
164+
kwargs = builder_args.gguf_kwargs
165+
model = Transformer.from_gguf(builder_args.gguf_path, **kwargs)
160166
return model
161167

162168

build/gguf_loader.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
139139
return model
140140

141141

142-
def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module:
142+
def load_model_and_state_dict(gguf_file: str, *, load_state_dict: bool = True, load_as_quantized: bool = True, inner_k_tiles = 8) -> torch.nn.Module:
143143
"""
144144
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
145145
that can be loaded into it.
@@ -174,14 +174,14 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_
174174
in_features = mod.in_features
175175
assert all(t.shape == (in_features, out_features))
176176

177-
q, s, z = Q4_0.unpack(t)
178-
scales_and_zeros = pack_scales_and_zeros(s, z)
179-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
180-
q, inner_k_tiles
181-
)
182-
183-
state_dict[f"{fqn}.weight"] = weight_int4pack
184-
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
177+
if load_state_dict:
178+
q, s, z = Q4_0.unpack(t)
179+
scales_and_zeros = pack_scales_and_zeros(s, z)
180+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
181+
q, inner_k_tiles
182+
)
183+
state_dict[f"{fqn}.weight"] = weight_int4pack
184+
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
185185

186186
parent = _fqn_lookup(_fqn_up(fqn), model)
187187
setattr(
@@ -197,8 +197,10 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_
197197
),
198198
)
199199
else:
200-
state_dict[f"{fqn}.weight"] = to_float(t)
200+
if load_state_dict:
201+
state_dict[f"{fqn}.weight"] = to_float(t)
201202

203+
assert (state_dict == {}) == (not load_state_dict)
202204
return model, state_dict
203205

204206

build/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,11 @@ def from_params(cls, params_path: str):
246246
return cls(ModelArgs.from_params(params_path))
247247

248248
@classmethod
249-
def from_gguf(cls, gguf_path: str):
249+
def from_gguf(cls, gguf_path: str, **kwargs):
250250
from build.gguf_loader import load_model_and_state_dict
251-
model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8)
252-
model.load_state_dict(state_dict, assign=True)
251+
model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
252+
if state_dict != {}:
253+
model.load_state_dict(state_dict, assign=True)
253254
return model
254255

255256

export.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,46 @@ def main(args):
4242
print(f"Using device={builder_args.device}")
4343
set_precision(builder_args.precision)
4444

45+
4546
builder_args.dso_path = None
4647
builder_args.pte_path = None
4748
builder_args.setup_caches = True
48-
model = _initialize_model(
49-
builder_args,
50-
quantize,
51-
)
5249

5350
output_pte_path = args.output_pte_path
5451
output_dso_path = args.output_dso_path
5552

53+
if not builder_args.gguf_path:
54+
model = _initialize_model(
55+
builder_args,
56+
quantize,
57+
)
58+
model_to_pte = model
59+
model_to_dso = model
60+
else:
61+
if output_pte_path:
62+
assert builder_args.gguf_kwargs is None
63+
# TODO: ET does not support _weight_int4pack_mm right now,
64+
# so GGUF is converted to float
65+
builder_args.gguf_kwargs = {"load_as_quantized": False}
66+
model_to_pte = _initialize_model(
67+
builder_args,
68+
quantize,
69+
)
70+
builder_args.gguf_kwargs = None
71+
if output_dso_path:
72+
assert builder_args.gguf_kwargs is None
73+
model_to_dso = _initialize_model(
74+
builder_args,
75+
quantize,
76+
)
77+
5678
with torch.no_grad():
5779
if output_pte_path:
5880
output_pte_path = str(os.path.abspath(output_pte_path))
5981
print(f">{output_pte_path}<")
6082
if executorch_export_available:
6183
print(f"Exporting model using Executorch to {output_pte_path}")
62-
export_model_et(model, builder_args.device, args.output_pte_path, args)
84+
export_model_et(model_to_pte, builder_args.device, args.output_pte_path, args)
6385
else:
6486
print(
6587
"Export with executorch requested but Executorch could not be loaded"
@@ -68,7 +90,7 @@ def main(args):
6890
if output_dso_path:
6991
output_dso_path = str(os.path.abspath(output_dso_path))
7092
print(f"Exporting model using AOT Inductor to {output_dso_path}")
71-
export_model_aoti(model, builder_args.device, output_dso_path, args)
93+
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)
7294

7395

7496
if __name__ == "__main__":

0 commit comments

Comments
 (0)