Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix (ptq/rotation): fix for rotation implementation (#1095) #1095

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,6 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
module.offload_params(module)

if insert_rotation_module and len(region.srcs) == 0:
# print(name, module.in_features, K)
rewriter = ModuleInstanceToModuleInstance(
module, RotatedModule(had_mat=rot_mat, k=K, layer=module))
rewriters.append(rewriter)
Expand Down Expand Up @@ -1467,7 +1466,7 @@ def rotate_matmuls(self, graph_module):

def apply(self,
graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:

rewriters = []
regions = _extract_regions(
graph_model,
state_impl_kwargs={
Expand Down
9 changes: 4 additions & 5 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--convert-layernorm-to-rmsnorm] [--replace-rmsnorm]
[--no-quantize] [--no-float16]
[--scaling-min-val SCALING_MIN_VAL] [--replace-mha]
[--weight-equalization]
[--graph-rotation {fx,layerwise,fused_no_fx}]
[--graph-rotation-mode {had,ort}] [--rotation-orphan-sink]
[--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ]
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}]
[--export-prefix EXPORT_PREFIX]
Expand Down Expand Up @@ -148,9 +147,9 @@ options:
--weight-equalization
Apply weight equalization. Relevant to ReLU based
models (e.g. OPT).
--graph-rotation {fx,layerwise,fused_no_fx}
--rotation {fx,layerwise,fused_no_fx}
Apply graph rotation equalization
--graph-rotation-mode {had,ort}
--rotation-mode {had,ort}
If GraphRotation is enabled, decide how to compute the
random rotation matrix that is fully fused. Online or
partial rotation will always be Hadamard
Expand Down
20 changes: 11 additions & 9 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def fused_rotation_no_fx(model, calibration_loader, args):
new_model = offload_model(new_model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.graph_rotation_mode,
full_rotation_method=args.rotation_mode,
return_rewriters=True)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')
Expand Down Expand Up @@ -104,10 +104,12 @@ def model_export(model, ref_input, args):


def validate(args):
if args.graph_rotation == 'fx':
if args.rotation == 'fx':
assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters'
assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'
assert args.convert_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm'
elif args.rotation == 'fused_no_fx':
assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'
if not args.no_quantize:
if args.gptq and args.gpfq:
warn("Both GPTQ and GPFQ are enabled.")
Expand Down Expand Up @@ -259,16 +261,16 @@ def main(args):
apply_layernorm_to_rmsnorm(model)
print("Layernorm To RMSNorm applied.")

if args.graph_rotation == 'fx':
if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.graph_rotation_mode)
orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode)
model = eq.apply(model)
remove_hooks(model)
elif args.graph_rotation == 'layerwise':
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.graph_rotation == 'fused_no_fx':
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)

# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
Expand Down Expand Up @@ -354,7 +356,7 @@ def main(args):
# If any equalization has taken places, the embedding layer and the fully connected one are
# not tied anymore, and they need to be treated as standalone, separate layers.
# In all other cases we can tie them back so to preserve memory.
if args.act_equalization is None and not require_fx:
if args.act_equalization is None and not require_fx and args.rotation is None:
model.tie_weights()

if args.bias_corr:
Expand Down Expand Up @@ -600,13 +602,13 @@ def parse_args(args):
action='store_true',
help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).')
parser.add_argument(
'--graph-rotation',
'--rotation',
type=str,
default=None,
choices=['fx', 'layerwise', 'fused_no_fx'],
help='Apply graph rotation equalization')
parser.add_argument(
'--graph-rotation-mode',
'--rotation-mode',
default='had',
choices=['had', 'ort'],
help=
Expand Down
7 changes: 3 additions & 4 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"no_quantize": True,
"rotation_orphan_sink": True,
"convert_layernorm_to_rmsnorm": True,
"graph_rotation": "fx",
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
Expand All @@ -394,7 +394,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"no_quantize": True,
"rotation_orphan_sink": False,
"convert_layernorm_to_rmsnorm": True,
"graph_rotation": "fx",
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
Expand All @@ -417,8 +417,7 @@ def test_small_models_quant_layer(caplog, layer_args):
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(args, 'graph_rotation') and args.graph_rotation == 'fx' and platform.system(
) == 'Windows':
if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)
Expand Down
Loading