Skip to content

Commit

Permalink
fix fx and export
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 12, 2024
1 parent 41b5532 commit ef321c6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/brevitas_examples/llm/llm_quant/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from brevitas.fx.value_tracer import ValueProxy


def get_fx(model):
def get_fx(model, is_export=True):
forward_signature = inspect.signature(model.forward).parameters
if all(input_name in forward_signature
for input_name in ["input_ids", "attention_mask", "past_key_values"]):
input_names = ["input_ids", "attention_mask", "past_key_values"]
if not is_export:
input_names.remove('past_key_values')
else:
raise ValueError(
f"Quantization with an FX graph is currently only supported for models taking `input_ids`, `attention_mask` and `past_key_values` as inputs. The model only has the following inputs: {forward_signature}"
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def main(args):
seqlen=args.seqlen,
split="train",
seed=args.seed,
require_fx=require_fx and args.export_target == 'onnx_qcdq',
require_fx=require_fx and args.export_target is not None,
device=None,
fuse_sequences=args.fuse_sequences)

Expand All @@ -213,7 +213,7 @@ def main(args):
seqlen=args.seqlen,
split="validation",
seed=args.seed,
require_fx=require_fx and args.export_target == 'onnx_qcdq',
require_fx=require_fx and args.export_target is not None,
device=None,
fuse_sequences=args.fuse_sequences)

Expand All @@ -234,7 +234,7 @@ def main(args):

if require_fx:
try:
model = get_fx(model)
model = get_fx(model, is_export=args.export_target is not None)
except:
print("HF symbolic trace not compatible, attempting with dynamo.")
with torch.no_grad():
Expand Down

0 comments on commit ef321c6

Please sign in to comment.