Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add et export with gguf with test #245

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion .github/workflows/et.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ jobs:
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
popd

mkdir gguf_files
export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf
export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model
wget -O ${GGUF_PATH} "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true"
wget -O ${GGUF_TOKENIZER_PATH} https://github.com/karpathy/llama2.c/raw/master/tokenizer.model

- name: Run inference
run: |
export MODEL_PATH=${PWD}/checkpoints/stories15M/stories15M.pt
Expand All @@ -75,7 +82,7 @@ jobs:
echo "Tests complete."

- name: Run inference
run: |
run: |
export MODEL_PATH=checkpoints/stories15M/stories15M.pt
export MODEL_NAME=stories15M
export MODEL_DIR=/tmp
Expand Down Expand Up @@ -121,3 +128,13 @@ jobs:
echo "tests complete"
echo "******************************************"

- name: Run GGUF export + inference
run: |
export GGUF_PATH=gguf_files/TinyLlama-1.1B-openorca.Q4_0.gguf
export GGUF_TOKENIZER_PATH=gguf_files/tokenizer.model

python torchchat.py export --gguf-path ${GGUF_PATH} --output-pte-path ${PWD}/${MODEL_NAME}.pte
python torchchat.py generate --gguf-path ${GGUF_PATH} --pte-path ${PWD}/${MODEL_NAME}.pte --tokenizer-path ${GGUF_TOKENIZER_PATH} --temperature 0 --max-new-tokens 20 > ${PWD}/output_et
cat ${PWD}/output_et

echo "Tests complete."
36 changes: 34 additions & 2 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union

import torch
import torch._dynamo.config
Expand All @@ -29,6 +29,7 @@ class BuilderArgs:
params_path: Optional[Union[Path, str]] = None
params_table: Optional[str] = None
gguf_path: Optional[Union[Path, str]] = None
gguf_kwargs: Optional[dict[str, Any]] = None
dso_path: Optional[Union[Path, str]] = None
pte_path: Optional[Union[Path, str]] = None
device: str = "cpu"
Expand Down Expand Up @@ -91,6 +92,7 @@ def from_args(cls, args): # -> BuilderArgs:
params_path=args.params_path,
params_table=args.params_table,
gguf_path=args.gguf_path,
gguf_kwargs=None,
dso_path=args.dso_path,
pte_path=args.pte_path,
device=args.device,
Expand Down Expand Up @@ -174,9 +176,30 @@ def device_sync(device):
sys.path.append(str(wd))


# TODO: remove these once ET supports _weight_int4pack_mm
def _set_gguf_kwargs(builder_args, is_et, context: str):
assert context in ["export", "generate"]
assert builder_args.gguf_kwargs is None

if builder_args.gguf_path is None:
print("No gguf_path provided, so ignoring set_gguf_kwargs.")
return

builder_args.gguf_kwargs = {}
if is_et:
builder_args.gguf_kwargs["load_as_quantized"] = False

def _unset_gguf_kwargs(builder_args):
builder_args.gguf_kwargs = None


def _load_model_gguf(builder_args):
assert builder_args.gguf_path
model = Transformer.from_gguf(builder_args.gguf_path)
if builder_args.gguf_kwargs is None:
kwargs = {}
else:
kwargs = builder_args.gguf_kwargs
model = Transformer.from_gguf(builder_args.gguf_path, **kwargs)
return model


Expand Down Expand Up @@ -254,6 +277,15 @@ def _initialize_model(
):
print("Loading model ...")
t0 = time.time()

if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
print("Setting gguf_kwargs for generate.")
is_dso = builder_args.dso_path is not None
is_pte = builder_args.pte_path is not None
assert not (is_dso and is_pte)
assert builder_args.gguf_kwargs is None
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")

model_ = _load_model(builder_args)
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")
Expand Down
22 changes: 12 additions & 10 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
return model


def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module:
def load_model_and_state_dict(gguf_file: str, *, load_state_dict: bool = True, load_as_quantized: bool = True, inner_k_tiles = 8) -> torch.nn.Module:
"""
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
that can be loaded into it.
Expand Down Expand Up @@ -174,14 +174,14 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_
in_features = mod.in_features
assert all(t.shape == (in_features, out_features))

q, s, z = Q4_0.unpack(t)
scales_and_zeros = pack_scales_and_zeros(s, z)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q, inner_k_tiles
)

state_dict[f"{fqn}.weight"] = weight_int4pack
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
if load_state_dict:
q, s, z = Q4_0.unpack(t)
scales_and_zeros = pack_scales_and_zeros(s, z)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
q, inner_k_tiles
)
state_dict[f"{fqn}.weight"] = weight_int4pack
state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros

parent = _fqn_lookup(_fqn_up(fqn), model)
setattr(
Expand All @@ -197,8 +197,10 @@ def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_
),
)
else:
state_dict[f"{fqn}.weight"] = to_float(t)
if load_state_dict:
state_dict[f"{fqn}.weight"] = to_float(t)

assert (state_dict == {}) == (not load_state_dict)
return model, state_dict


Expand Down
7 changes: 4 additions & 3 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,11 @@ def from_params(cls, params_path: str):
return cls(ModelArgs.from_params(params_path))

@classmethod
def from_gguf(cls, gguf_path: str):
def from_gguf(cls, gguf_path: str, **kwargs):
from build.gguf_loader import load_model_and_state_dict
model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8)
model.load_state_dict(state_dict, assign=True)
model, state_dict = load_model_and_state_dict(gguf_path, **kwargs)
if state_dict != {}:
model.load_state_dict(state_dict, assign=True)
return model


Expand Down
38 changes: 31 additions & 7 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from build.builder import _initialize_model, BuilderArgs
from build.builder import _initialize_model, BuilderArgs, _set_gguf_kwargs, _unset_gguf_kwargs
from cli import add_arguments_for_export, arg_init, check_args
from export_aoti import export_model as export_model_aoti

Expand Down Expand Up @@ -42,24 +42,48 @@ def main(args):
print(f"Using device={builder_args.device}")
set_precision(builder_args.precision)


builder_args.dso_path = None
builder_args.pte_path = None
builder_args.setup_caches = True
model = _initialize_model(
builder_args,
quantize,
)

output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path

# TODO: clean this up
# This mess is because ET does not support _weight_int4pack_mm right now
if not builder_args.gguf_path:
model = _initialize_model(
builder_args,
quantize,
)
model_to_pte = model
model_to_dso = model
else:
if output_pte_path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very kludgy and I would prefer to export to int4 and then handle it from there. Basing front end decisions on backend is a very bad practice because we're going to end up in a world of hurt.

Kimish and I had discussed doing a transform from int4 ->a8w4dq. Right now we just get a de-quantized model.
Please plan to land that asap, Kimish?

cc: @kimishpatel

_set_gguf_kwargs(builder_args, is_et=True, context="export")
model_to_pte = _initialize_model(
builder_args,
quantize,
)
_unset_gguf_kwargs(builder_args)

if output_dso_path:
_set_gguf_kwargs(builder_args, is_et=False, context="export")
model_to_dso = _initialize_model(
builder_args,
quantize,
)
_unset_gguf_kwargs(builder_args)


with torch.no_grad():
if output_pte_path:
output_pte_path = str(os.path.abspath(output_pte_path))
print(f">{output_pte_path}<")
if executorch_export_available:
print(f"Exporting model using Executorch to {output_pte_path}")
export_model_et(model, builder_args.device, args.output_pte_path, args)
export_model_et(model_to_pte, builder_args.device, args.output_pte_path, args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this at all :( But we're out of runway, so I will approve for now.

else:
print(
"Export with executorch requested but Executorch could not be loaded"
Expand All @@ -68,7 +92,7 @@ def main(args):
if output_dso_path:
output_dso_path = str(os.path.abspath(output_dso_path))
print(f"Exporting model using AOT Inductor to {output_dso_path}")
export_model_aoti(model, builder_args.device, output_dso_path, args)
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.



if __name__ == "__main__":
Expand Down
Loading