diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 81bf2d62d8f42..605166930ccc6 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -38,7 +38,13 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { if (cuda_device_capability >= 90) { return CUDA_VERSION >= 12000; } else if (cuda_device_capability >= 89) { - return CUDA_VERSION >= 12040; + // CUTLASS Kernels have not been tuned for Ada Lovelace systems + // and are slower than torch.mm. Return false unconditionally in this case. + return false; + + // Once the CUTLASS kernels have been optimized for Lovelace systems, + // use the following check: + // return CUDA_VERSION >= 12040; } #endif @@ -98,4 +104,4 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, TORCH_CHECK(version_num >= 75); cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); } -} \ No newline at end of file +}