diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 059149a408..58764b6bcc 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -404,7 +404,7 @@ def fp8_grouped_gemm( use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, -) -> Tuple[Union[List[torch.Tensor], None], ...]: +) -> Tuple[Union[List[torch.Tensor], torch.Tensor, None], ...]: """ TN layout Grouped GEMM with fp8 inputs. This function accepts two combinations of inputs: @@ -416,7 +416,7 @@ def fp8_grouped_gemm( if isinstance(A_scale_inv, list): assert isinstance(out, torch.Tensor) and m_splits is not None elif isinstance(A_scale_inv, torch.Tensor): - assert isinstance(out, list) + assert isinstance(out, (list, tuple)) else: raise ValueError("A_scale_inv should be a list of tensors or a single tensor.") @@ -436,7 +436,7 @@ def fp8_grouped_gemm( bias_dtype = TE_DType[bias_dtype] gelu_input = empty_tensors - if isinstance(out, list): + if not isinstance(out, torch.Tensor): if gelu: gelu_input = [ torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)