Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 29, 2024
1 parent bf08af4 commit 6a1fd4a
Showing 1 changed file with 11 additions and 27 deletions.
38 changes: 11 additions & 27 deletions tests/brevitas/core/test_quant_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,17 @@
# pylint: disable=missing-function-docstring, redefined-outer-name

import struct
from typing import Tuple

from brevitas.nn.quant_linear import QuantLinear
from tests.brevitas.hyp_helper import float_tensor_nz_st

try:
from mx.mx_ops import _quantize_mx as mx
except:
mx = None
from hypothesis import given
import pytest_cases
import torch

from brevitas.nn.quant_activation import QuantIdentity
from brevitas.nn.quant_linear import QuantLinear
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from tests.brevitas.hyp_helper import float_tensor_nz_st

torch.manual_seed(0)

Expand All @@ -39,7 +35,7 @@ def scalar_to_string(val: float, spaced: bool) -> str:


# debug utility
def check_bits(val: torch.Tensor | float, mbits: int) -> (bool, int):
def check_bits(val: torch.Tensor | float, mbits: int) -> Tuple[bool, int]:
""" return (too many precision bits, lowest mantissa bit) """
strings = to_string(val, spaced=False)
if isinstance(strings, str):
Expand Down Expand Up @@ -132,21 +128,15 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False):
# offset = (shared - exp - (self.emax - self.emin)).clamp_min(0)
# The shift left will be mbits - offset - exp, which for negative exponents gets them into the right range.
maxval = self.maxval * (shared - self.emax).exp2() # scale maxval per tile
return ((tensor * scale).round() / scale).clamp(-maxval, maxval), scale
return ((tensor * scale).round() / scale).clamp(-maxval, maxval)


MAP = {
"fp8_e4m3": (4, 3),
"fp8_e5m2": (5, 2),
"fp6_e2m3": (2, 3),
"fp6_e3m2": (3, 2),
"fp4_e2m1": (2, 1)}
MAP = {"e4m3": (4, 3), "e5m2": (5, 2), "e2m3": (2, 3), "e3m2": (3, 2), "e2m1": (2, 1)}


@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10))
@pytest_cases.parametrize('bit_widths', list(MAP.keys()))
def test_act_mx(inp, bit_widths):
# print("-------------------------------------------")
torch.set_printoptions(precision=12, sci_mode=False)
exp, mant = MAP[bit_widths]

Expand All @@ -160,21 +150,18 @@ def test_act_mx(inp, bit_widths):
act_quant.eval()
x = inp

quantizer = MXFP(bit_widths)

qx = act_quant(x)

if mx is None:
print("Install microscaling library, --no-deps flag recommended")
else:
y = mx(
x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False)
y = quantizer.quantize(x)
assert torch.allclose(qx.value, y, atol=1e-8)


@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10))
@pytest_cases.parametrize('bit_widths', list(MAP.keys()))
@pytest_cases.parametrize('weight_quant_type', ['stats', 'parameter_from_stats'])
def test_weight_mx(inp, bit_widths, weight_quant_type):
# print("-------------------------------------------")
torch.set_printoptions(precision=12, sci_mode=False)
exp, mant = MAP[bit_widths]
weight_quant = QuantLinear(
Expand All @@ -190,14 +177,11 @@ def test_weight_mx(inp, bit_widths, weight_quant_type):
x = inp
weight_quant.weight.data = x
weight_quant.weight_quant.init_tensor_quant()
quantizer = MXFP(bit_widths)

qx_weight = weight_quant.quant_weight()
qx_weight_two = weight_quant.quant_weight()

if mx is None:
print("Install microscaling library, --no-deps flag recommended")
else:
y = mx(
x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False)
y = quantizer.quantize(x)
assert torch.allclose(qx_weight.value, y, atol=1e-8)
assert torch.allclose(qx_weight_two.value, y, atol=1e-8)

0 comments on commit 6a1fd4a

Please sign in to comment.