@@ -46,7 +46,7 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
46
46
# impl
47
47
48
48
49
- def pytorch_i8_impl (a : torch .tensor , b : torch .tensor , scale_a : torch .tensor ,
49
+ def pytorch_mm_impl (a : torch .tensor , b : torch .tensor , scale_a : torch .tensor ,
50
50
scale_b : torch .tensor ,
51
51
out_dtype : torch .dtype ) -> torch .tensor :
52
52
return torch .mm (a , b )
@@ -115,7 +115,7 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
115
115
timers .append (
116
116
bench_fn (a .to (dtype = torch .bfloat16 , device = "cuda" ),
117
117
b .to (dtype = torch .bfloat16 , device = "cuda" ), scale_a , scale_b ,
118
- torch .bfloat16 , label , sub_label , pytorch_i8_impl ,
118
+ torch .bfloat16 , label , sub_label , pytorch_mm_impl ,
119
119
"pytorch_bf16_bf16_bf16_matmul-no-scales" ))
120
120
121
121
# cutlass impl
@@ -136,6 +136,13 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
136
136
137
137
timers = []
138
138
139
+ # pytorch impl w. bf16
140
+ timers .append (
141
+ bench_fn (a .to (dtype = torch .bfloat16 , device = "cuda" ),
142
+ b .to (dtype = torch .bfloat16 , device = "cuda" ), scale_a , scale_b ,
143
+ torch .bfloat16 , label , sub_label , pytorch_mm_impl ,
144
+ "pytorch_bf16_bf16_bf16_matmul-no-scales" ))
145
+
139
146
# pytorch impl: bf16 output, without fp8 fast accum
140
147
timers .append (
141
148
bench_fn (a , b , scale_a , scale_b , torch .bfloat16 , label , sub_label ,
@@ -160,14 +167,12 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
160
167
161
168
# cutlass impl: bf16 output
162
169
timers .append (
163
- bench_fn (a , b , scale_a .to (device = "cpu" ), scale_b .to (device = "cpu" ),
164
- torch .bfloat16 , label , sub_label , cutlass_impl ,
165
- "cutlass_fp8_fp8_bf16_scaled_mm" ))
170
+ bench_fn (a , b , scale_a , scale_b , torch .bfloat16 , label , sub_label ,
171
+ cutlass_impl , "cutlass_fp8_fp8_bf16_scaled_mm" ))
166
172
# cutlass impl: fp16 output
167
173
timers .append (
168
- bench_fn (a , b , scale_a .to (device = "cpu" ), scale_b .to (device = "cpu" ),
169
- torch .float16 , label , sub_label , cutlass_impl ,
170
- "cutlass_fp8_fp8_fp16_scaled_mm" ))
174
+ bench_fn (a , b , scale_a , scale_b , torch .float16 , label , sub_label ,
175
+ cutlass_impl , "cutlass_fp8_fp8_fp16_scaled_mm" ))
171
176
return timers
172
177
173
178
0 commit comments