Skip to content

Commit

Permalink
fix generate too
Browse files Browse the repository at this point in the history
  • Loading branch information
metascroy committed Apr 17, 2024
1 parent 8ad99f2 commit da53078
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
26 changes: 26 additions & 0 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,23 @@ def device_sync(device):
sys.path.append(str(wd))


# TODO: remove these once ET supports _weight_int4pack_mm
def _set_gguf_kwargs(builder_args, is_et, context: str):
assert context in ["export", "generate"]
assert builder_args.gguf_kwargs is None

if builder_args.gguf_path is None:
print("No gguf_path provided, so ignoring set_gguf_kwargs.")
return

builder_args.gguf_kwargs = {}
if is_et:
builder_args.gguf_kwargs["load_as_quantized"] = False

def _unset_gguf_kwargs(builder_args):
builder_args.gguf_kwargs = None


def _load_model_gguf(builder_args):
assert builder_args.gguf_path
if builder_args.gguf_kwargs is None:
Expand Down Expand Up @@ -260,6 +277,15 @@ def _initialize_model(
):
print("Loading model ...")
t0 = time.time()

if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
print("Setting gguf_kwargs for generate.")
is_dso = builder_args.dso_path is not None
is_pte = builder_args.pte_path is not None
assert not (is_dso and is_pte)
assert builder_args.gguf_kwargs is None
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")

model_ = _load_model(builder_args)
device_sync(device=builder_args.device)
print(f"Time to load model: {time.time() - t0:.02f} seconds")
Expand Down
16 changes: 9 additions & 7 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from build.builder import _initialize_model, BuilderArgs
from build.builder import _initialize_model, BuilderArgs, _set_gguf_kwargs, _unset_gguf_kwargs
from cli import add_arguments_for_export, arg_init, check_args
from export_aoti import export_model as export_model_aoti

Expand Down Expand Up @@ -50,6 +50,8 @@ def main(args):
output_pte_path = args.output_pte_path
output_dso_path = args.output_dso_path

# TODO: clean this up
# This mess is because ET does not support _weight_int4pack_mm right now
if not builder_args.gguf_path:
model = _initialize_model(
builder_args,
Expand All @@ -59,21 +61,21 @@ def main(args):
model_to_dso = model
else:
if output_pte_path:
assert builder_args.gguf_kwargs is None
# TODO: ET does not support _weight_int4pack_mm right now,
# so GGUF is converted to float
builder_args.gguf_kwargs = {"load_as_quantized": False}
_set_gguf_kwargs(builder_args, is_et=True, context="export")
model_to_pte = _initialize_model(
builder_args,
quantize,
)
builder_args.gguf_kwargs = None
_unset_gguf_kwargs(builder_args)

if output_dso_path:
assert builder_args.gguf_kwargs is None
_set_gguf_kwargs(builder_args, is_et=False, context="export")
model_to_dso = _initialize_model(
builder_args,
quantize,
)
_unset_gguf_kwargs(builder_args)


with torch.no_grad():
if output_pte_path:
Expand Down

0 comments on commit da53078

Please sign in to comment.