-
Notifications
You must be signed in to change notification settings - Fork 207
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feat (tests/quant_tensor): Quant Tensor tests (#894)
- Loading branch information
1 parent
a8159f9
commit a106a6d
Showing
1 changed file
with
106 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
from enum import Enum | ||
|
||
import pytest | ||
import torch | ||
|
||
from brevitas.inject.enum import QuantType | ||
from brevitas.nn import QuantIdentity | ||
from brevitas.quant_tensor import QuantTensor | ||
|
||
|
||
class Operator(Enum): | ||
ADD = 0 | ||
SUBTRACT = 1 | ||
DIVIDE = 2 | ||
MULTIPLY = 3 | ||
MATMUL = 4 | ||
|
||
|
||
def to_quant_tensor(input: torch.Tensor) -> QuantTensor: | ||
mod = QuantIdentity(bit_width=8, return_quant_tensor=True) | ||
return mod(input) | ||
|
||
|
||
def qdq(normal_tensor, quant_tensor): | ||
return ( | ||
torch.round(normal_tensor / quant_tensor.scale + quant_tensor.zero_point) - | ||
quant_tensor.zero_point) * quant_tensor.scale | ||
|
||
|
||
def test_quant_tensor_init(): | ||
x = torch.randn(4, 4) | ||
quant_tensor = to_quant_tensor(x) | ||
normal_tensor = torch.Tensor(x) | ||
assert torch.allclose(qdq(normal_tensor, quant_tensor), quant_tensor, rtol=0.01) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'op', [Operator.ADD, Operator.SUBTRACT, Operator.DIVIDE, Operator.MULTIPLY, Operator.MATMUL]) | ||
def test_quant_tensor_operators(op): | ||
x = torch.randn(4, 4) | ||
|
||
a = torch.Tensor(x) | ||
b = torch.Tensor(x) | ||
|
||
qa = to_quant_tensor(a) | ||
qb = to_quant_tensor(b) | ||
|
||
# to factor in quantisation error | ||
e_a = a - qa | ||
e_b = b - qb | ||
|
||
if op == Operator.ADD: | ||
quant = qa + qb | ||
normal = (a - e_a) + (b - e_b) | ||
elif op == Operator.SUBTRACT: | ||
quant = qa - qb | ||
normal = (a - e_a) - (b - e_b) | ||
elif op == Operator.DIVIDE: | ||
quant = qa / qb | ||
normal = (a - e_a) / (b - e_b) | ||
elif op == Operator.MULTIPLY: | ||
quant = qa * qb | ||
normal = (a - e_a) * (b - e_b) | ||
elif op == Operator.MATMUL: | ||
# @ matmul operator not implemented for QuantTensor | ||
quant = torch.matmul(qa, qb) | ||
normal = (a - e_a) @ (b - e_b) | ||
else: | ||
# unrecognised operator | ||
assert False | ||
|
||
assert torch.allclose(normal, quant) | ||
|
||
|
||
def test_quant_tensor_div_by_zero(): | ||
a = to_quant_tensor(torch.ones(4, 4)) | ||
b = to_quant_tensor(torch.zeros(4, 4)) | ||
assert torch.isinf(a / b).all().item() | ||
|
||
|
||
def test_quant_tensor_div_by_fraction(): | ||
a = to_quant_tensor(torch.ones(4, 4)) | ||
b = to_quant_tensor(torch.ones(4, 4) * 0.5) | ||
assert torch.allclose(a / b, torch.ones(4, 4) * 2, atol=0.1) | ||
|
||
|
||
# TODO: need to deal with quant metadata | ||
def test_quant_tensor_transpose(): | ||
x = torch.ones(4, 4).tril() | ||
a = x.clone() | ||
b = to_quant_tensor(x) | ||
assert torch.allclose(a.transpose(0, 1), b.transpose(0, 1), atol=0.01) | ||
|
||
|
||
# TODO: need to deal with quant metadata | ||
def test_quant_tensor_view(): | ||
x = torch.ones(4, 4) | ||
a = to_quant_tensor(x) | ||
b = torch.Tensor(x) | ||
|
||
assert torch.allclose(a.view(-1), b.view(-1), atol=0.01) | ||
assert torch.allclose(a.view(2, -1), b.view(2, -1), atol=0.01) | ||
assert torch.allclose(a.view(16, -1), b.view(16, -1), atol=0.01) | ||
assert torch.allclose(a.view(8, 2), b.view(8, 2), atol=0.01) |