Skip to content

Commit

Permalink
Reference impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 24, 2024
1 parent 213dc29 commit a073264
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
OUT_CH = 16
BATCH = 1
REFERENCE_SCALES = {
'int_quant': (0.00935234408825635910, 0.00859776325523853302),
'fp_quant': (0.00249395845457911491, 0.00190271728206425905)}
'int_quant': (0.00935234408825635910, 0.01362917013466358185),
'fp_quant': (0.00249395845457911491, 0.00363444536924362183)}
REFERNECE_INP = torch.tensor([[-1.8645, -0.4071, 1.1971]])


Expand Down Expand Up @@ -86,12 +86,14 @@ class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
self.act = qnn.QuantReLU(act_quant=act_quant)
self.linear = qnn.QuantLinear(3, 8)
self.linear_weights = torch.tensor([[1.0023, 0.0205,
1.4604], [-0.2918, -1.8218, -0.7010],
[1.4573, -0.9074, -0.2708]])
self.act_1 = qnn.QuantIdentity(act_quant=act_quant)

def forward(self, x):
o = self.act(x)
o = self.linear(o)
o = torch.matmul(o, self.linear_weights)
return self.act_1(o)

# Reference input
Expand Down

0 comments on commit a073264

Please sign in to comment.