Skip to content

Commit

Permalink
args name
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 12, 2024
1 parent a06ea82 commit fe4b3fc
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def validate(args):
if args.graph_rotation:
assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters'
assert args.replace_rmsnorms, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'
assert args.apply_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm'
assert args.convert_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm'
if not args.no_quantize:
if args.gptq and args.gpfq:
warn("Both GPTQ and GPFQ are enabled.")
Expand Down Expand Up @@ -190,7 +190,7 @@ def main(args):
with CastFloat16ToFloat32():
apply_awq(model, awq_results)

require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge or args.apply_layernorm_to_rmsnorm else False
require_fx = True if args.weight_equalization or args.act_equalization == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm else False

# Load the data for calibration and evaluation.
calibration_loader = get_dataset_for_model(
Expand Down Expand Up @@ -249,7 +249,7 @@ def main(args):
apply_layernorm_affine_merge(model)
print("LN affine merge applied.")

if args.apply_layernorm_to_rmsnorm:
if args.convert_layernorm_to_rmsnorm:
print("Convert LayerNorm to RMSNorm...")
apply_layernorm_to_rmsnorm(model)
print("Layernorm To RMSNorm applied.")
Expand Down Expand Up @@ -557,6 +557,8 @@ def parse_args(args):
'--act-calibration', action='store_true', help='Apply activation calibration.')
parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.')
parser.add_argument('--ln-affine-merge', action='store_true', help='Merge LN affine params.')
parser.add_argument(
'--convert-layernorm-to-rmsnorm', action='store_true', help='Merge LN affine params.')
parser.add_argument(
'--replace-rmsnorm', action='store_true', help='Replace HF RMSNorms with Torch one.')
parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.')
Expand Down

0 comments on commit fe4b3fc

Please sign in to comment.