Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add usage of int8-mixed-bf16 quantization with X86InductorQuantizer #2668

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions prototype_source/pt2e_quant_ptq_x86_inductor.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,37 @@ After we get the quantized model, we will further lower it to the inductor backe

::

optimized_model = torch.compile(converted_model)
with torch.no_grad():
optimized_model = torch.compile(converted_model)

# Running some benchmark
optimized_model(*example_inputs)

# Running some benchmark
optimized_model(*example_inputs)
In a more advanced scenario, int8-mixed-bf16 quantization comes into play. In this instance,
a Convolution or GEMM operator produces BFloat16 output data type instead of Float32 in the absence
of a subsequent quantization node. Subsequently, the BFloat16 tensor seamlessly propagates through
subsequent pointwise operators, effectively minimizing memory usage and potentially enhancing performance.
leslie-fang-intel marked this conversation as resolved.
Show resolved Hide resolved
The utilization of this feature mirrors that of regular BFloat16 Autocast, as simple as wrapping the
script within the BFloat16 Autocast context.

::

with torch.autocast(device_type="cpu", dtype=torch.bfloat16, enabled=True), torch.no_grad():
# Turn on Autocast to use int8-mixed-bf16 quantization. After lowering into Inductor CPP Backend,
# For operators such as QConvolution and QLinear:
# * The input data type is consistently defined as int8, attributable to the presence of a pair
of quantization and dequantization nodes inserted at the input.
# * The computation precision remains at int8.
# * The output data type may vary, being either int8 or BFloat16, contingent on the presence
# of a pair of quantization and dequantization nodes at the output.
# For non-quantizable pointwise operators, the data type will be inherited from the previous node,
# potentially resulting in a data type of BFloat16 in this scenario.
# For quantizable pointwise operators such as QMaxpool2D, it continues to operate with the int8
leslie-fang-intel marked this conversation as resolved.
Show resolved Hide resolved
# data type for both input and output.
optimized_model = torch.compile(converted_model)

# Running some benchmark
optimized_model(*example_inputs)

Put all these codes together, we will have the toy example code.
Please note that since the Inductor ``freeze`` feature does not turn on by default yet, run your example code with ``TORCHINDUCTOR_FREEZING=1``.
Expand Down
Loading