From 7494e2ed5ad6e31fa89f46f2a28949ad02b855d5 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 13:56:34 +0000 Subject: [PATCH] Update tests --- tests/brevitas/core/test_quant_mx.py | 38 ++++++++-------------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/tests/brevitas/core/test_quant_mx.py b/tests/brevitas/core/test_quant_mx.py index 44f375962..b2ab279d4 100644 --- a/tests/brevitas/core/test_quant_mx.py +++ b/tests/brevitas/core/test_quant_mx.py @@ -4,21 +4,17 @@ # pylint: disable=missing-function-docstring, redefined-outer-name import struct +from typing import Tuple -from brevitas.nn.quant_linear import QuantLinear -from tests.brevitas.hyp_helper import float_tensor_nz_st - -try: - from mx.mx_ops import _quantize_mx as mx -except: - mx = None from hypothesis import given import pytest_cases import torch from brevitas.nn.quant_activation import QuantIdentity +from brevitas.nn.quant_linear import QuantLinear from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight +from tests.brevitas.hyp_helper import float_tensor_nz_st torch.manual_seed(0) @@ -39,7 +35,7 @@ def scalar_to_string(val: float, spaced: bool) -> str: # debug utility -def check_bits(val: torch.Tensor | float, mbits: int) -> (bool, int): +def check_bits(val: torch.Tensor | float, mbits: int) -> Tuple[bool, int]: """ return (too many precision bits, lowest mantissa bit) """ strings = to_string(val, spaced=False) if isinstance(strings, str): @@ -132,21 +128,15 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False): # offset = (shared - exp - (self.emax - self.emin)).clamp_min(0) # The shift left will be mbits - offset - exp, which for negative exponents gets them into the right range. maxval = self.maxval * (shared - self.emax).exp2() # scale maxval per tile - return ((tensor * scale).round() / scale).clamp(-maxval, maxval), scale + return ((tensor * scale).round() / scale).clamp(-maxval, maxval) -MAP = { - "fp8_e4m3": (4, 3), - "fp8_e5m2": (5, 2), - "fp6_e2m3": (2, 3), - "fp6_e3m2": (3, 2), - "fp4_e2m1": (2, 1)} +MAP = {"e4m3": (4, 3), "e5m2": (5, 2), "e2m3": (2, 3), "e3m2": (3, 2), "e2m1": (2, 1)} @given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10)) @pytest_cases.parametrize('bit_widths', list(MAP.keys())) def test_act_mx(inp, bit_widths): - # print("-------------------------------------------") torch.set_printoptions(precision=12, sci_mode=False) exp, mant = MAP[bit_widths] @@ -160,13 +150,11 @@ def test_act_mx(inp, bit_widths): act_quant.eval() x = inp + quantizer = MXFP(bit_widths) + qx = act_quant(x) - if mx is None: - print("Install microscaling library, --no-deps flag recommended") - else: - y = mx( - x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False) + y = quantizer.quantize(x) assert torch.allclose(qx.value, y, atol=1e-8) @@ -174,7 +162,6 @@ def test_act_mx(inp, bit_widths): @pytest_cases.parametrize('bit_widths', list(MAP.keys())) @pytest_cases.parametrize('weight_quant_type', ['stats', 'parameter_from_stats']) def test_weight_mx(inp, bit_widths, weight_quant_type): - # print("-------------------------------------------") torch.set_printoptions(precision=12, sci_mode=False) exp, mant = MAP[bit_widths] weight_quant = QuantLinear( @@ -190,14 +177,11 @@ def test_weight_mx(inp, bit_widths, weight_quant_type): x = inp weight_quant.weight.data = x weight_quant.weight_quant.init_tensor_quant() + quantizer = MXFP(bit_widths) qx_weight = weight_quant.quant_weight() qx_weight_two = weight_quant.quant_weight() - if mx is None: - print("Install microscaling library, --no-deps flag recommended") - else: - y = mx( - x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False) + y = quantizer.quantize(x) assert torch.allclose(qx_weight.value, y, atol=1e-8) assert torch.allclose(qx_weight_two.value, y, atol=1e-8)