From ef321c65c987ac89619808f4621dcc89968e5541 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 12 Nov 2024 15:11:00 +0000 Subject: [PATCH] fix fx and export --- src/brevitas_examples/llm/llm_quant/run_utils.py | 4 +++- src/brevitas_examples/llm/main.py | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/run_utils.py b/src/brevitas_examples/llm/llm_quant/run_utils.py index 2afe42d3a..44ba711a5 100644 --- a/src/brevitas_examples/llm/llm_quant/run_utils.py +++ b/src/brevitas_examples/llm/llm_quant/run_utils.py @@ -32,11 +32,13 @@ from brevitas.fx.value_tracer import ValueProxy -def get_fx(model): +def get_fx(model, is_export=True): forward_signature = inspect.signature(model.forward).parameters if all(input_name in forward_signature for input_name in ["input_ids", "attention_mask", "past_key_values"]): input_names = ["input_ids", "attention_mask", "past_key_values"] + if not is_export: + input_names.remove('past_key_values') else: raise ValueError( f"Quantization with an FX graph is currently only supported for models taking `input_ids`, `attention_mask` and `past_key_values` as inputs. The model only has the following inputs: {forward_signature}" diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 852cc7a84..40c1634ac 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -201,7 +201,7 @@ def main(args): seqlen=args.seqlen, split="train", seed=args.seed, - require_fx=require_fx and args.export_target == 'onnx_qcdq', + require_fx=require_fx and args.export_target is not None, device=None, fuse_sequences=args.fuse_sequences) @@ -213,7 +213,7 @@ def main(args): seqlen=args.seqlen, split="validation", seed=args.seed, - require_fx=require_fx and args.export_target == 'onnx_qcdq', + require_fx=require_fx and args.export_target is not None, device=None, fuse_sequences=args.fuse_sequences) @@ -234,7 +234,7 @@ def main(args): if require_fx: try: - model = get_fx(model) + model = get_fx(model, is_export=args.export_target is not None) except: print("HF symbolic trace not compatible, attempting with dynamo.") with torch.no_grad():