From edba3421564064318d8b525052dbd00d36dfe7d7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 21 Dec 2023 13:42:21 +0000 Subject: [PATCH] typing --- src/brevitas/utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 76b6af9d3..43540cc72 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -54,7 +54,7 @@ def kthvalue( k: int, dim: Optional[int] = None, keepdim: bool = False, - out: Optional[Tuple] = None) -> torch.Tensor: + out: Optional[Tuple[torch.Tensor, torch.LongTensor]] = None) -> torch.Tensor: # As of torch 2.1, there is no kthvalue implementation: # - In CPU for float16 # - In GPU for bfloat16