Skip to content

Commit

Permalink
fix kthvalue dim init
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 21, 2023
1 parent 06b810f commit f6bd017
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ def kthvalue(x, k, dim=None, keepdim=False, out=None) -> torch.Tensor:
if (dtype == torch.float16 and 'cpu' in device) or \
(dtype == torch.bfloat16 and 'cuda' in device):
x = x.type(torch.float32)
x, indices = torch.kthvalue(x, k, dim=dim, keepdim=keepdim)

# PyTorch specify None as default for `dim` but it breaks if we specifically pass None
if dim is not None:
x, indices = torch.kthvalue(x, k, dim=dim, keepdim=keepdim)
else:
x, indices = torch.kthvalue(x, k, keepdim=keepdim)

if x.dtype != dtype:
x = x.type(dtype)
return (x, indices)

0 comments on commit f6bd017

Please sign in to comment.