From 06b810f299187f1c10f2cd324b36b545e9c0bac9 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 21 Dec 2023 10:39:48 +0000 Subject: [PATCH] Fix for jit --- src/brevitas/utils/torch_utils.py | 9 +++++++-- tests/brevitas/core/test_stats.py | 6 ++++-- tests/brevitas/graph/test_calibration.py | 4 +++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 7074e1e62..88a8f96f5 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -49,17 +49,22 @@ def torch_partial_deepcopy(model): return model_copy -def kthvalue(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: +def kthvalue(x, k, dim=None, keepdim=False, out=None) -> torch.Tensor: # As of torch 2.1, there is no kthvalue implementation: # - In CPU for float16 # - In GPU for bfloat16 # In these cases we cast to float32 and then go back to the original dtype dtype = x.dtype device = str(x.device) + + # We do not support out as buffer for the output, since we cannot control its dtype + if out is not None: + raise RuntimeError("out argument for kthvalue not supported") + if (dtype == torch.float16 and 'cpu' in device) or \ (dtype == torch.bfloat16 and 'cuda' in device): x = x.type(torch.float32) - x, indices = torch.kthvalue(x, *args, **kwargs) + x, indices = torch.kthvalue(x, k, dim=dim, keepdim=keepdim) if x.dtype != dtype: x = x.type(dtype) return (x, indices) diff --git a/tests/brevitas/core/test_stats.py b/tests/brevitas/core/test_stats.py index 224f323fc..4d397e457 100644 --- a/tests/brevitas/core/test_stats.py +++ b/tests/brevitas/core/test_stats.py @@ -8,6 +8,8 @@ from brevitas.core.stats import AbsPercentile from brevitas.core.stats import NegativePercentileOrZero from brevitas.core.stats import PercentileInterval +# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations +from brevitas.utils.torch_utils import kthvalue def test_abs_percentile_per_tensor(): @@ -35,10 +37,10 @@ def compute_percentile(self, x, low_q=None, high_q=None): low_p, high_p = None, None if low_q is not None: k = int(math.ceil(.01 * low_q * x.numel())) - low_p = x.view(-1).kthvalue(k).values + low_p = kthvalue(x.view(-1), k=k)[0] if high_q is not None: k = int(math.floor(.01 * high_q * x.numel() + 0.5)) - high_p = x.view(-1).kthvalue(k).values + high_p = kthvalue(x.view(-1), k=k)[0] return low_p, high_p def test_negative_percentile(self): diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 58561cbd1..6580d971b 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -12,6 +12,8 @@ from brevitas.graph.calibrate import calibration_mode import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint +# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations +from brevitas.utils.torch_utils import kthvalue from tests.brevitas.hyp_helper import float_tensor_random_size_st IN_CH = 8 @@ -21,7 +23,7 @@ def compute_quantile(x, q): k = int(math.floor(.01 * q * x.numel() + 0.5)) - result = x.abs().view(-1).kthvalue(k).values + result = kthvalue(x.abs().view(-1), k=k)[0] return result