diff --git a/src/brevitas_examples/llm/README.md b/src/brevitas_examples/llm/README.md index b9518c3c0..89b572d13 100644 --- a/src/brevitas_examples/llm/README.md +++ b/src/brevitas_examples/llm/README.md @@ -15,9 +15,12 @@ Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Curr When using `--optimize-rotations`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`. ```bash -usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] - [--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval] - [--dataset {wikitext2,c4}] [--gpxq-block-name GPXQ_BLOCK_NAME] +usage: main.py [-h] [--config CONFIG] [--model MODEL] + [--dtype {float32,float16,bfloat16}] [--seed SEED] + [--nsamples NSAMPLES] + [--nsamples-rot-calibration NSAMPLES_ROT_CALIBRATION] + [--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}] + [--gpxq-block-name GPXQ_BLOCK_NAME] [--weight-bit-width WEIGHT_BIT_WIDTH] [--weight-param-method {stats,mse,hqo}] [--weight-scale-precision {float_scale,po2_scale}] @@ -48,11 +51,10 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE] [--act-calibration] [--bias-corr] [--ln-affine-merge] [--convert-layernorm-to-rmsnorm] [--replace-rmsnorm] - [--no-quantize] [--no-float16] - [--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa] - [--functional-sdpa-quant] [--replace-mha] + [--no-quantize] [--scaling-min-val SCALING_MIN_VAL] + [--quant-sdpa] [--functional-sdpa-quant] [--replace-mha] [--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}] - [--rotation-mode {had,ort}] [--optimize-rotations] + [--optimize-rotations] [--rotation-mode {had,ort}] [--rotation-orphan-sink] [--rotation-sdpa-regions] [--act-equalization {None,layerwise,fx}] [--act-equalization-alpha ACT_EQUALIZATION_ALPHA] @@ -65,14 +67,20 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED] [--few-shot-compile] [--few-shot-zeroshot] [--few-shot-limit FEW_SHOT_LIMIT] [--few-shot-tasks [FEW_SHOT_TASKS ...]] + [--rotation-layers-to-expand [ROTATION_LAYERS_TO_EXPAND ...]] options: -h, --help show this help message and exit --config CONFIG Specify alternative default commandline args (e.g., config/default_template.yml). Default: None. --model MODEL HF model name. Default: facebook/opt-125m. + --dtype {float32,float16,bfloat16} + Data type for model. Default: None --seed SEED Seed for sampling the calibration data. Default: 0. --nsamples NSAMPLES Number of calibration data samples. Default: 128. + --nsamples-rot-calibration NSAMPLES_ROT_CALIBRATION + Number of calibration data samples for rotation. + Default: 800. --seqlen SEQLEN Sequence length. Default: 2048. --eval Eval model PPL on the chosen Dataset. --dataset {wikitext2,c4} @@ -134,8 +142,9 @@ options: Granularity for scales/zero-point of inputs. Default: per_tensor. --kv-quant-granularity {per_tensor,per_row,per_group} - Granularity for scales/zero-point of inputs. Default: - per_tensor. + Granularity for scales/zero-point of KV cache. If not + set, it will use input-quant-granularity. Default: + None --input-group-size INPUT_GROUP_SIZE Group size for per_group input quantization. Default: 64. @@ -174,8 +183,6 @@ options: Merge LN affine params. --replace-rmsnorm Replace HF RMSNorms with Torch one. --no-quantize Disable quantization. - --no-float16 Disable float16 as base datatype and switch to - float32. --scaling-min-val SCALING_MIN_VAL Minimum value to clamp scale to when using bf16 or fp16 quantization. @@ -191,6 +198,7 @@ options: models (e.g. OPT). --rotation {fx,layerwise,fused_no_fx} Apply graph rotation equalization + --optimize-rotations Whether to optimize the rotations (default: False). --rotation-mode {had,ort} If GraphRotation is enabled, decide how to compute the random rotation matrix that is fully fused. Online or @@ -242,5 +250,8 @@ options: --few-shot-tasks [FEW_SHOT_TASKS ...] A list of tasks for zero_shot evaluation. Default: ['arc_challenge', 'arc_easy', 'winogrande', 'piqa'] + --rotation-layers-to-expand [ROTATION_LAYERS_TO_EXPAND ...] + A list of module names to expand with hadamard + rotation. Default: [] ``` diff --git a/src/brevitas_examples/llm/config/default_template.yml b/src/brevitas_examples/llm/config/default_template.yml index b956da2fd..2b8696939 100644 --- a/src/brevitas_examples/llm/config/default_template.yml +++ b/src/brevitas_examples/llm/config/default_template.yml @@ -51,6 +51,7 @@ model: facebook/opt-125m no_float16: false no_quantize: false nsamples: 128 +nsamples_rot_calibration: 800 optimize_rotations: false quant_sdpa: false quantize_input_zero_point: false @@ -59,6 +60,7 @@ quantize_weight_zero_point: false replace_mha: false replace_rmsnorm: false rotation: null +rotation_layers_to_expand: [] rotation_mode: had rotation_orphan_sink: false rotation_sdpa_regions: false diff --git a/src/brevitas_examples/llm/llm_args.py b/src/brevitas_examples/llm/llm_args.py index 3e2136de2..6a2879244 100644 --- a/src/brevitas_examples/llm/llm_args.py +++ b/src/brevitas_examples/llm/llm_args.py @@ -19,6 +19,12 @@ def create_llm_args_parser(): type=str, default="facebook/opt-125m", help='HF model name. Default: facebook/opt-125m.') + parser.add_argument( + '--dtype', + type=str, + default=None, + choices=["float32", "float16", "bfloat16"], + help='Data type for model. Default: %(default)s') parser.add_argument( '--seed', type=int, default=0, help='Seed for sampling the calibration data. Default: 0.') parser.add_argument( @@ -220,10 +226,6 @@ def create_llm_args_parser(): parser.add_argument( '--replace-rmsnorm', action='store_true', help='Replace HF RMSNorms with Torch one.') parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.') - parser.add_argument( - '--no-float16', - action='store_true', - help='Disable float16 as base datatype and switch to float32.') parser.add_argument( '--scaling-min-val', type=float, diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index bb9da3e06..9e1601d53 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -160,10 +160,7 @@ def quantize_llm(args, extra_args=None): if args.export_prefix is None: args.export_prefix = f"{args.model.replace('/', '--')}" - if args.no_float16: - dtype = torch.float32 - else: - dtype = torch.float16 + dtype = getattr(torch, args.dtype) # Whether to quantize SDPA with FX quant_sdpa_fx = args.quant_sdpa and not args.replace_mha diff --git a/tests/brevitas_examples/llm_test_template.yml b/tests/brevitas_examples/llm_test_template.yml index f37c99152..9b8511771 100644 --- a/tests/brevitas_examples/llm_test_template.yml +++ b/tests/brevitas_examples/llm_test_template.yml @@ -46,7 +46,7 @@ ln_affine_merge: false load_awq: null load_checkpoint: false model: facebook/opt-125m -no_float16: false +dtype: float32 no_quantize: false nsamples: 128 optimize_rotations: false diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py index 5a20df1eb..8d6ce08a2 100644 --- a/tests/brevitas_examples/test_llm.py +++ b/tests/brevitas_examples/test_llm.py @@ -142,7 +142,7 @@ def default_run_args(request): args.weight_quant_granularity = "per_channel" # "per_tensor", "per_channel", "per_group". args.input_bit_width = 8 args.act_calibration = True - args.no_float16 = True + args.dtype = "float32" return args @@ -910,7 +910,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_orphan_sink": True, "rotation_mode": "ort", "nsamples_rot_calibration": 2, - "no_float16": True, + "dtype": "float32", "extra_args": [ "--learning_rate", "1.5", @@ -938,7 +938,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "ort", "nsamples_rot_calibration": 2, - "no_float16": True, + "dtype": "float32", "extra_args": [ "--learning_rate", "1.5", @@ -966,7 +966,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_orphan_sink": True, "rotation_mode": "had", "nsamples_rot_calibration": 2, - "no_float16": True, + "dtype": "float32", "extra_args": [ "--learning_rate", "1.5", @@ -994,7 +994,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl): "rotation_orphan_sink": False, "rotation_mode": "had", "nsamples_rot_calibration": 2, - "no_float16": True, + "dtype": "float32", "extra_args": [ "--learning_rate", "1.5",