Skip to content

Commit

Permalink
Feat (pytorch_utils): kthvalue for (b)float16
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 22, 2023
1 parent fcdc623 commit 89af08f
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 11 deletions.
18 changes: 10 additions & 8 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from brevitas import config
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_int
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue

from .stats_wrapper import SCALAR_SHAPE

Expand Down Expand Up @@ -64,15 +66,15 @@ def forward(self, x: Tensor):
if self.stats_reduce_dim is None:
# k is 1-indexed, so round away from zero
k = int(math.floor(.01 * self.q * x.numel() + 0.5))
result = x.abs().view(-1).kthvalue(k).values
result = kthvalue(x.abs().view(-1), k)[0]
else:
# assuming x is two dimensional, get the other dimension
assert len(x.size()) == 2, "Only 2-dim input is supported."
other_dim = abs(self.stats_reduce_dim - 1)
dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1)
# k is 1-indexed, so round away from zero
k = int(math.floor(.01 * self.q * dim_slice.numel() + 0.5))
result = x.abs().kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
result = kthvalue(x.abs(), k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
return result


Expand All @@ -97,15 +99,15 @@ def forward(self, x: Tensor) -> Tensor:
if self.stats_reduce_dim is None:
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * x.numel()))
result = x.view(-1).kthvalue(k).values
result = kthvalue(x.view(-1), k)[0]
else:
# assuming x is two dimensional, get the other dimension
assert len(x.size()) == 2, "Only 2-dim input is supported."
other_dim = abs(self.stats_reduce_dim - 1)
dim_slice = torch.narrow(x, dim=other_dim, start=0, length=1)
# k is 1-indexed, so round away from zero
k = int(math.ceil(.01 * self.q * dim_slice.numel()))
result = x.kthvalue(k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
result = kthvalue(x, k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
result = torch.clamp(result, max=self.zero())
return result

Expand Down Expand Up @@ -134,8 +136,8 @@ def forward(self, x: Tensor) -> Tensor:
low_k = int(math.ceil(.01 * self.low_q * x.numel()))
# k is 1-indexed, so round away from zero
high_k = int(math.floor(.01 * self.high_q * x.numel() + 0.5))
low_result = x.view(-1).kthvalue(low_k).values
high_result = x.view(-1).kthvalue(high_k).values
low_result = kthvalue(x.view(-1), low_k)[0]
high_result = kthvalue(x.view(-1), high_k)[0]
else:
# assuming x is two dimensional, get the other dimension
assert len(x.size()) == 2, "Only 2-dim input is supported."
Expand All @@ -144,8 +146,8 @@ def forward(self, x: Tensor) -> Tensor:
low_k = int(math.ceil(.01 * self.low_q * dim_slice.numel()))
# k is 1-indexed, so round away from zero
high_k = int(math.floor(.01 * self.high_q * dim_slice.numel() + 0.5))
low_result = x.kthvalue(low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
high_result = x.kthvalue(high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim).values
low_result = kthvalue(x, low_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
high_result = kthvalue(x, high_k, dim=self.stats_reduce_dim, keepdim=self.keepdim)[0]
# We need to make sure the lower bound is not positive to align with zero-point statistics
low_result = torch.clamp(low_result, max=self.zero())
interval = high_result - low_result
Expand Down
34 changes: 34 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import copy
from typing import Optional, Tuple

import torch
from torch.nn import Sequential
Expand Down Expand Up @@ -46,3 +47,36 @@ def torch_partial_deepcopy(model):
memo[id(p)] = copy.copy(p) # Shallow copy of parameters
model_copy = copy.deepcopy(model, memo)
return model_copy


def kthvalue(
x: torch.Tensor,
k: int,
dim: Optional[int] = None,
keepdim: bool = False,
out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None
) -> Tuple[torch.Tensor, torch.LongTensor]:
# As of torch 2.1, there is no kthvalue implementation:
# - In CPU for float16
# - In GPU for bfloat16
# In these cases we cast to float32 and then go back to the original dtype
dtype = x.dtype
device = str(x.device)

# We do not support out as buffer for the output, since we cannot control its dtype
if out is not None:
raise RuntimeError("out argument for kthvalue not supported")

if (dtype == torch.float16 and 'cpu' in device) or \
(dtype == torch.bfloat16 and 'cuda' in device):
x = x.type(torch.float32)

# PyTorch specify None as default for `dim` but it breaks if we specifically pass None
if dim is not None:
x, indices = torch.kthvalue(x, k, dim=dim, keepdim=keepdim)
else:
x, indices = torch.kthvalue(x, k, keepdim=keepdim)

if x.dtype != dtype:
x = x.type(dtype)
return (x, indices)
6 changes: 4 additions & 2 deletions tests/brevitas/core/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from brevitas.core.stats import AbsPercentile
from brevitas.core.stats import NegativePercentileOrZero
from brevitas.core.stats import PercentileInterval
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue


def test_abs_percentile_per_tensor():
Expand Down Expand Up @@ -35,10 +37,10 @@ def compute_percentile(self, x, low_q=None, high_q=None):
low_p, high_p = None, None
if low_q is not None:
k = int(math.ceil(.01 * low_q * x.numel()))
low_p = x.view(-1).kthvalue(k).values
low_p = kthvalue(x.view(-1), k=k)[0]
if high_q is not None:
k = int(math.floor(.01 * high_q * x.numel() + 0.5))
high_p = x.view(-1).kthvalue(k).values
high_p = kthvalue(x.view(-1), k=k)[0]
return low_p, high_p

def test_negative_percentile(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/brevitas/graph/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from brevitas.graph.calibrate import calibration_mode
import brevitas.nn as qnn
from brevitas.quant import Int8ActPerTensorFixedPoint
# Use custom implementation of kthvalue as work around to (b)float16 kernel limitations
from brevitas.utils.torch_utils import kthvalue
from tests.brevitas.hyp_helper import float_tensor_random_size_st

IN_CH = 8
Expand All @@ -21,7 +23,7 @@

def compute_quantile(x, q):
k = int(math.floor(.01 * q * x.numel() + 0.5))
result = x.abs().view(-1).kthvalue(k).values
result = kthvalue(x.abs().view(-1), k=k)[0]
return result


Expand Down

0 comments on commit 89af08f

Please sign in to comment.