From 89af08f2761f495ab9bbc27877da868b119d6690 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 21 Dec 2023 23:00:47 +0000 Subject: [PATCH] Feat (pytorch_utils): kthvalue for (b)float16 --- src/brevitas/core/stats/stats_op.py | 18 +++++++------ src/brevitas/utils/torch_utils.py | 34 ++++++++++++++++++++++++ tests/brevitas/core/test_stats.py | 6 +++-- tests/brevitas/graph/test_calibration.py | 4 ++- 4 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 194631953..bfcfbb58f 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -12,6 +12,8 @@ from brevitas import config from brevitas.core.utils import StatelessBuffer from brevitas.function.ops import max_int +# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations +from brevitas.utils.torch_utils import kthvalue from .stats_wrapper import SCALAR_SHAPE @@ -64,7 +66,7 @@ def forward(self, x: Tensor): if self.stats_reduce_dim is None: # k is 1-indexed, so round away from zero k = int(math.floor(.01 * self.q * x.numel() + 0.5)) - result = x.abs().view(-1).kthvalue(k).values + result = kthvalue(x.abs().view(-1), k)[0] else: # assuming x is two dimensional, get the other dimension assert len(x.size()) == 2, "Only 2-dim input is supported." @@ -72,7 +74,7 @@ def forward(self, x: Tensor): dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1) # k is 1-indexed, so round away from zero k = int(math.floor(.01 * self.q * dim_slice.numel() + 0.5)) - result = x.abs().kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values + result = kthvalue(x.abs(), k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] return result @@ -97,7 +99,7 @@ def forward(self, x: Tensor) -> Tensor: if self.stats_reduce_dim is None: # k is 1-indexed, so round away from zero k = int(math.ceil(.01 * self.q * x.numel())) - result = x.view(-1).kthvalue(k).values + result = kthvalue(x.view(-1), k)[0] else: # assuming x is two dimensional, get the other dimension assert len(x.size()) == 2, "Only 2-dim input is supported." @@ -105,7 +107,7 @@ def forward(self, x: Tensor) -> Tensor: dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1) # k is 1-indexed, so round away from zero k = int(math.ceil(.01 * self.q * dim_slice.numel())) - result = x.kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values + result = kthvalue(x, k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] result = torch.clamp(result, max=self.zero()) return result @@ -134,8 +136,8 @@ def forward(self, x: Tensor) -> Tensor: low_k = int(math.ceil(.01 * self.low_q * x.numel())) # k is 1-indexed, so round away from zero high_k = int(math.floor(.01 * self.high_q * x.numel() + 0.5)) - low_result = x.view(-1).kthvalue(low_k).values - high_result = x.view(-1).kthvalue(high_k).values + low_result = kthvalue(x.view(-1), low_k)[0] + high_result = kthvalue(x.view(-1), high_k)[0] else: # assuming x is two dimensional, get the other dimension assert len(x.size()) == 2, "Only 2-dim input is supported." @@ -144,8 +146,8 @@ def forward(self, x: Tensor) -> Tensor: low_k = int(math.ceil(.01 * self.low_q * dim_slice.numel())) # k is 1-indexed, so round away from zero high_k = int(math.floor(.01 * self.high_q * dim_slice.numel() + 0.5)) - low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values - high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values + low_result = kthvalue(x, low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] + high_result = kthvalue(x, high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0] # We need to make sure the lower bound is not positive to align with zero-point statistics low_result = torch.clamp(low_result, max=self.zero()) interval = high_result - low_result diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 7105ea874..ec7d6fac4 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +from typing import Optional, Tuple import torch from torch.nn import Sequential @@ -46,3 +47,36 @@ def torch_partial_deepcopy(model): memo[id(p)] = copy.copy(p) # Shallow copy of parameters model_copy = copy.deepcopy(model, memo) return model_copy + + +def kthvalue( + x: torch.Tensor, + k: int, + dim: Optional[int] = None, + keepdim: bool = False, + out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None +) -> Tuple[torch.Tensor, torch.LongTensor]: + # As of torch 2.1, there is no kthvalue implementation: + # - In CPU for float16 + # - In GPU for bfloat16 + # In these cases we cast to float32 and then go back to the original dtype + dtype = x.dtype + device = str(x.device) + + # We do not support out as buffer for the output, since we cannot control its dtype + if out is not None: + raise RuntimeError("out argument for kthvalue not supported") + + if (dtype == torch.float16 and 'cpu' in device) or \ + (dtype == torch.bfloat16 and 'cuda' in device): + x = x.type(torch.float32) + + # PyTorch specify None as default for `dim` but it breaks if we specifically pass None + if dim is not None: + x, indices = torch.kthvalue(x, k, dim=dim, keepdim=keepdim) + else: + x, indices = torch.kthvalue(x, k, keepdim=keepdim) + + if x.dtype != dtype: + x = x.type(dtype) + return (x, indices) diff --git a/tests/brevitas/core/test_stats.py b/tests/brevitas/core/test_stats.py index 224f323fc..4d397e457 100644 --- a/tests/brevitas/core/test_stats.py +++ b/tests/brevitas/core/test_stats.py @@ -8,6 +8,8 @@ from brevitas.core.stats import AbsPercentile from brevitas.core.stats import NegativePercentileOrZero from brevitas.core.stats import PercentileInterval +# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations +from brevitas.utils.torch_utils import kthvalue def test_abs_percentile_per_tensor(): @@ -35,10 +37,10 @@ def compute_percentile(self, x, low_q=None, high_q=None): low_p, high_p = None, None if low_q is not None: k = int(math.ceil(.01 * low_q * x.numel())) - low_p = x.view(-1).kthvalue(k).values + low_p = kthvalue(x.view(-1), k=k)[0] if high_q is not None: k = int(math.floor(.01 * high_q * x.numel() + 0.5)) - high_p = x.view(-1).kthvalue(k).values + high_p = kthvalue(x.view(-1), k=k)[0] return low_p, high_p def test_negative_percentile(self): diff --git a/tests/brevitas/graph/test_calibration.py b/tests/brevitas/graph/test_calibration.py index 58561cbd1..6580d971b 100644 --- a/tests/brevitas/graph/test_calibration.py +++ b/tests/brevitas/graph/test_calibration.py @@ -12,6 +12,8 @@ from brevitas.graph.calibrate import calibration_mode import brevitas.nn as qnn from brevitas.quant import Int8ActPerTensorFixedPoint +# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations +from brevitas.utils.torch_utils import kthvalue from tests.brevitas.hyp_helper import float_tensor_random_size_st IN_CH = 8 @@ -21,7 +23,7 @@ def compute_quantile(x, q): k = int(math.floor(.01 * q * x.numel() + 0.5)) - result = x.abs().view(-1).kthvalue(k).values + result = kthvalue(x.abs().view(-1), k=k)[0] return result