Skip to content

Commit

Permalink
Add fx act equalization, fixes for float16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius committed Jun 27, 2023
1 parent 3ef35e2 commit 5b6f975
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 22 deletions.
47 changes: 40 additions & 7 deletions src/brevitas_examples/llm/llm_quant/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
# SPDX-License-Identifier: BSD-3-Clause
"""

import warnings

import torch

from brevitas.fx.brevitas_tracer import value_trace
from brevitas.graph.equalize import activation_equalization_mode
from brevitas.graph.equalize import EqualizeGraph
from brevitas_examples.llm.llm_quant.run_utils import apply_layer_ptq_fn
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32


@torch.no_grad()
Expand All @@ -24,17 +27,47 @@ def activation_equalization_iter(curr_layer, inps, outs, cached_values, alpha):


@torch.no_grad()
def apply_act_equalization(model, dataloader, nsamples, seqlen=2048, alpha=0.5):
apply_layer_ptq_fn(
def apply_act_equalization(
model,
act_equalization_type,
dataloader,
nsamples,
inference_fn=activation_equalization_iter,
seqlen=seqlen,
alpha=alpha)
seqlen=2048,
alpha=0.5,
ref_kwargs=None):
if act_equalization_type == 'layerwise':
apply_layer_ptq_fn(
model,
dataloader,
nsamples,
inference_fn=activation_equalization_iter,
seqlen=seqlen,
alpha=alpha)
elif act_equalization_type == 'fx':
assert ref_kwargs is not None, "Ref kwargs required to perform tracing and lift the model into FX."
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model):
graph_model = value_trace(model, value_args=ref_kwargs)
# TODO this is currently running on CPU. We need Accelerate or a TorchDispatchMode
# or an FX interpreter to run it on GPU
warnings.warn(
"FX mode activation equalization currently runs on CPU, expect it to be slow for large models."
)
with activation_equalization_mode(graph_model,
alpha,
add_mul_node=False,
layerwise=False):
for input_ids in dataloader:
graph_model(input_ids=input_ids)
else:
raise RuntimeError(f"{act_equalization_type} not supported.")


@torch.no_grad()
def apply_weight_equalization(model, ref_kwargs, scale_computation_type='range'):
graph_model = value_trace(model, value_args=ref_kwargs)
EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model)
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply equalization, and then cast back
with cast_to_float32(model):
graph_model = value_trace(model, value_args=ref_kwargs)
EqualizeGraph(scale_computation_type=scale_computation_type).apply(graph_model)
9 changes: 7 additions & 2 deletions src/brevitas_examples/llm/llm_quant/ln_affine_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.graph.equalize import _is_reshaping_op
from brevitas.graph.equalize import _is_scale_invariant_module
from brevitas.graph.utils import get_module
from brevitas_examples.llm.llm_quant.run_utils import cast_to_float32


def replace_bias(next_module, new_bias):
Expand Down Expand Up @@ -73,6 +74,7 @@ def merge_layernorm_affine_params(graph_model):
)
for module, merged in merged_dict.items():
if merged:
# We preserve weight and bias in case they are used to merge SmoothQuant scales in fx mode later on
module.weight.data.fill_(1.)
module.bias.data.fill_(0.)
else:
Expand All @@ -83,5 +85,8 @@ def merge_layernorm_affine_params(graph_model):

@torch.no_grad()
def apply_layernorm_affine_merge(model, ref_kwargs):
graph_model = value_trace(model, ref_kwargs)
merge_layernorm_affine_params(graph_model)
# We can't do fp16 tracing on CPU as many kernels are not implemented
# So we have to cast to fp32 first, trace, apply merging, and then cast back
with cast_to_float32(model):
graph_model = value_trace(model, ref_kwargs)
merge_layernorm_affine_params(graph_model)
12 changes: 9 additions & 3 deletions src/brevitas_examples/llm/llm_quant/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def quantize_model(
if input_quant is not None:
input_quant = input_quant.let(
**{
'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point})
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
if input_quant_granularity == 'per_row':
# QuantMHA internally always uses Seq, B, E
input_quant = input_quant.let(
Expand All @@ -150,7 +152,9 @@ def quantize_model(
if sym_input_quant is not None:
sym_input_quant = sym_input_quant.let(
**{
'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point})
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})
if input_quant_granularity == 'per_row':
q_scaled_quant = sym_input_quant.let(
**{
Expand All @@ -169,7 +173,9 @@ def quantize_model(
if per_tensor_input_quant is not None:
per_tensor_input_quant = per_tensor_input_quant.let(
**{
'bit_width': input_bit_width, 'quantize_zero_point': quantize_input_zero_point})
'bit_width': input_bit_width,
'quantize_zero_point': quantize_input_zero_point,
'dtype': dtype})

quant_linear_kwargs = {
'input_quant': per_tensor_input_quant, 'weight_quant': weight_quant, 'dtype': dtype}
Expand Down
20 changes: 20 additions & 0 deletions src/brevitas_examples/llm/llm_quant/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
limitations under the License.
"""

from contextlib import contextmanager

import torch
from torch import nn
from tqdm import tqdm
Expand Down Expand Up @@ -140,3 +142,21 @@ def apply_layer_ptq_fn(
input_capture_fn=calib_input_capture,
seqlen=seqlen,
**inference_fn_kwargs)


@contextmanager
def cast_to_float32(model):
dtype_dict = {}
for name, p in model.named_parameters():
dtype_dict[name] = p.dtype
for name, b in model.named_buffers():
dtype_dict[name] = b.dtype
if any(dtype != torch.float32 for dtype in dtype_dict.values()):
model.to(dtype=torch.float32)
try:
yield model
finally:
for name, p in model.named_parameters():
p.data = p.data.to(dtype_dict[name])
for name, b in model.named_buffers():
b.data = b.data.to(dtype_dict[name])
27 changes: 17 additions & 10 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

import argparse
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -113,7 +112,12 @@
action='store_true',
help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).')
parser.add_argument(
'--act-equalization', action='store_true', help='Apply activation equalization (SmoothQuant).')
'--act-equalization',
default=None,
choices=[None, 'layerwise', 'fx'],
help='Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,'
'while fx merges them whenever possible into previous tensors, which is possible on ReLU based models (e.g. OPT).'
)
parser.add_argument(
'--export-target',
default=None,
Expand Down Expand Up @@ -168,10 +172,8 @@ def validate(args):
assert args.quantize_weight_zero_point, "Quantized weight zero point required."
if args.input_quant_type == 'asym':
assert args.quantize_input_zero_point, "Quantized input zero point required."
if args.input_bit_width is not None and not args.act_calibration:
warnings.warn(
"Input quantization is being applied without activation calibration. Set --act-calibration."
)
if args.input_bit_width:
assert args.act_calibration, "Input quantization is being applied without activation calibration. Set --act-calibration."


def main():
Expand Down Expand Up @@ -204,9 +206,9 @@ def main():
apply_layernorm_affine_merge(model, ref_kwargs={'input_ids': calibration_loader[0]})
print("LN affine merge applied.")

# Insert standard MHA layers when performing weight equalization to avoid dealing
# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
# with all the variability in HF implementations
if args.weight_equalization or args.input_bit_width:
if args.weight_equalization or args.act_equalization == 'fx' or args.input_bit_width:
print("Replace HF MHA with quantizable variants...")
model = replace_mha_with_quantizable_layers(model, dtype)
print("Replacing done.")
Expand All @@ -216,9 +218,14 @@ def main():
apply_weight_equalization(model, ref_kwargs={'input_ids': calibration_loader[0]})
print("Weight equalization applied.")

if args.act_equalization:
if args.act_equalization is not None:
print("Apply act equalization (SmoothQuant)...")
apply_act_equalization(model, calibration_loader, args.nsamples)
apply_act_equalization(
model,
args.act_equalization,
calibration_loader,
args.nsamples,
ref_kwargs={'input_ids': calibration_loader[0]})
print("Act equalization applied.")

if not args.no_quantize:
Expand Down

0 comments on commit 5b6f975

Please sign in to comment.