Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (float): adding new attributes to proxy and quant tensor #1072

Merged
merged 9 commits into from
Oct 28, 2024
7 changes: 7 additions & 0 deletions src/brevitas/proxy/float_parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

class WeightFloatQuantProxyFromInjectorBase(WeightQuantProxyFromInjectorBase, ABC):

def bit_width(self):
if not self.is_quant_enabled:
return None
x = self.__call__(self.tracked_parameter_list[0])
bit_width = x.mantissa_bit_width + x.exponent_bit_width + 1
return bit_width

def scale(self):
if not self.is_quant_enabled:
return None
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/quant_tensor/float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def _pre_round_float_value(self):
minifloat_value = minifloat_value / int_scale
return minifloat_value

def int(self):
i-colbert marked this conversation as resolved.
Show resolved Hide resolved
fx_value = torch.round(self._pre_round_float_value)
return fx_value

@property
def is_valid(self):
with torch.no_grad():
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def _pre_round_float_value(self):
minifloat_value = minifloat_value / int_scale
return minifloat_value

def int(self):
fx_value = torch.round(self._pre_round_float_value)
return fx_value

@property
def is_valid(self):
with torch.no_grad():
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/utils/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(self, quant_tensor: GroupwiseFloatQuantTensor, metadata_only: bool)
self.shape = quant_tensor.value.shape
if metadata_only:
self.value = None
self.quant_tensor = quant_tensor.set(value=None)
self.quant_tensor = quant_tensor.set(value_=None)
else:
self.quant_tensor = quant_tensor
# torch.compile compatibility
Expand Down Expand Up @@ -146,7 +146,7 @@ def __init__(self, quant_tensor: GroupwiseIntQuantTensor, metadata_only: bool):
self.shape = quant_tensor.value.shape
if metadata_only:
self.value = None
self.quant_tensor = quant_tensor.set(value=None)
self.quant_tensor = quant_tensor.set(value_=None)
else:
self.quant_tensor = quant_tensor
# torch.compile compatibility
Expand Down
94 changes: 90 additions & 4 deletions tests/brevitas/quant_tensor/test_quant_tensor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from enum import Enum

import numpy as np
from packaging import version
import pytest
import pytest_cases
Expand All @@ -13,7 +15,11 @@
from brevitas.quant.experimental.float_quant_ocp import Fp8e5m2OCPActPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant_tensor import FloatQuantTensor
from brevitas.quant_tensor import GroupwiseFloatQuantTensor
from brevitas.quant_tensor import IntQuantTensor
from brevitas.utils.quant_utils import _CachedIO
from brevitas.utils.quant_utils import _CachedIOFloat
from brevitas.utils.quant_utils import _CachedIOGroupwiseFloat


class Operator(Enum):
Expand All @@ -24,14 +30,40 @@ class Operator(Enum):
MATMUL = 4


def to_quant_tensor(input: torch.Tensor) -> IntQuantTensor:
mod = QuantIdentity(bit_width=8, return_quant_tensor=True)
def to_quant_tensor(input: torch.Tensor, bit_width=8) -> IntQuantTensor:
mod = QuantIdentity(bit_width=bit_width, return_quant_tensor=True)
return mod(input)


def to_float_quant_tensor(input: torch.Tensor) -> FloatQuantTensor:
def to_float_quant_tensor(
input: torch.Tensor,
bit_width=8,
exponent_bit_width=4,
mantissa_bit_width=3) -> FloatQuantTensor:
mod = QuantIdentity(
bit_width=8, return_quant_tensor=True, act_quant=Fp8e5m2OCPActPerTensorFloat)
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
return_quant_tensor=True,
act_quant=Fp8e5m2OCPActPerTensorFloat)
return mod(input)


def to_mx_quant_tensor(
input: torch.Tensor,
bit_width=8,
exponent_bit_width=4,
mantissa_bit_width=3,
group_size=32,
group_dim=1) -> FloatQuantTensor:
i-colbert marked this conversation as resolved.
Show resolved Hide resolved
mod = QuantIdentity(
bit_width=bit_width,
group_size=group_size,
group_dim=group_dim,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
return_quant_tensor=True,
act_quant=MXFloat8e4m3Act)
return mod(input)


Expand Down Expand Up @@ -138,3 +170,57 @@ def test_minifloat(quant_class_key_vale):
qx = q(x)
# Check that minifloat doesn't raise error
qx.minifloat()


@pytest.mark.parametrize("metadata_only", [True, False])
def test_int_quant_tensor(metadata_only, bit_width=8):
limit = np.exp2(bit_width) - 1
w = torch.randn(32, 1024)
q = to_quant_tensor(w, bit_width=bit_width)
i = q.int().float()
assert ((i.max() - i.min()) <= limit).all()
# test caching works
cache = _CachedIO(q, metadata_only=metadata_only)
assert cache.bit_width == bit_width


@pytest.mark.parametrize("metadata_only", [True, False])
def test_float_quant_tensor(metadata_only, bit_width=8, exponent_bit_width=4, mantissa_bit_width=3):
assert mantissa_bit_width + exponent_bit_width + 1 == bit_width
limit = (np.exp2(mantissa_bit_width + 1) - 1) * np.exp2(np.exp2(exponent_bit_width) - 2)
w = torch.randn(32, 1024)
q = to_float_quant_tensor(
w,
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width)
# test that the integer API returns fixed point values in the right range
i = q.int().float()
assert ((i.max() - i.min()) <= limit).all()
# test caching works
cache = _CachedIOFloat(q, metadata_only=metadata_only)
assert cache.mantissa_bit_width == mantissa_bit_width
assert cache.exponent_bit_width == exponent_bit_width


@pytest.mark.parametrize("metadata_only", [True, False])
def test_mx_quant_tensor(metadata_only, bit_width=8, exponent_bit_width=4, mantissa_bit_width=3):
assert mantissa_bit_width + exponent_bit_width + 1 == bit_width
limit = (np.exp2(mantissa_bit_width + 1) - 1) * np.exp2(np.exp2(exponent_bit_width) - 2)
w = torch.randn(32, 1024)
q = to_mx_quant_tensor(
w,
bit_width=bit_width,
exponent_bit_width=exponent_bit_width,
mantissa_bit_width=mantissa_bit_width,
group_size=32,
group_dim=1)
# test that the integer API returns fixed point values in the right range
i = q.int().float()
assert ((i.max() - i.min()) <= limit).all()
# test caching works
cache = _CachedIOGroupwiseFloat(q, metadata_only=metadata_only)
assert cache.mantissa_bit_width == mantissa_bit_width
assert cache.exponent_bit_width == exponent_bit_width
assert cache.group_size == 32
assert cache.group_dim == 1
Loading