Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (minifloat): add support for user specified minifloat format #821

Merged
merged 15 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/brevitas/core/function_wrapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from .clamp import ClampMin
from .clamp import FloatClamp
from .clamp import ScalarClamp
from .clamp import TensorClamp
from .learned_round import LearnedRoundSte
Expand Down
53 changes: 53 additions & 0 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
"""
ScriptModule wrappers for various variants of clamping.
"""
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Module

import brevitas
from brevitas.core.utils import StatelessBuffer
from brevitas.function import tensor_clamp


Expand Down Expand Up @@ -73,3 +76,53 @@ def __init__(self, min_val: float) -> None:
@brevitas.jit.script_method
def forward(self, x: Tensor):
return x.clamp_min(self.min_val)


class FloatClamp(brevitas.jit.ScriptModule):
""""
ScriptModule for clamping minifloat formats to their inf/NaN implementations.

Currently, inf/NaN codes have to be encoded through the mantissa.
I.e. setting inf to 1101.111 (E4M3) is not a valid code.
"""

__constants__ = ['saturating', 'has_inf_values']

def __init__(
self,
max_value: float,
tensor_clamp_impl: Module,
inf_values: Optional[Tuple[str]] = None,
saturating: bool = True) -> None:
super(FloatClamp, self).__init__()

self.tensor_clamp_impl = tensor_clamp_impl

self.max_value = StatelessBuffer(torch.tensor(max_value))
self.saturating = saturating
self.has_inf_values = bool(inf_values)

@brevitas.jit.script_method
def forward(self, x: Tensor):
inf_mask = x.isinf()
p_max_val_mask = x > self.max_value()
n_max_val_mask = -x > self.max_value()

# first clamp everything to +- max_value, basically the saturating case
x = self.tensor_clamp_impl(x, min_val=-self.max_value(), max_val=self.max_value())

if not self.saturating:
# if non-saturating, we need to map values greater than max_val to nan or inf
if self.has_inf_values:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = torch.tensor(float('inf'))
x[n_max_val_mask] = torch.tensor(float('-inf'))
else:
# no inf values, so we need to map them to NaN
full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask)
x[full_max_val_mask] = torch.tensor(float('nan'))

# we also map the inf values to NaN in this case
x[inf_mask] = torch.tensor(float('nan'))

return x
8 changes: 5 additions & 3 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def __init__(
signed: bool,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: Optional[int] = None,
exponent_bias: int,
float_clamp_impl: nn.Module,
scaling_impl: Optional[nn.Module] = None,
float_scaling_impl: Optional[nn.Module] = None,
float_to_int_impl: nn.Module = RoundSte(),
Expand All @@ -43,8 +44,6 @@ def __init__(
raise RuntimeError("Mantissa bit width cannot be 0.")
self.mantissa_bit_width = StatelessBuffer(
(torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype)))
if exponent_bias is None:
exponent_bias = 2 ** (exponent_bit_width - 1) - 1
self.exponent_bias = StatelessBuffer(
torch.tensor(float(exponent_bias), device=device, dtype=dtype))
self.fp_max_val = StatelessBuffer(
Expand All @@ -59,6 +58,7 @@ def __init__(
self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype))
self.float_scaling_impl = float_scaling_impl
self.scaling_impl = scaling_impl
self.float_clamp_impl = float_clamp_impl

@brevitas.jit.script_method
def internal_scale(self, x):
Expand Down Expand Up @@ -86,6 +86,8 @@ def dequantize(self, y, scale):
@brevitas.jit.script_method
def forward(self, x):
y, scale = self.quantize(x)
# after quantizing, clamp to special cases like NaN/inf if they are set
y = self.float_clamp_impl(y)
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.bit_width()
1 change: 1 addition & 0 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Implementation of various core operations often performed as part of quantization.
The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler.
"""
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand Down
51 changes: 35 additions & 16 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.core.function_wrapper import FloatClamp
from brevitas.core.function_wrapper import TensorClamp
from brevitas.core.quant.float import FloatQuant
from brevitas.core.scaling.float_scaling import FloatScaling
from brevitas.inject import ExtendedInjector
Expand All @@ -10,22 +12,44 @@
from brevitas.quant.solver import ActQuantSolver
from brevitas.quant.solver import WeightQuantSolver
from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum
from brevitas.utils.float_quant_utils import get_max_value


class FloatWeightBase(SolveTensorQuantFloatToIntImplFromEnum):
proxy_class = WeightQuantProxyFromInjector
class FloatBase(SolveTensorQuantFloatToIntImplFromEnum):
tensor_quant = FloatQuant
signed = True
float_to_int_impl_type = 'round'
scaling_min_val = 1e-10
float_clamp_impl = FloatClamp
tensor_clamp_impl = TensorClamp

@value
def exponent_bias(exponent_bit_width):
return 2 ** (exponent_bit_width - 1) - 1

class FloatActBase(SolveTensorQuantFloatToIntImplFromEnum):
@value
def max_value(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values=None,
inf_values=None,
saturating=True):
return get_max_value(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values,
inf_values,
saturating)


class FloatWeightBase(FloatBase):
proxy_class = WeightQuantProxyFromInjector


class FloatActBase(FloatBase):
proxy_class = ActQuantProxyFromInjector
tensor_quant = FloatQuant
signed = True
float_to_int_impl_type = 'round'
scaling_min_val = 1e-10


class ScaledFloatWeightBase(FloatWeightBase, WeightQuantSolver):
Expand All @@ -43,20 +67,15 @@ class ScaledFloatActBase(FloatActBase, ActQuantSolver):
float_scaling_impl = FloatScaling


class ExponentBiasMixin(ExtendedInjector):

@value
def exponent_bias(exponent_bit_width):
return 2 ** (exponent_bit_width - 1) - 1


class Fp8e4m3Mixin(ExponentBiasMixin):
class Fp8e4m3Mixin(ExtendedInjector):
bit_width = 8
exponent_bit_width = 4
mantissa_bit_width = 3
saturating = True


class Fp8e5m2Mixin(ExponentBiasMixin):
class Fp8e5m2Mixin(ExtendedInjector):
bit_width = 8
exponent_bit_width = 5
mantissa_bit_width = 2
saturating = True
150 changes: 150 additions & 0 deletions src/brevitas/quant/experimental/float_quant_ocp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from brevitas.quant.base import MSESymmetricScale
from brevitas.quant.experimental.float_base import FloatActBase
from brevitas.quant.experimental.float_base import FloatWeightBase
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from brevitas.quant.experimental.float_base import ScaledFloatActBase
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase


class Fp8e4m3OCPMixin(Fp8e4m3Mixin):
nan_values = (('111',))
inf_values = None


class Fp8e5m2OCPMixin(Fp8e5m2Mixin):
nan_values = ('01', '11', '10')
inf_values = (('00',))


class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase):
"""
FP8 signed E3M4 weight quantizer.
"""
pass


class Fp8e5m2OCPWeight(Fp8e5m2OCPMixin, FloatWeightBase):
"""
FP8 signed E5M2 weight quantizer.
"""
pass


class Fp8e4m3OCPAct(Fp8e4m3OCPMixin, FloatActBase):
"""
FP8 signed E4M3 activation quantizer.
"""
pass


class Fp8e5m2OCPAct(Fp8e5m2OCPMixin, FloatActBase):
"""
FP8 signed E5M2 activation quantizer.
"""
pass


class Fp8e4m3OCPWeightPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2OCPWeightPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-tensor absmax-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3OCPActPerTensorFloat(Fp8e4m3OCPMixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2OCPActPerTensorFloat(Fp8e5m2OCPMixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static percentile-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3OCPWeightPerChannelFloat(Fp8e4m3OCPMixin, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel absmax-based scaling.
"""
scaling_per_output_channel = True


class Fp8e5m2OCPWeightPerChannelFloat(Fp8e5m2OCPMixin, ScaledFloatWeightBase):
"""
FP8 signed E5M2 weight quantizer with per-channel absmax-based scaling.
"""
scaling_per_output_channel = True


class Fp8e4m3OCPActPerChannelFloat2d(Fp8e4m3OCPMixin, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static percentile-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2OCPActPerChannelFloat2d(Fp8e5m2OCPMixin, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static percentile-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3OCPActPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_per_output_channel = False


class Fp8e5m2OCPActPerTensorFloatMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-tensor static MSE-based scaling.
"""
scaling_per_output_channel = False


class Fp8e4m3OCPActPerChannelFloat2dMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E4M3 activation quantizer with per-channel static MSE-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e5m2OCPActPerChannelFloat2dMSE(Fp8e5m2OCPMixin, MSESymmetricScale, ScaledFloatActBase):
"""
FP8 signed E5M2 activation quantizer with per-channel static MSE-based scaling.
"""
scaling_per_output_channel = True
scaling_stats_permute_dims = (1, 0, 2, 3)


class Fp8e4m3OCPWeightPerChannelFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-channel MSE-based scaling.
"""
scaling_per_output_channel = True


class Fp8e4m3OCPWeightPerTensorFloatMSE(Fp8e4m3OCPMixin, MSESymmetricScale, ScaledFloatWeightBase):
"""
FP8 signed E3M4 weight quantizer with per-tensor MSE-based scaling.
"""
scaling_per_output_channel = False
Loading
Loading