diff --git a/notebooks/quantized_recurrent.ipynb b/notebooks/quantized_recurrent.ipynb index 6569662ad..5d28c0bb9 100644 --- a/notebooks/quantized_recurrent.ipynb +++ b/notebooks/quantized_recurrent.ipynb @@ -636,6 +636,7 @@ "from torch.nn import RNN\n", "from brevitas.nn import QuantRNN\n", "from brevitas import config\n", + "ATOL = 1e-6\n", "\n", "config.IGNORE_MISSING_KEYS = True\n", "torch.manual_seed(123456)\n", @@ -648,12 +649,11 @@ "\n", "# Generate random input\n", "inp = torch.randn(5, 2, 10)\n", - "\n", + "torch.allclose\n", "# Check outputs are the same\n", - "assert torch.isclose(quant_rnn(inp)[0], float_rnn(inp)[0]).all().item(), f\"inp {inp} \\n max error {torch.max(torch.abs(quant_rnn(inp)[0] - float_rnn(inp)[0]))} \\n QuantOut {quant_rnn(inp)[0]} \\n FloatOut {float_rnn(inp)[0]}\"\n", - "\n", + "assert torch.allclose(quant_rnn(inp)[0], float_rnn(inp)[0], atol=ATOL)\n", "# Check hidden states are the same\n", - "assert torch.isclose(quant_rnn(inp)[1], float_rnn(inp)[1]).all().item()" + "assert torch.allclose(quant_rnn(inp)[1], float_rnn(inp)[1], atol=ATOL)" ] }, {