From c20c03cefddc6d996124e42c058fa13a54c42b8d Mon Sep 17 00:00:00 2001 From: Francesco Conti Date: Wed, 21 Aug 2024 15:57:05 +0200 Subject: [PATCH] Fix corner cases global shift gen In some corner cases, the global shift factor was generated as a number < 0 (down to -inf...). This makes no sense, so now the global shift must be 0 at a minimum. --- test/NnxTestClasses.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/NnxTestClasses.py b/test/NnxTestClasses.py index 8d4eed1..7e0e3a0 100644 --- a/test/NnxTestClasses.py +++ b/test/NnxTestClasses.py @@ -213,7 +213,11 @@ def _calculate_global_shift( """Calculate global shift so that the output values are in the range of out_type""" s = tensor.type(torch.float64).std() target_s = 2 ** (out_type._bits - 1) - return torch.ceil(torch.log2(s / target_s)).type(torch.int32) + shift = torch.ceil(torch.log2(s / target_s)).type(torch.int32) + if shift < 1: + return torch.zeros((1,)).type(torch.int32) + else: + return shift @staticmethod def _random_data(_type: IntegerType, shape: Tuple, extremes: Tuple = None):