Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update benchmarking code #10982

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")

# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")

# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
Expand All @@ -222,13 +222,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG v3.5.1
GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
GIT_PROGRESS TRUE

# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
GIT_SHALLOW TRUE
# GIT_SHALLOW TRUE
)
endif()
FetchContent_MakeAvailable(cutlass)
Expand Down
184 changes: 184 additions & 0 deletions benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
## Cutlass benchmark V1

from typing import Callable, Iterable

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import make_rand_tensors, to_fp16, to_bf16

Check failure on line 8 in benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py:8:38: F401 `utils.to_fp16` imported but unused

Check failure on line 8 in benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F401)

benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py:8:47: F401 `utils.to_bf16` imported but unused

import vllm._custom_ops as ops


# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
min_run_time = 1

globals = {
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)


def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.int8

# Create tensors
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)

if not torch.allclose(out, out_ref):
print("Incorrect result")
exit()

timers = []

# pytorch impl - bfloat16
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16)))

# pytorch impl - float16
timers.append(
bench_fn(label, sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))

# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))

# cutlass with bias: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))

# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.float16))

# cutlass with bias: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_fp16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)))

return timers


def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn

# Create tensors
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)

Check failure on line 98 in benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (F841)

benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py:98:5: F841 Local variable `bias` is assigned to but never used

out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
out_ref = torch._scaled_mm(a, b, scale_a=scale_a, scale_b=scale_b, out_dtype=torch.bfloat16)

Check failure on line 101 in benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

benchmarks/cutlass_benchmarks/dense_mm/bench_v1.py:101:81: E501 Line too long (96 > 80)

if not torch.allclose(out, out_ref):
print("Incorrect result")
exit()

timers = []

# pytorch impl w. bf16
timers.append(
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda")))

# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16))

# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True))

# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16))

# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
use_fast_accum=True))

# cutlass impl: bf16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass impl: fp16 output
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))

return timers


def bench_v1(dtype: torch.dtype, m: int, k: int, n: int, label: str,
sub_label: str) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label)
raise ValueError("unsupported type")
Loading
Loading