Skip to content

Commit

Permalink
Tentative
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Aug 13, 2024
1 parent f96872a commit 5fef0b1
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 46 deletions.
16 changes: 9 additions & 7 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def __init__(
def parameter_search(self, xl, x):
best_loss = torch.tensor(float('inf'), device=x.device, dtype=x.dtype)
candidate = xl
candidate = self.input_view_shape_impl(candidate)
# candidate = self.input_view_shape_impl(candidate)
best_candidate = candidate
beta = self.beta
with torch.no_grad():
Expand All @@ -614,22 +614,24 @@ def parameter_search(self, xl, x):
self.set_local_loss_mode(True)
quant_tensor = self.proxy_forward(x).detach()
self.set_local_loss_mode(False)
loss = torch.abs(quant_tensor.value - x).mean()
loss = torch.abs(quant_tensor.value_ - x).mean()

best_candidate = torch.where(loss < best_loss, candidate, best_candidate)
if loss >= best_loss:
break
best_loss = torch.min(loss, best_loss)
W_e = shrink_lp_op(x - quant_tensor.value, beta, self.lp_norm)
W_e = shrink_lp_op(x - quant_tensor.value_, beta, self.lp_norm)
zero_point = quant_tensor.zero_point
num = self.input_view_shape_impl(x - W_e).detach()
den = self.input_view_shape_impl(quant_tensor.int() - zero_point).detach()
den = self.input_view_shape_impl(
torch.round(quant_tensor.value_ / quant_tensor.scale_) - zero_point).detach()
mask = (num != 0.) & (den != 0.)
if self.stats_reduce_dim is None:
candidate = masked_median(num / den, mask)
else:
candidate = masked_median(num / den, mask, dim=self.stats_reduce_dim)
candidate = self.input_view_shape_impl(candidate)
candidate = masked_median(
num / den, mask, dim=self.stats_reduce_dim, keepdim=True)
# candidate = self.input_view_shape_impl(candidate)
candidate = self.clamp_min_ste(candidate)
bit_width = self.msb_clamp_bit_width_impl()
int_threshold = self.int_scaling_impl(bit_width)
Expand All @@ -643,7 +645,7 @@ def optimize(self, x):
x_view = self.input_view_shape_impl(x)

init = self.hqo_init_op(x_view).detach()
best_candidate = self.parameter_search(init, x)
best_candidate = self.parameter_search(init, x_view)

# Save for evaluation by other modules (e.g. zp) invoking local loss mode
self.internal_candidate = best_candidate.detach()
Expand Down
16 changes: 10 additions & 6 deletions src/brevitas/quant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,11 +529,13 @@ def zero_point_stats_impl():
return this.stats_impl_zp

@value
def inner_stats_input_view_shape_impl(scaling_per_output_channel):
if scaling_per_output_channel:
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
else:
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK


class HQOScale(ExtendedInjector):
Expand All @@ -547,11 +549,13 @@ def scaling_stats_impl():
return this.stats_impl_scale

@value
def inner_stats_input_view_shape_impl(scaling_per_output_channel):
if scaling_per_output_channel:
def inner_stats_input_view_shape_impl(scaling_per_output):
if scaling_per_output == ScalingPerOutputType.CHANNEL:
return StatsInputViewShapeImpl.OVER_OUTPUT_CHANNELS
else:
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return StatsInputViewShapeImpl.OVER_TENSOR
elif scaling_per_output == ScalingPerOutputType.GROUP:
return StatsInputViewShapeImpl.OVER_SUBCHANNEL_BLOCK


class HQOAsymmetricScale(HQOScale):
Expand Down
42 changes: 24 additions & 18 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,33 @@ class ShiftedUintWeightAsymmetricGroupQuant(ShiftedUint8WeightPerChannelFloat):
scaling_per_output_type = ScalingPerOutputType.GROUP


class ShiftedUintWeightAsymmetricGroupQuantHQO(IntWeightSymmetricGroupQuant):
from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO


class ShiftedUintWeightAsymmetricGroupQuantHQO(Int8WeightPerChannelFloatHQO):
"""
Block / group / vector signed asymmetric weight quantizer with float scales and zero-points.
"""
zero_point_input_shape = this.scaling_input_shape
reshaped_zero_point_shape = this.reshaped_scaling_shape
zero_point_shape = this.scaling_shape
inner_expanded_zero_point_shape = this.expanded_scaling_shape
expanded_zero_point_shape = this.expanded_scaling_shape
zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
zero_point_stats_input_concat_dim = 0
zero_point_impl = ExpandReshapeZeroPointWrapper
zero_point_stats_impl = HalfQuadraticOptimizerZeroPoint
hqo_init_op_zp = NegativeMinOrZero
scaling_stats_impl = AbsMinMax
keepdim = True
# zero-point is converted to a parameter right away
wrapped_zero_point_impl = ParameterFromStatsFromParameterZeroPoint
quantize_zero_point = False
signed = False
inner_stats_input_view_shape_impl = torch.nn.Identity()
proxy_class = GroupwiseWeightQuantProxyFromInjector
scaling_per_output_type = ScalingPerOutputType.GROUP

# zero_point_input_shape = this.scaling_input_shape
# reshaped_zero_point_shape = this.reshaped_scaling_shape
# zero_point_shape = this.scaling_shape
# # inner_expanded_zero_point_shape = this.expanded_scaling_shape
# # expanded_zero_point_shape = this.expanded_scaling_shape
# zero_point_stats_input_view_shape_impl = this.scaling_stats_input_view_shape_impl
# zero_point_stats_input_concat_dim = 0
# # zero_point_impl = ExpandReshapeZeroPointWrapper
# zero_point_stats_impl = HalfQuadraticOptimizerZeroPoint
# hqo_init_op_zp = NegativeMinOrZero
# scaling_stats_impl = AbsMinMax
# keepdim = True
# # zero-point is converted to a parameter right away
# zero_point_impl = ParameterFromStatsFromParameterZeroPoint
# quantize_zero_point = False
# signed = False
# inner_stats_input_view_shape_impl = torch.nn.Identity()


class Int8DynamicActPerTensorFloat(DynamicActProxyMixin, Int8ActPerTensorFloat):
Expand Down
22 changes: 11 additions & 11 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import re

import numpy as np
from optimum.amd.brevitas.accelerate_utils import offload_model
from optimum.amd.brevitas.accelerate_utils import remove_hooks
from optimum.exporters.onnx import onnx_export_from_model
# from optimum.amd.brevitas.accelerate_utils import offload_model
# from optimum.amd.brevitas.accelerate_utils import remove_hooks
# from optimum.exporters.onnx import onnx_export_from_model
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
Expand Down Expand Up @@ -52,8 +52,8 @@
parser.add_argument(
'--weight-param-method',
type=str,
default='stats',
choices=['stats', 'mse'],
default='hqo',
choices=['stats', 'mse', 'hqo'],
help='How scales/zero-point are determined. Default: stats.')
parser.add_argument(
'--weight-scale-precision',
Expand All @@ -64,7 +64,7 @@
parser.add_argument(
'--weight-quant-type',
type=str,
default='sym',
default='asym',
choices=['sym', 'asym'],
help='Weight quantization type. Default: asym.')
parser.add_argument(
Expand Down Expand Up @@ -302,17 +302,17 @@ def main():
if args.weight_equalization:
print("Apply weight equalization...")
# In case of float16 model, we need to offload to account for missing ops
model = offload_model(model)
# model = offload_model(model)
apply_weight_equalization(model)
remove_hooks(model)
# remove_hooks(model)
print("Weight equalization applied.")

if args.act_equalization is not None:
offload_model(model)
# offload_model(model)
print("Apply act equalization (SmoothQuant)...")
apply_act_equalization(model, args.act_equalization, calibration_loader)
print("Act equalization applied.")
remove_hooks(model)
# remove_hooks(model)

if not args.no_quantize:
print("Applying model quantization...")
Expand Down Expand Up @@ -369,7 +369,7 @@ def main():
print("Model eval...")
ppl = model_eval(model, val_data, args.seqlen)
print(f"C4 perplexity: {ppl}")
remove_hooks(model)
# remove_hooks(model)

if args.export_target:
print(f"Export to {args.export_target}")
Expand Down
16 changes: 12 additions & 4 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ def input_zp_stats_type():
torch.cuda.empty_cache()
if args.bias_correction:
print("Applying bias correction")
for m in pipe.unet.modules():
if hasattr(m, 'cache_inference_quant_weight'):
m.cache_inference_quant_weight = True
with torch.no_grad(), bias_correction_mode(pipe.unet):
run_val_inference(
pipe,
Expand Down Expand Up @@ -520,7 +523,13 @@ def input_zp_stats_type():
if args.use_mlperf_inference:
print(f"Computing accuracy with MLPerf pipeline")
compute_mlperf_fid(
args.model, args.path_to_coco, pipe, args.prompt, output_dir, not args.vae_fp16_fix)
args.model,
args.path_to_coco,
pipe,
args.prompt,
output_dir,
device=args.device,
vae_force_upcast=not args.vae_fp16_fix)
else:
print(f"Computing accuracy on default prompt")
testing_prompts = TESTING_PROMPTS[:args.prompt]
Expand Down Expand Up @@ -552,7 +561,6 @@ def input_zp_stats_type():
fid.update(float_images_values, real=True)
fid.update(quant_images_values, real=False)
print(f"FID: {float(fid.compute())}")

if args.export_target:
# Move to cpu and to float32 to enable CPU export
if args.export_cpu_float32:
Expand Down Expand Up @@ -689,7 +697,7 @@ def input_zp_stats_type():
parser.add_argument(
'--conv-input-bit-width',
type=int,
default=0,
default=8,
help='Input bit width. Default: 0 (not quantized)')
parser.add_argument(
'--act-eq-alpha',
Expand All @@ -699,7 +707,7 @@ def input_zp_stats_type():
parser.add_argument(
'--linear-input-bit-width',
type=int,
default=0,
default=8,
help='Input bit width. Default: 0 (not quantized).')
parser.add_argument(
'--linear-output-bit-width',
Expand Down

0 comments on commit 5fef0b1

Please sign in to comment.