Skip to content

Commit

Permalink
[Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS ke…
Browse files Browse the repository at this point in the history
  • Loading branch information
tlrmchlsmth authored Jun 20, 2024
1 parent a7dcc62 commit 3f3b6b2
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 13 deletions.
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);

bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
Expand Down
16 changes: 16 additions & 0 deletions csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b_scales);
#endif

bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)

#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif

return false;
}

void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
Expand Down
6 changes: 6 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCUDA, &cutlass_scaled_mm);

// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
&cutlass_scaled_mm_supports_fp8);
#endif

// Quantized GEMM for GPTQ.
Expand Down
4 changes: 4 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,


# cutlass
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor:
Expand Down
15 changes: 2 additions & 13 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,8 @@
def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
major, minor = torch.version.cuda.split(".")
version = int(major) * 10 + int(minor)

# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False
if capability >= 90:
gpu_is_supported = version > 120
elif capability >= 89:
gpu_is_supported = version > 124

return gpu_is_supported

return ops.cutlass_scaled_mm_supports_fp8(capability)


class Fp8Config(QuantizationConfig):
Expand Down

0 comments on commit 3f3b6b2

Please sign in to comment.