Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 31, 2024
1 parent 5bc480f commit 32f9ffe
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 51 deletions.
33 changes: 20 additions & 13 deletions src/brevitas/export/torch/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,18 @@ def __init__(self) -> None:
self.symbolic_kwargs = {}

def dequantize_fn(self, x, scale, zero_point, axis):
if axis is None:
axis = -1
# cast zero_point to float, otherwise if both x
# and zero_point are uint (as in asym quant)
# uint - uint can lead to errors. Don't cast x to float
# as the main float datatype might not be float32 (e.g float16)
if isinstance(zero_point, torch.Tensor):
zero_point = zero_point.to(torch.float)
else:
zero_point = float(zero_point)
return (x - zero_point) * scale
return torch.ops.brevitas.dequantize(x, scale, zero_point, axis)
# if isinstance(zero_point, torch.Tensor):
# zero_point = zero_point.to(torch.float)
# else:
# zero_point = float(zero_point)
# return (x - zero_point) * scale

def cast_fn(self, x, dtype):
return x.type(dtype)
Expand All @@ -51,7 +54,7 @@ def flatten_dequantize_params(self):

@property
def itemize_quantize_scalar_params(self):
return True
return False

def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'
Expand All @@ -67,15 +70,15 @@ class TorchQCDQCastMixin(QMixin, TorchCDQCastMixin, ABC):

@classmethod
def int8_dtype(cls):
return torch.qint8
return torch.int8

@classmethod
def uint8_dtype(cls):
return torch.quint8
return torch.uint8

@classmethod
def int32_dtype(cls):
return torch.qint32
return torch.int32

def validate(self, module):
super().validate(module)
Expand All @@ -85,10 +88,14 @@ def validate(self, module):

def quantize_fn(self, x, scale, zero_point, dtype, axis):
if axis is None:
y = torch.quantize_per_tensor(x, scale, zero_point, dtype)
else:
y = torch.quantize_per_channel(x, scale, zero_point, axis, dtype)
return y.int_repr()
axis = -1
return torch.ops.brevitas.quantize(x, scale, zero_point, axis)

# if axis is None:
# y = torch.quantize_per_tensor(x, scale, zero_point, dtype)
# else:
# y = torch.quantize_per_channel(x, scale, zero_point, axis, dtype)
# return y.int_repr()


class TorchQCDQHandler(BaseHandler):
Expand Down
14 changes: 2 additions & 12 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ def quantize_model(
input_quant_type=None,
input_quant_granularity=None,
input_group_size=None,
input_stats_op='percentile',
quantize_input_zero_point=False,
quantize_embedding=False,
use_ocp=False,
Expand All @@ -201,7 +200,7 @@ def quantize_model(
weight_float_format = {
'exponent_bit_width': int(weight_quant_format[1]),
'mantissa_bit_width': int(weight_quant_format[3])}
if ocp_weight_format:
if use_ocp:
weight_quant_format += '_ocp'
ocp_weight_format = weight_quant_format
weight_quant_format = 'float'
Expand All @@ -211,7 +210,7 @@ def quantize_model(
input_float_format = {
'exponent_bit_width': int(input_quant_format[1]),
'mantissa_bit_width': int(input_quant_format[3])}
if ocp_weight_format:
if use_ocp:
input_quant_format += '_ocp'
ocp_input_format = input_quant_format
input_quant_format = 'float'
Expand Down Expand Up @@ -255,15 +254,6 @@ def quantize_model(
if input_kwargs is None:
input_kwargs = dict()

if input_stats_op == 'minmax':
if input_quant_type == 'asym':
input_scaling_stats_op = StatsOp.MIN_MAX
zero_point_stats_impl = NegativeMinOrZero
input_kwargs['zero_point_stats_impl'] = zero_point_stats_impl
else:
input_scaling_stats_op = StatsOp.MAX
input_kwargs['scaling_stats_op'] = input_scaling_stats_op

input_quant = input_quant.let(**input_kwargs)
sym_input_quant = sym_input_quant.let(**input_kwargs)
linear_input_quant = linear_input_quant.let(**input_kwargs)
Expand Down
50 changes: 43 additions & 7 deletions src/brevitas_examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--calibration-prompt-path CALIBRATION_PROMPT_PATH]
[--checkpoint-name CHECKPOINT_NAME]
[--load-checkpoint LOAD_CHECKPOINT]
[--path-to-latents PATH_TO_LATENTS] [--resolution RESOLUTION]
[--path-to-latents PATH_TO_LATENTS]
[--path-to-coco PATH_TO_COCO] [--resolution RESOLUTION]
[--guidance-scale GUIDANCE_SCALE]
[--calibration-steps CALIBRATION_STEPS]
[--output-path OUTPUT_PATH | --no-output-path]
Expand All @@ -77,7 +78,8 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH]
[--weight-param-method {stats,mse}]
[--input-param-method {stats,mse}]
[--input-stats-op {minmax,percentile}]
[--input-scale-stats-op {minmax,percentile}]
[--input-zp-stats-op {minmax,percentile}]
[--weight-scale-precision {float_scale,po2_scale}]
[--input-scale-precision {float_scale,po2_scale}]
[--weight-quant-type {sym,asym}]
Expand All @@ -89,11 +91,16 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--input-scale-type {static,dynamic}]
[--weight-group-size WEIGHT_GROUP_SIZE]
[--quantize-weight-zero-point | --no-quantize-weight-zero-point]
[--quantize-input-zero-point | --no-quantize-input-zero-point]
[--export-cuda-float16 | --no-export-cuda-float16]
[--use-mlperf-inference | --no-use-mlperf-inference]
[--use-ocp | --no-use-ocp]
[--use-negative-prompts | --no-use-negative-prompts]
[--dry-run | --no-dry-run]
[--quantize-time-emb | --no-quantize-time-emb]
[--quantize-conv-in | --no-quantize-conv-in]
[--quantize-input-time-emb | --no-quantize-input-time-emb]
[--quantize-input-conv-in | --no-quantize-input-conv-in]

Stable Diffusion quantization

Expand All @@ -113,21 +120,24 @@ options:
--calibration-prompt-path CALIBRATION_PROMPT_PATH
Path to calibration prompt
--checkpoint-name CHECKPOINT_NAME
Name to use to store the checkpoint. If not provided,
no checkpoint is saved.
Name to use to store the checkpoint in the output dir.
If not provided, no checkpoint is saved.
--load-checkpoint LOAD_CHECKPOINT
Path to checkpoint to load. If provided, PTQ
techniques are skipped.
--path-to-latents PATH_TO_LATENTS
Load pre-defined latents. If not provided, they are
generated based on an internal seed.
--path-to-coco PATH_TO_COCO
Path to MLPerf compliant Coco dataset. Used when the
--use-mlperf flag is set. Default: None
--resolution RESOLUTION
Resolution along height and width dimension. Default:
512.
--guidance-scale GUIDANCE_SCALE
Guidance scale.
--calibration-steps CALIBRATION_STEPS
Percentage of steps used during calibration
Steps used during calibration
--output-path OUTPUT_PATH
Path where to generate output folder.
--no-output-path Disable Path where to generate output folder.
Expand Down Expand Up @@ -169,8 +179,12 @@ options:
How scales/zero-point are determined. Default: stats.
--input-param-method {stats,mse}
How scales/zero-point are determined. Default: stats.
--input-stats-op {minmax,percentile}
Define what statics op to use . Default: minmax.
--input-scale-stats-op {minmax,percentile}
Define what statics op to use for input scale.
Default: minmax.
--input-zp-stats-op {minmax,percentile}
Define what statics op to use for input zero point.
Default: minmax.
--weight-scale-precision {float_scale,po2_scale}
Whether scale is a float value or a po2. Default:
float_scale.
Expand Down Expand Up @@ -203,6 +217,10 @@ options:
Enable Quantize weight zero-point. Default: Enabled
--no-quantize-weight-zero-point
Disable Quantize weight zero-point. Default: Enabled
--quantize-input-zero-point
Enable Quantize input zero-point. Default: Enabled
--no-quantize-input-zero-point
Disable Quantize input zero-point. Default: Enabled
--export-cuda-float16
Enable Export FP16 on CUDA. Default: Disabled
--no-export-cuda-float16
Expand All @@ -227,5 +245,23 @@ options:
calibration. Default: Disabled
--no-dry-run Disable Generate a quantized model without any
calibration. Default: Disabled
--quantize-time-emb Enable Quantize time embedding layers. Default: True
--no-quantize-time-emb
Disable Quantize time embedding layers. Default: True
--quantize-conv-in Enable Quantize first conv layer. Default: True
--no-quantize-conv-in
Disable Quantize first conv layer. Default: True
--quantize-input-time-emb
Enable Quantize input to time embedding layers.
Default: Disabled
--no-quantize-input-time-emb
Disable Quantize input to time embedding layers.
Default: Disabled
--quantize-input-conv-in
Enable Quantize input to first conv layer. Default:
Enabled
--no-quantize-input-conv-in
Disable Quantize input to first conv layer. Default:
Enabled

```
Loading

0 comments on commit 32f9ffe

Please sign in to comment.