From da53078ce13ef5e056aba8f680f1a9727fde2a8b Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 17 Apr 2024 15:50:08 -0700 Subject: [PATCH] fix generate too --- build/builder.py | 26 ++++++++++++++++++++++++++ export.py | 16 +++++++++------- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/build/builder.py b/build/builder.py index cce5846ce..10d6c3717 100644 --- a/build/builder.py +++ b/build/builder.py @@ -176,6 +176,23 @@ def device_sync(device): sys.path.append(str(wd)) +# TODO: remove these once ET supports _weight_int4pack_mm +def _set_gguf_kwargs(builder_args, is_et, context: str): + assert context in ["export", "generate"] + assert builder_args.gguf_kwargs is None + + if builder_args.gguf_path is None: + print("No gguf_path provided, so ignoring set_gguf_kwargs.") + return + + builder_args.gguf_kwargs = {} + if is_et: + builder_args.gguf_kwargs["load_as_quantized"] = False + +def _unset_gguf_kwargs(builder_args): + builder_args.gguf_kwargs = None + + def _load_model_gguf(builder_args): assert builder_args.gguf_path if builder_args.gguf_kwargs is None: @@ -260,6 +277,15 @@ def _initialize_model( ): print("Loading model ...") t0 = time.time() + + if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path): + print("Setting gguf_kwargs for generate.") + is_dso = builder_args.dso_path is not None + is_pte = builder_args.pte_path is not None + assert not (is_dso and is_pte) + assert builder_args.gguf_kwargs is None + _set_gguf_kwargs(builder_args, is_et=is_pte, context="generate") + model_ = _load_model(builder_args) device_sync(device=builder_args.device) print(f"Time to load model: {time.time() - t0:.02f} seconds") diff --git a/export.py b/export.py index 8786be9d2..3723e4b65 100644 --- a/export.py +++ b/export.py @@ -9,7 +9,7 @@ import torch -from build.builder import _initialize_model, BuilderArgs +from build.builder import _initialize_model, BuilderArgs, _set_gguf_kwargs, _unset_gguf_kwargs from cli import add_arguments_for_export, arg_init, check_args from export_aoti import export_model as export_model_aoti @@ -50,6 +50,8 @@ def main(args): output_pte_path = args.output_pte_path output_dso_path = args.output_dso_path + # TODO: clean this up + # This mess is because ET does not support _weight_int4pack_mm right now if not builder_args.gguf_path: model = _initialize_model( builder_args, @@ -59,21 +61,21 @@ def main(args): model_to_dso = model else: if output_pte_path: - assert builder_args.gguf_kwargs is None - # TODO: ET does not support _weight_int4pack_mm right now, - # so GGUF is converted to float - builder_args.gguf_kwargs = {"load_as_quantized": False} + _set_gguf_kwargs(builder_args, is_et=True, context="export") model_to_pte = _initialize_model( builder_args, quantize, ) - builder_args.gguf_kwargs = None + _unset_gguf_kwargs(builder_args) + if output_dso_path: - assert builder_args.gguf_kwargs is None + _set_gguf_kwargs(builder_args, is_et=False, context="export") model_to_dso = _initialize_model( builder_args, quantize, ) + _unset_gguf_kwargs(builder_args) + with torch.no_grad(): if output_pte_path: