diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 7164038e..873271ee 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -257,7 +257,7 @@ def get_cuda_stream(device: torch.device) -> int: def determine_gemm_backend(device: torch.device) -> str: major, _ = get_compute_capability(device) - if major >= 9: + if major >= 9 and torch.version.cuda >= "12.3": return "sm90" else: return "sm80"