Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Turn off CUTLASS scaled_mm for Ada Lovelace #6384

Merged
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 250 -f 5 -t 1
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test -b 32 -l 1000 -f 5 -t 1
model_name: "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.752
value: 0.755
- name: "exact_match,flexible-extract"
value: 0.752
limit: 250
value: 0.755
limit: 1000
num_fewshot: 5
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.756
value: 0.753
- name: "exact_match,flexible-extract"
value: 0.752
limit: 250
value: 0.753
limit: 1000
num_fewshot: 5
10 changes: 8 additions & 2 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
// CUTLASS Kernels have not been tuned for Ada Lovelace systems
// and are slower than torch.mm. Return false unconditionally in this case.
return false;

// Once the CUTLASS kernels have been optimized for Lovelace systems,
// use the following check:
// return CUDA_VERSION >= 12040;
}
#endif

Expand Down Expand Up @@ -98,4 +104,4 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
}
}
}
Loading