Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Oct 28, 2024
1 parent 07bbff0 commit 3340972
Showing 1 changed file with 37 additions and 11 deletions.
48 changes: 37 additions & 11 deletions tests/brevitas/core/test_quant_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import struct

from brevitas.nn.quant_linear import QuantLinear
from tests.brevitas.hyp_helper import float_tensor_nz_st

try:
Expand All @@ -17,7 +18,7 @@

from brevitas.nn.quant_activation import QuantIdentity
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.utils.torch_utils import float_internal_scale
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight

torch.manual_seed(0)

Expand Down Expand Up @@ -144,10 +145,11 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False):

@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10))
@pytest_cases.parametrize('bit_widths', list(MAP.keys()))
def test_mx(inp, bit_widths):
def test_act_mx(inp, bit_widths):
# print("-------------------------------------------")
torch.set_printoptions(precision=12, sci_mode=False)
exp, mant = MAP[bit_widths]

act_quant = QuantIdentity(
MXFloat8e4m3Act,
exponent_bit_width=exp,
Expand All @@ -158,20 +160,44 @@ def test_mx(inp, bit_widths):
act_quant.eval()
x = inp

# dtype = MXFP(bit_widths)
# q, scale = dtype.quantize(x, select=False)
qx = act_quant(x)
# error, lowest = check_bits(q, dtype.mbits)

exp_bias = torch.tensor(2 ** (exp - 1) - 1)
if mx is None:
print("Install microscaling library, --no-deps flag recommended")
else:
y = mx(
x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False)
assert torch.allclose(qx.value, y, atol=1e-8)


@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10))
@pytest_cases.parametrize('bit_widths', list(MAP.keys()))
@pytest_cases.parametrize('weight_quant_type', ['stats', 'parameter_from_stats'])
def test_weight_mx(inp, bit_widths, weight_quant_type):
# print("-------------------------------------------")
torch.set_printoptions(precision=12, sci_mode=False)
exp, mant = MAP[bit_widths]
weight_quant = QuantLinear(
32,
1,
bias=False,
weight_quant=MXFloat8e4m3Weight,
weight_scaling_impl_type=weight_quant_type,
weight_exponent_bit_width=exp,
weight_mantissa_bit_width=mant,
weight_bit_width=mant + exp + 1)

x = inp
weight_quant.weight.data = x
weight_quant.weight_quant.init_tensor_quant()

qx_weight = weight_quant.quant_weight()
qx_weight_two = weight_quant.quant_weight()

int_scale = float_internal_scale(
x / qx.scale, torch.tensor(mant), 1. - exp_bias - torch.tensor(mant), torch.tensor(1e-8))
brev_scale = 1 / (int_scale * qx.scale)
if mx is None:
print("Install microscaling library, --no-deps flag recommended")
else:
y = mx(
x, 8, elem_format=bit_widths, block_size=32, axes=-1, round='even', custom_cuda=False)
assert torch.allclose(qx.value, y, atol=1e-4)
# assert torch.allclose(brev_scale, scale, atol=1e-4)
assert torch.allclose(qx_weight.value, y, atol=1e-8)
assert torch.allclose(qx_weight_two.value, y, atol=1e-8)

0 comments on commit 3340972

Please sign in to comment.