diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index a6e03911c..d619c56e5 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -185,6 +185,8 @@ def main(args): kwargs = {"torch_dtype": dtype} if quant_sdpa_fx: kwargs["attn_implementation"] = "sdpa" + elif args.replace_mha: + kwargs["attn_implementation"] = "eager" if args.export_target == 'torch_qcdq': kwargs['torchscript'] = True