Skip to content

Commit

Permalink
fix type
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Yao <[email protected]>
  • Loading branch information
yaox12 committed Aug 27, 2024
1 parent 7b9ac33 commit e42a6d4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def fp8_grouped_gemm(
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
) -> Tuple[Union[List[torch.Tensor], None], ...]:
) -> Tuple[Union[List[torch.Tensor], torch.Tensor, None], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
This function accepts two combinations of inputs:
Expand All @@ -416,7 +416,7 @@ def fp8_grouped_gemm(
if isinstance(A_scale_inv, list):
assert isinstance(out, torch.Tensor) and m_splits is not None
elif isinstance(A_scale_inv, torch.Tensor):
assert isinstance(out, list)
assert isinstance(out, (list, tuple))
else:
raise ValueError("A_scale_inv should be a list of tensors or a single tensor.")

Expand All @@ -436,7 +436,7 @@ def fp8_grouped_gemm(
bias_dtype = TE_DType[bias_dtype]
gelu_input = empty_tensors

if isinstance(out, list):
if not isinstance(out, torch.Tensor):
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
Expand Down

0 comments on commit e42a6d4

Please sign in to comment.