From e42a6d4386a7362bfa2ef03e82b88c615764bd74 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 27 Aug 2024 00:35:50 -0700 Subject: [PATCH] fix type Signed-off-by: Xin Yao --- transformer_engine/pytorch/cpp_extensions/gemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 059149a408..58764b6bcc 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -404,7 +404,7 @@ def fp8_grouped_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[Union[List[torch.Tensor], torch.Tensor, None], ...]: """ TN layout Grouped GEMM with fp8 inputs. This function accepts two combinations of inputs: @@ -416,7 +416,7 @@ def fp8_grouped_gemm( if isinstance(A_scale_inv, list): assert isinstance(out, torch.Tensor) and m_splits is not None elif isinstance(A_scale_inv, torch.Tensor): - assert isinstance(out, list) + assert isinstance(out, (list, tuple)) else: raise ValueError("A_scale_inv should be a list of tensors or a single tensor.") @@ -436,7 +436,7 @@ def fp8_grouped_gemm( bias_dtype = TE_DType[bias_dtype] gelu_input = empty_tensors - if isinstance(out, list): + if not isinstance(out, torch.Tensor): if gelu: gelu_input = [ torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)