From 1c8dc36eb95c6b7105bd651f03daf584b495e93d Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Tue, 24 Dec 2024 23:59:04 -0800 Subject: [PATCH] bugfix: only use sm90 group gemm when torch cuda >= 12.3 (#699) wgmma is only available for cuda 12.3 or later, turn to use sm80 version when torch.cuda version is lower. cc @xslingcn --- flashinfer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 7164038e..873271ee 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -257,7 +257,7 @@ def get_cuda_stream(device: torch.device) -> int: def determine_gemm_backend(device: torch.device) -> str: major, _ = get_compute_capability(device) - if major >= 9: + if major >= 9 and torch.version.cuda >= "12.3": return "sm90" else: return "sm80"