We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hello,
Casting from fp8 to bf16 in triton fails on SM89 because of a ptx error.
I am opening this issue because this causes {"triton.codegen_upcast_to_fp32": False} to fail on torch compile. Should this be a torch bug instead?
triton code:
import triton import triton.language as tl import torch @triton.jit def cast_kernel( X_PTR, BLOCK_SIZE: tl.constexpr ): idx = tl.arange(0, BLOCK_SIZE) x = tl.load(X_PTR + idx) x_bf16 = x.to(tl.bfloat16) x_fp8 = x_bf16.to(tl.float8e4nv) tl.store(X_PTR + idx, x_fp8) BLOCK_SIZE = 128 x = torch.empty(BLOCK_SIZE, device='cuda').to(dtype=torch.float8_e4m3fn) cast_kernel[(1,)](x, BLOCK_SIZE=BLOCK_SIZE)
torch code:
import torch @torch.compile(options={"triton.codegen_upcast_to_fp32": False}) def test(tensor: torch.Tensor) -> torch.Tensor: tensor = tensor.to(torch.bfloat16) return tensor.to(dtype=torch.float8_e4m3fn) tensor = torch.randn((100, 100), device="cuda").to(dtype=torch.float8_e4m3fn) test(tensor)
error logs:
RuntimeError: Internal Triton PTX codegen error `ptxas` stderr: ptxas /tmp/tmp3xn11kox.ptx, line 47; error : Feature 'cvt with .bf16.f16' requires .target sm_90 or higher ptxas /tmp/tmp3xn11kox.ptx, line 48; error : Feature 'cvt with .bf16.f16' requires .target sm_90 or higher ptxas fatal : Ptx assembly aborted due to errors
Triton: 3.1.0 GPU: RTX 6000 Ada Edition
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Describe the bug
Hello,
Casting from fp8 to bf16 in triton fails on SM89 because of a ptx error.
I am opening this issue because this causes {"triton.codegen_upcast_to_fp32": False} to fail on torch compile.
Should this be a torch bug instead?
triton code:
torch code:
error logs:
Environment details
Triton: 3.1.0
GPU: RTX 6000 Ada Edition
The text was updated successfully, but these errors were encountered: