Skip to content

Commit

Permalink
add supports_fp8 to git_8bit_types
Browse files Browse the repository at this point in the history
Signed-off-by: Randall Smith <[email protected]>
  • Loading branch information
rasmith committed Nov 6, 2024
1 parent f003676 commit 4f1e62e
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/kernels/test_scaled_mm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def scaled_mm_torch(a: torch.Tensor,

def get_8bit_types():
types = [torch.int8]
if current_platform.is_rocm():
supports_fp8 = current_platform.has_device_capability(89)
if current_platform.is_rocm() and supports_fp8:
types.append(torch.float8_e4m3fnuz)
elif (current_platform.is_cuda()
and current_platform.has_device_capability(89)):
elif current_platform.is_cuda() and supports_fp8:
types.append(torch.float8_e4m3fn)
return types

Expand Down

0 comments on commit 4f1e62e

Please sign in to comment.