Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (brevitas_examples/llm): better dtype selection #1186

Merged
merged 4 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ Set the env variable `BREVITAS_JIT=1` to speed up the quantization process. Curr
When using `--optimize-rotations`, the rotation training procedure relies on the Trainer class (https://huggingface.co/docs/transformers/en/main_classes/trainer). Therefore, training can be further configured by passing arguments accepted by the dataclass TrainingArguments (https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments), e.g. `--learning_rate`, `--weight_decay`, `per_device_train_batch_size`.

```bash
usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--nsamples NSAMPLES] [--seqlen SEQLEN] [--eval]
[--dataset {wikitext2,c4}] [--gpxq-block-name GPXQ_BLOCK_NAME]
usage: main.py [-h] [--config CONFIG] [--model MODEL]
[--dtype {float32,float16,bfloat16}] [--seed SEED]
[--nsamples NSAMPLES]
[--nsamples-rot-calibration NSAMPLES_ROT_CALIBRATION]
[--seqlen SEQLEN] [--eval] [--dataset {wikitext2,c4}]
[--gpxq-block-name GPXQ_BLOCK_NAME]
[--weight-bit-width WEIGHT_BIT_WIDTH]
[--weight-param-method {stats,mse,hqo}]
[--weight-scale-precision {float_scale,po2_scale}]
Expand Down Expand Up @@ -48,11 +51,10 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--gpxq-max-accumulator-tile-size GPXQ_MAX_ACCUMULATOR_TILE_SIZE]
[--act-calibration] [--bias-corr] [--ln-affine-merge]
[--convert-layernorm-to-rmsnorm] [--replace-rmsnorm]
[--no-quantize] [--no-float16]
[--scaling-min-val SCALING_MIN_VAL] [--quant-sdpa]
[--functional-sdpa-quant] [--replace-mha]
[--no-quantize] [--scaling-min-val SCALING_MIN_VAL]
[--quant-sdpa] [--functional-sdpa-quant] [--replace-mha]
[--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--optimize-rotations]
[--optimize-rotations] [--rotation-mode {had,ort}]
[--rotation-orphan-sink] [--rotation-sdpa-regions]
[--act-equalization {None,layerwise,fx}]
[--act-equalization-alpha ACT_EQUALIZATION_ALPHA]
Expand All @@ -65,14 +67,20 @@ usage: main.py [-h] [--config CONFIG] [--model MODEL] [--seed SEED]
[--few-shot-compile] [--few-shot-zeroshot]
[--few-shot-limit FEW_SHOT_LIMIT]
[--few-shot-tasks [FEW_SHOT_TASKS ...]]
[--rotation-layers-to-expand [ROTATION_LAYERS_TO_EXPAND ...]]

options:
-h, --help show this help message and exit
--config CONFIG Specify alternative default commandline args (e.g.,
config/default_template.yml). Default: None.
--model MODEL HF model name. Default: facebook/opt-125m.
--dtype {float32,float16,bfloat16}
Data type for model. Default: None
--seed SEED Seed for sampling the calibration data. Default: 0.
--nsamples NSAMPLES Number of calibration data samples. Default: 128.
--nsamples-rot-calibration NSAMPLES_ROT_CALIBRATION
Number of calibration data samples for rotation.
Default: 800.
--seqlen SEQLEN Sequence length. Default: 2048.
--eval Eval model PPL on the chosen Dataset.
--dataset {wikitext2,c4}
Expand Down Expand Up @@ -134,8 +142,9 @@ options:
Granularity for scales/zero-point of inputs. Default:
per_tensor.
--kv-quant-granularity {per_tensor,per_row,per_group}
Granularity for scales/zero-point of inputs. Default:
per_tensor.
Granularity for scales/zero-point of KV cache. If not
set, it will use input-quant-granularity. Default:
None
--input-group-size INPUT_GROUP_SIZE
Group size for per_group input quantization. Default:
64.
Expand Down Expand Up @@ -174,8 +183,6 @@ options:
Merge LN affine params.
--replace-rmsnorm Replace HF RMSNorms with Torch one.
--no-quantize Disable quantization.
--no-float16 Disable float16 as base datatype and switch to
float32.
--scaling-min-val SCALING_MIN_VAL
Minimum value to clamp scale to when using bf16 or
fp16 quantization.
Expand All @@ -191,6 +198,7 @@ options:
models (e.g. OPT).
--rotation {fx,layerwise,fused_no_fx}
Apply graph rotation equalization
--optimize-rotations Whether to optimize the rotations (default: False).
--rotation-mode {had,ort}
If GraphRotation is enabled, decide how to compute the
random rotation matrix that is fully fused. Online or
Expand Down Expand Up @@ -242,5 +250,8 @@ options:
--few-shot-tasks [FEW_SHOT_TASKS ...]
A list of tasks for zero_shot evaluation. Default:
['arc_challenge', 'arc_easy', 'winogrande', 'piqa']
--rotation-layers-to-expand [ROTATION_LAYERS_TO_EXPAND ...]
A list of module names to expand with hadamard
rotation. Default: []

```
2 changes: 2 additions & 0 deletions src/brevitas_examples/llm/config/default_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ model: facebook/opt-125m
no_float16: false
no_quantize: false
nsamples: 128
nsamples_rot_calibration: 800
optimize_rotations: false
quant_sdpa: false
quantize_input_zero_point: false
Expand All @@ -59,6 +60,7 @@ quantize_weight_zero_point: false
replace_mha: false
replace_rmsnorm: false
rotation: null
rotation_layers_to_expand: []
rotation_mode: had
rotation_orphan_sink: false
rotation_sdpa_regions: false
Expand Down
10 changes: 6 additions & 4 deletions src/brevitas_examples/llm/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def create_llm_args_parser():
type=str,
default="facebook/opt-125m",
help='HF model name. Default: facebook/opt-125m.')
parser.add_argument(
'--dtype',
type=str,
default=None,
choices=["float32", "float16", "bfloat16"],
help='Data type for model. Default: %(default)s')
parser.add_argument(
'--seed', type=int, default=0, help='Seed for sampling the calibration data. Default: 0.')
parser.add_argument(
Expand Down Expand Up @@ -220,10 +226,6 @@ def create_llm_args_parser():
parser.add_argument(
'--replace-rmsnorm', action='store_true', help='Replace HF RMSNorms with Torch one.')
parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.')
parser.add_argument(
'--no-float16',
action='store_true',
help='Disable float16 as base datatype and switch to float32.')
parser.add_argument(
'--scaling-min-val',
type=float,
Expand Down
5 changes: 1 addition & 4 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ def quantize_llm(args, extra_args=None):
if args.export_prefix is None:
args.export_prefix = f"{args.model.replace('/', '--')}"

if args.no_float16:
dtype = torch.float32
else:
dtype = torch.float16
dtype = getattr(torch, args.dtype)

# Whether to quantize SDPA with FX
quant_sdpa_fx = args.quant_sdpa and not args.replace_mha
Expand Down
2 changes: 1 addition & 1 deletion tests/brevitas_examples/llm_test_template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ ln_affine_merge: false
load_awq: null
load_checkpoint: false
model: facebook/opt-125m
no_float16: false
dtype: float32
no_quantize: false
nsamples: 128
optimize_rotations: false
Expand Down
10 changes: 5 additions & 5 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def default_run_args(request):
args.weight_quant_granularity = "per_channel" # "per_tensor", "per_channel", "per_group".
args.input_bit_width = 8
args.act_calibration = True
args.no_float16 = True
args.dtype = "float32"
return args


Expand Down Expand Up @@ -910,7 +910,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"rotation_orphan_sink": True,
"rotation_mode": "ort",
"nsamples_rot_calibration": 2,
"no_float16": True,
"dtype": "float32",
"extra_args": [
"--learning_rate",
"1.5",
Expand Down Expand Up @@ -938,7 +938,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"rotation_orphan_sink": False,
"rotation_mode": "ort",
"nsamples_rot_calibration": 2,
"no_float16": True,
"dtype": "float32",
"extra_args": [
"--learning_rate",
"1.5",
Expand Down Expand Up @@ -966,7 +966,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"rotation_orphan_sink": True,
"rotation_mode": "had",
"nsamples_rot_calibration": 2,
"no_float16": True,
"dtype": "float32",
"extra_args": [
"--learning_rate",
"1.5",
Expand Down Expand Up @@ -994,7 +994,7 @@ def test_small_models_rotation_ppl(caplog, rotation_ppl_args_and_ppl):
"rotation_orphan_sink": False,
"rotation_mode": "had",
"nsamples_rot_calibration": 2,
"no_float16": True,
"dtype": "float32",
"extra_args": [
"--learning_rate",
"1.5",
Expand Down