Skip to content

Commit

Permalink
JIT annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 21, 2023
1 parent 4d2c0dc commit e996018
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def torch_partial_deepcopy(model):
return model_copy


def kthvalue(x, k, dim=None, keepdim=False, out=None) -> torch.Tensor:
def kthvalue(
x: torch.Tensor,
k: int,
dim: Optional[int] = None,
keepdim: bool = False,
out: Optional[Tuple] = None) -> torch.Tensor:
# As of torch 2.1, there is no kthvalue implementation:
# - In CPU for float16
# - In GPU for bfloat16
Expand Down

0 comments on commit e996018

Please sign in to comment.