Skip to content

Commit e2b85cf

Browse files
authored
Fix w8a8 benchmark and add Llama-3-8B (vllm-project#5562)
1 parent 845a3f2 commit e2b85cf

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
4646
# impl
4747

4848

49-
def pytorch_i8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
49+
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
5050
scale_b: torch.tensor,
5151
out_dtype: torch.dtype) -> torch.tensor:
5252
return torch.mm(a, b)
@@ -115,7 +115,7 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
115115
timers.append(
116116
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
117117
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
118-
torch.bfloat16, label, sub_label, pytorch_i8_impl,
118+
torch.bfloat16, label, sub_label, pytorch_mm_impl,
119119
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
120120

121121
# cutlass impl
@@ -136,6 +136,13 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
136136

137137
timers = []
138138

139+
# pytorch impl w. bf16
140+
timers.append(
141+
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
142+
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
143+
torch.bfloat16, label, sub_label, pytorch_mm_impl,
144+
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
145+
139146
# pytorch impl: bf16 output, without fp8 fast accum
140147
timers.append(
141148
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
@@ -160,14 +167,12 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
160167

161168
# cutlass impl: bf16 output
162169
timers.append(
163-
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
164-
torch.bfloat16, label, sub_label, cutlass_impl,
165-
"cutlass_fp8_fp8_bf16_scaled_mm"))
170+
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
171+
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
166172
# cutlass impl: fp16 output
167173
timers.append(
168-
bench_fn(a, b, scale_a.to(device="cpu"), scale_b.to(device="cpu"),
169-
torch.float16, label, sub_label, cutlass_impl,
170-
"cutlass_fp8_fp8_fp16_scaled_mm"))
174+
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
175+
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
171176
return timers
172177

173178

benchmarks/cutlass_benchmarks/weight_shapes.py

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
([4096, 22016], 1),
2323
([11008, 4096], 0),
2424
],
25+
"meta-llama/Llama-3-8b": [
26+
([4096, 6144], 1),
27+
([4096, 4096], 0),
28+
([4096, 28672], 1),
29+
([14336, 4096], 0),
30+
],
2531
"meta-llama/Llama-2-13b-hf": [
2632
([5120, 15360], 1),
2733
([5120, 5120], 0),

0 commit comments

Comments
 (0)