Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (minifloat): cleanup minifloat impl #922

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import brevitas
from brevitas.core.utils import StatelessBuffer
from brevitas.function import tensor_clamp
from brevitas.function.ops import max_float


class TensorClamp(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -90,39 +91,62 @@ class FloatClamp(brevitas.jit.ScriptModule):

def __init__(
self,
max_value: float,
tensor_clamp_impl: Module,
signed: bool,
inf_values: Optional[Tuple[str]] = None,
saturating: bool = True) -> None:
nan_values: Optional[Tuple[str]] = None,
max_available_float: Optional[Tensor] = None,
saturating: bool = True,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None) -> None:
super(FloatClamp, self).__init__()

self.tensor_clamp_impl = tensor_clamp_impl

self.max_value = StatelessBuffer(torch.tensor(max_value))
self.saturating = saturating
self.has_inf_values = bool(inf_values)
self.inf_values = inf_values
self.nan_values = nan_values
self.signed = signed

if max_available_float:
max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype)
self.max_available_float = StatelessBuffer(max_available_float)
else:
self.max_available_float = None

@brevitas.jit.script_method
def forward(self, x: Tensor):
def forward(
self,
x: Tensor,
exponent_bit_width: Tensor,
mantissa_bit_width: Tensor,
exponent_bias: Tensor):
inf_mask = x.isinf()
p_max_val_mask = x > self.max_value()
n_max_val_mask = -x > self.max_value()
max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
p_max_val_mask = x > max_value
n_max_val_mask = -x > max_value
min_float = torch.tensor(0.) if not self.signed else -max_value

# first clamp everything to +- max_value, basically the saturating case
x = self.tensor_clamp_impl(x, min_val=-self.max_value(), max_val=self.max_value())
x = self.tensor_clamp_impl(x, min_val=min_float, max_val=max_value)

if not self.saturating:
# if non-saturating, we need to map values greater than max_val to nan or inf
if self.has_inf_values:
if self.inf_values:
# we have inf values, so we set abs values > max_value to +- inf, and leave inf at inf
x[p_max_val_mask] = torch.tensor(float('inf'))
x[n_max_val_mask] = torch.tensor(float('-inf'))
else:
elif self.nan_values:
# no inf values, so we need to map them to NaN
full_max_val_mask = torch.logical_or(p_max_val_mask, n_max_val_mask)
x[full_max_val_mask] = torch.tensor(float('nan'))

# we also map the inf values to NaN in this case
x[inf_mask] = torch.tensor(float('nan'))
else:
raise RuntimeError(
"Clamping is not saturaing, but neither `inf_values` nor `nan_values` is specified"
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
)

return x
16 changes: 7 additions & 9 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -46,8 +46,7 @@ def __init__(
(torch.tensor(float(mantissa_bit_width), device=device, dtype=dtype)))
self.exponent_bias = StatelessBuffer(
torch.tensor(float(exponent_bias), device=device, dtype=dtype))
self.fp_max_val = StatelessBuffer(
max_float(self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias()))

self.fp_internal_scale_min = StatelessBuffer(
1. - self.exponent_bias() - self.mantissa_bit_width())
if float_scaling_impl is None:
Expand All @@ -69,14 +68,12 @@ def internal_scale(self, x):

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale = self.scaling_impl(x) / self.float_scaling_impl(x)
scale_impl_value = self.scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scale_impl_value / self.float_scaling_impl(x)
scaled_x = x / scale
internal_scale = self.internal_scale(scaled_x)
val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale)
if self.signed:
val_fp_quant = torch.clip(val_fp_quant, -1. * self.fp_max_val(), self.fp_max_val())
else:
val_fp_quant = torch.clip(val_fp_quant, 0., self.fp_max_val())
return val_fp_quant, scale

@brevitas.jit.script_method
Expand All @@ -87,7 +84,8 @@ def dequantize(self, y, scale):
def forward(self, x):
y, scale = self.quantize(x)
# after quantizing, clamp to special cases like NaN/inf if they are set
y = self.float_clamp_impl(y)
y = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.bit_width()
1 change: 1 addition & 0 deletions src/brevitas/core/scaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from brevitas.core.stats import SCALAR_SHAPE

from .float_scaling import FloatScaling
from .int_scaling import IntScaling
from .int_scaling import PowerOfTwoIntScaling
from .pre_scaling import AccumulatorAwareParameterPreScaling
Expand Down
32 changes: 21 additions & 11 deletions src/brevitas/core/scaling/float_scaling.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Optional
from typing import List, Optional, Tuple

import torch
from torch import Tensor
Expand All @@ -15,18 +15,28 @@ class FloatScaling(brevitas.jit.ScriptModule):

def __init__(
self,
exponent_bit_width: int,
mantissa_bit_width: int,
exponent_bias: int,
max_available_float: Optional[float] = None,
inf_values: Optional[Tuple[str]] = None,
nan_values: Optional[Tuple[str]] = None,
saturating: bool = True,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None):
super(FloatScaling, self).__init__()
exponent_bit_width = torch.tensor(exponent_bit_width, device=device, dtype=dtype)
mantissa_bit_width = torch.tensor(mantissa_bit_width, device=device, dtype=dtype)
exponent_bias = torch.tensor(exponent_bias, device=device, dtype=dtype)
self.max_float_val = StatelessBuffer(
max_float(exponent_bit_width, mantissa_bit_width, exponent_bias))
self.inf_values = inf_values
self.nan_values = nan_values
self.saturating = saturating

if max_available_float:
max_available_float = torch.tensor(max_available_float, device=device, dtype=dtype)
self.max_available_float = StatelessBuffer(max_available_float)
else:
self.max_available_float = None

@brevitas.jit.script_method
def forward(self, input: torch.Tensor) -> Tensor:
return self.max_float_val()
def forward(
self, exponent_bit_width: Tensor, mantissa_bit_width: Tensor,
exponent_bias: Tensor) -> Tensor:
max_value = max_float(exponent_bit_width, mantissa_bit_width, exponent_bias)
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
return max_value
3 changes: 2 additions & 1 deletion src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
Implementation of various core operations often performed as part of quantization.
The implemented functions adheres to the restriction imposed by Pytorch 1.1.0's TorchScript compiler.
"""
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
from torch import Tensor

import brevitas
from brevitas.utils.float_quant_utils import get_minifloat_value


@brevitas.jit.script
Expand Down
17 changes: 0 additions & 17 deletions src/brevitas/quant/experimental/float_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from brevitas.quant.solver import ActQuantSolver
from brevitas.quant.solver import WeightQuantSolver
from brevitas.quant.solver.common import SolveTensorQuantFloatToIntImplFromEnum
from brevitas.utils.float_quant_utils import get_max_value


class FloatBase(SolveTensorQuantFloatToIntImplFromEnum):
Expand All @@ -27,22 +26,6 @@ class FloatBase(SolveTensorQuantFloatToIntImplFromEnum):
def exponent_bias(exponent_bit_width):
return 2 ** (exponent_bit_width - 1) - 1

@value
def max_value(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values=None,
inf_values=None,
saturating=True):
return get_max_value(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values,
inf_values,
saturating)


class FloatWeightBase(FloatBase):
proxy_class = WeightQuantProxyFromInjector
Expand Down
27 changes: 27 additions & 0 deletions src/brevitas/quant/experimental/float_quant_ocp.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,51 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from dependencies import value

from brevitas.quant.base import MSESymmetricScale
from brevitas.quant.experimental.float_base import FloatActBase
from brevitas.quant.experimental.float_base import FloatWeightBase
from brevitas.quant.experimental.float_base import Fp8e4m3Mixin
from brevitas.quant.experimental.float_base import Fp8e5m2Mixin
from brevitas.quant.experimental.float_base import ScaledFloatActBase
from brevitas.quant.experimental.float_base import ScaledFloatWeightBase
from brevitas.utils.float_quant_utils import get_max_available_float


class Fp8e4m3OCPMixin(Fp8e4m3Mixin):
nan_values = (('111',))
inf_values = None

@value
def max_available_float(
exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values,
saturating):
return get_max_available_float(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values,
inf_values,
saturating)


class Fp8e5m2OCPMixin(Fp8e5m2Mixin):
nan_values = ('01', '11', '10')
inf_values = (('00',))

@value
def max_available_float(
exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values,
saturating):
return get_max_available_float(
exponent_bit_width,
mantissa_bit_width,
exponent_bias,
nan_values,
inf_values,
saturating)


class Fp8e4m3OCPWeight(Fp8e4m3OCPMixin, FloatWeightBase):
"""
Expand Down
14 changes: 11 additions & 3 deletions src/brevitas/utils/float_quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Tuple

import torch


def mantissa_bits_to_float(bits: str, frexp_compatible: bool = False) -> float:
Expand All @@ -21,11 +24,16 @@ def get_minifloat_value(exponent: str, mantissa: str, exponent_bias: int) -> flo
"""
exponent_value = int(exponent, 2)
mantissa_value = mantissa_bits_to_float(mantissa)
return 2 ** (exponent_value - exponent_bias) * mantissa_value
return (2 ** (exponent_value - exponent_bias)) * mantissa_value


def get_max_value(
exponent_bit_width, mantissa_bit_width, exponent_bias, nan_values, inf_values, saturating):
def get_max_available_float(
exponent_bit_width: torch.Tensor,
mantissa_bit_width: torch.Tensor,
exponent_bias: torch.Tensor,
nan_values: Tuple[str],
inf_values: Tuple[str],
saturating: bool) -> torch.Tensor:
# Idea: take the smallest NaN/inf value, set max_value to the next smaller one
# inf without NaN not possible
if inf_values is None and nan_values is None:
Expand Down
25 changes: 21 additions & 4 deletions tests/brevitas/core/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from hypothesis import given
import pytest
import torch

from brevitas.function.ops import max_float
from brevitas.quant.experimental.float import Fp8e4m3Weight
from brevitas.quant.experimental.float import Fp8e5m2Weight
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeight
Expand All @@ -20,26 +22,41 @@
'minifloat, expected_max_val',
((format, max_val) for format, max_val in FORMAT_MAXVAL_MAP.items()))
def test_max_value(minifloat, expected_max_val):
max_val = minifloat.float_clamp_impl.max_value()
max_val = max_float(
torch.tensor(minifloat.exponent_bit_width, dtype=torch.float32),
torch.tensor(minifloat.mantissa_bit_width, dtype=torch.float32),
torch.tensor(minifloat.exponent_bias, dtype=torch.float32))
max_available_float = minifloat.float_clamp_impl.max_available_float
max_val = max_val if max_available_float is None else torch.min(max_val, max_available_float())

assert expected_max_val == max_val


@given(inp=float_tensor_random_shape_st())
def test_float_clamp(inp, fp8_clamp):
max_val = fp8_clamp.float_clamp_impl.max_value()

max_val = max_float(
torch.tensor(fp8_clamp.exponent_bit_width, dtype=torch.float32),
torch.tensor(fp8_clamp.mantissa_bit_width, dtype=torch.float32),
torch.tensor(fp8_clamp.exponent_bias, dtype=torch.float32))
max_available_float = fp8_clamp.float_clamp_impl.max_available_float
max_val = max_val if max_available_float is None else torch.min(max_val, max_available_float())
# get values that exceed max_val
over_limit_mask = inp.abs() > max_val

# clamp inp
inp = fp8_clamp.float_clamp_impl(inp)
inp = fp8_clamp.float_clamp_impl(
inp,
torch.tensor(fp8_clamp.exponent_bit_width),
torch.tensor(fp8_clamp.mantissa_bit_width),
torch.tensor(fp8_clamp.exponent_bias))

if fp8_clamp.float_clamp_impl.saturating:
# should be clamped to +- max val
assert (inp[over_limit_mask].abs() == max_val).all()
else:
# if inf_values, over limit mask should now be all inf
if fp8_clamp.float_clamp_impl.has_inf_values:
if fp8_clamp.float_clamp_impl.inf_values:
# all values exceeding max_val should be inf
assert inp[over_limit_mask].isinf().all()
else:
Expand Down
Loading
Loading