From 5010a9a4cdcb5fbb1e0a6b6d6887baf6b9eca601 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Tue, 8 Oct 2024 02:17:53 -0400 Subject: [PATCH] Fix accumulator and result types (#14) For example, perform f16 matmul with f32 as the accumulator type and truncate the result to f16. This is more realistic than using f16 as the accumulator type. Keep track of operand, accumulator, and result types in `GemmConfig`. --- gemmbench/gemm_bench.py | 4 +- gemmbench/gemm_utils.py | 105 +++++++++++++----------- gemmbench/problems.py | 171 ++++++++++++++++++++++++++++++++-------- 3 files changed, 201 insertions(+), 79 deletions(-) diff --git a/gemmbench/gemm_bench.py b/gemmbench/gemm_bench.py index 56c6dbf..c3afe14 100644 --- a/gemmbench/gemm_bench.py +++ b/gemmbench/gemm_bench.py @@ -107,7 +107,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, target = args.target extra_compiler_args = list(args.Xiree_compile) dump_dir = args.dump_dir - + args = itertools.starmap( lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs ) @@ -171,7 +171,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6) results.append(( - index, tag, name, vmfb_hash, config.M, config.N, config.K, config.dtype, config.tA, config.tB, + index, tag, name, vmfb_hash, config.M, config.N, config.K, config.operand_element_type, config.tA, config.tB, round(benchmark_gemm_mean_time_us, 4), round(arithmetic_intensity, 4), round(tflops_per_second, 4), diff --git a/gemmbench/gemm_utils.py b/gemmbench/gemm_utils.py index 03c4228..7dbbdcd 100644 --- a/gemmbench/gemm_utils.py +++ b/gemmbench/gemm_utils.py @@ -15,10 +15,12 @@ class GemmConfig: K: int tA: str tB: str - dtype: str + operand_element_type: str + accumulator_element_type: str + result_element_type: str def get_name(self) -> str: - name = f"gemm_{self.M}_{self.N}_{self.K}_{self.dtype}" + name = f"gemm_{self.M}_{self.N}_{self.K}_{self.operand_element_type}_{self.accumulator_element_type}" if self.tA == "T": name += "_tA" elif self.tB == "T": @@ -27,30 +29,26 @@ def get_name(self) -> str: def get_inp1(self) -> str: if self.tA == "T": - inp1 = f"{self.K}x{self.M}x{self.dtype}" - else: - inp1 = f"{self.M}x{self.K}x{self.dtype}" - return inp1 + return f"{self.K}x{self.M}x{self.operand_element_type}" + return f"{self.M}x{self.K}x{self.operand_element_type}" def get_inp2(self) -> str: if self.tB == "T": - inp2 = f"{self.N}x{self.K}x{self.dtype}" - else: - inp2 = f"{self.K}x{self.N}x{self.dtype}" - return inp2 + return f"{self.N}x{self.K}x{self.operand_element_type}" + return f"{self.K}x{self.N}x{self.operand_element_type}" def get_byte_count(self) -> int: - dtype_bits_map = { - "f32": 32, - "f16": 16, - "bf16": 16, - "f8E4M3FNUZ": 8, - "i8": 8, - "i32": 32, + dtype_to_bytes = { + "f32": 4, + "f16": 2, + "bf16": 2, + "f8E4M3FNUZ": 1, + "i8": 1, + "i32": 4, } - bytes_per_element = dtype_bits_map[self.dtype] // 8 - element_count = self.M * self.K + self.N * self.K + self.M * self.N - byte_count = element_count * bytes_per_element + operand_bytes_per_element = dtype_to_bytes[self.operand_element_type] + result_bytes_per_element = dtype_to_bytes[self.result_element_type] + byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element return byte_count def get_flops(self) -> int: @@ -61,40 +59,54 @@ def generate_mlir(config: GemmConfig): K = config.K M = config.M N = config.N - dtype = config.dtype + operand_element_type = config.operand_element_type + acc_element_type = config.accumulator_element_type + result_element_type = config.result_element_type + assert not operand_element_type.startswith('i'), "Integer types not supported yet" + tA = config.tA tB = config.tB mlir_template_A = f""" module {{ - func.func @main(%arg0: tensor<{K}x{M}x{dtype}>, %arg1: tensor<{K}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{ - %cst = arith.constant 0.000000e+00 : {dtype} - %0 = tensor.empty() : tensor<{M}x{N}x{dtype}> - %1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{dtype}>, tensor<{K}x{N}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - return %2 : tensor<{M}x{N}x{dtype}> + func.func @main(%arg0: tensor<{K}x{M}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ + %cst = arith.constant 0.000000e+00 : {acc_element_type} + %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> + %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> + %2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) + outs(%1 : tensor<{M}x{N}x{acc_element_type}>) + -> tensor<{M}x{N}x{acc_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> + return %3 : tensor<{M}x{N}x{result_element_type}> }} }} """ mlir_template_B = f""" module {{ - func.func @main(%arg0: tensor<{M}x{K}x{dtype}>, %arg1: tensor<{N}x{K}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{ - %cst = arith.constant 0.000000e+00 : {dtype} - %0 = tensor.empty() : tensor<{M}x{N}x{dtype}> - %1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{dtype}>, tensor<{N}x{K}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - return %2 : tensor<{M}x{N}x{dtype}> + func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{N}x{K}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ + %cst = arith.constant 0.000000e+00 : {acc_element_type} + %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> + %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> + %2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>) + outs(%1 : tensor<{M}x{N}x{acc_element_type}>) + -> tensor<{M}x{N}x{acc_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> + return %3 : tensor<{M}x{N}x{result_element_type}> }} }} """ - mlir_template = f"""module {{ - func.func @main(%arg0: tensor<{M}x{K}x{dtype}>, %arg1: tensor<{K}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{ - %cst = arith.constant 0.000000e+00 : {dtype} - %0 = tensor.empty() : tensor<{M}x{N}x{dtype}> - %1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{dtype}>, tensor<{K}x{N}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> - return %2 : tensor<{M}x{N}x{dtype}> + mlir_template = f""" +module {{ + func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{ + %cst = arith.constant 0.000000e+00 : {acc_element_type} + %0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}> + %1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>) + outs(%1 : tensor<{M}x{N}x{acc_element_type}>) + -> tensor<{M}x{N}x{acc_element_type}> + %3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}> + return %3 : tensor<{M}x{N}x{result_element_type}> }} }} """ @@ -104,7 +116,10 @@ def generate_mlir(config: GemmConfig): return mlir_template_B return mlir_template + def generate_tk_mlir(config: GemmConfig): + assert config.operand_element_type == 'f16', "Unsupported problem" + assert config.accumulator_element_type == 'f32', "Unsupported problem" # Input sizes M = tkl.sym.M N = tkl.sym.N @@ -158,12 +173,12 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # repeat represents the results of the loop tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - + shape = [config.M, config.N, config.K] - dtype_map = { + operand_element_type_map = { "f16": torch.float16, } - dtype = dtype_map[config.dtype] + operand_element_type = operand_element_type_map[config.operand_element_type] hyperparams = { ADDRESS_SPACE: SHARED_ADDRESS_SPACE, @@ -180,8 +195,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: with tk.gen.TestLaunchContext( hyperparams, canonicalize=True, run=True, run_config=config ): - a = torch.randn(shape[0], shape[2], dtype=dtype) - b = torch.randn(shape[1], shape[2], dtype=dtype) + a = torch.randn(shape[0], shape[2], dtype=operand_element_type) + b = torch.randn(shape[1], shape[2], dtype=operand_element_type) c = torch.zeros(shape[0], shape[1], dtype=torch.float32) mb = gemm(a, b, c) diff --git a/gemmbench/problems.py b/gemmbench/problems.py index c0cc8c5..69dbe47 100644 --- a/gemmbench/problems.py +++ b/gemmbench/problems.py @@ -9,11 +9,27 @@ import re -def is_compute_bound(M, N, K, bpe): +def num_bytes(dtype: str) -> int: + return {"f16": 2, "bf16": 2, "f32": 4, "i8": 1, "i32": 4}[dtype] + + +def get_default_accumulator_element_type(operand_element_type: str) -> str: + return {"f16": "f32", "bf16": "f32", "f32": "f32", "i8": "i32", "i32": "i32"}[ + operand_element_type + ] + + +def get_default_result_element_type(operand_element_type: str) -> str: + return operand_element_type + + +def is_compute_bound(M: int, N: int, K: int, dtype: str) -> bool: """Is this GEMM compute (or memory) bound?""" magic_ratio = 64 flops = 2 * M * N * K - bytes = bpe * (M * K + K * N + M * N) + elem_type_bytes = num_bytes(dtype) + result_bytes = num_bytes(get_default_result_element_type(dtype)) + bytes = elem_type_bytes * (M * K + K * N) + result_bytes * (M * N) return flops > magic_ratio * bytes @@ -654,19 +670,24 @@ def is_compute_bound(M, N, K, bpe): (4096, 5120, 640), ] + def llama13bmatvec(dtype: str) -> list[GemmConfig]: configs = [] """LLAMA 13b, single batch, FP16.""" for m, n, k, model, gcount in LLAMA: if n == 1 and model == "13b": - configs.append(GemmConfig( - m, - n, - k, - "T", - "N", - dtype - )) + configs.append( + GemmConfig( + m, + n, + k, + "T", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs @@ -681,7 +702,9 @@ def llama13bmatvecbf16(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -697,7 +720,9 @@ def llama70bmatvec(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -713,7 +738,9 @@ def llama70bmatvecbf16(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -730,7 +757,9 @@ def llama13bskinny(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -747,7 +776,9 @@ def llama13bskinnybf16(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -764,7 +795,9 @@ def llama70bskinny(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -781,7 +814,9 @@ def llama70bskinnybf16(dtype: str) -> list[GemmConfig]: k, "T", "N", - dtype + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), )) return configs @@ -790,8 +825,17 @@ def gpt4memory(dtype: str) -> list[GemmConfig]: """GPT4 memory bound GEMMs; FP16.""" configs = [] for m, n, k in GPT4: - hgemm = GemmConfig(m, n, k, "N", "N", dtype) - if not is_compute_bound(m, n, k, 2): + hgemm = GemmConfig( + m, + n, + k, + "N", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + if not is_compute_bound(m, n, k, dtype): configs.append(hgemm) return configs @@ -800,28 +844,51 @@ def gpt4compute(dtype: str) -> list[GemmConfig]: """GPT4 compute bound GEMMs; FP16.""" configs = [] for m, n, k in GPT4: - hgemm = GemmConfig(m, n, k, "N", "N", dtype) - if is_compute_bound(m, n, k, 2): + hgemm = GemmConfig( + m, + n, + k, + "N", + "N", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + if is_compute_bound(m, n, k, dtype): configs.append(hgemm) return configs def tk_default(dtype: str) -> list[GemmConfig]: """TK Shapes.""" + acc_type = get_default_accumulator_element_type(dtype) + res_type = get_default_result_element_type(dtype) configs = [] M, N, K = 1024, 5120, 640 - configs.append(GemmConfig(M, N, K, "N", "T", dtype)) + configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) M, N, K = 2048, 10240, 1280 - configs.append(GemmConfig(M, N, K, "N", "T", dtype)) + configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) M, N, K = 4096, 20480, 2560 - configs.append(GemmConfig(M, N, K, "N", "T", dtype)) + configs.append(GemmConfig(M, N, K, "N", "T", dtype, acc_type, res_type)) return configs + def tk_unet(dtype: str) -> list[GemmConfig]: """UNET Shapes for TK.""" configs = [] for m, n, k in UNET: - configs.append(GemmConfig(m, n, k, "N", "T", dtype)) + configs.append( + GemmConfig( + m, + n, + k, + "N", + "T", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs @@ -829,25 +896,59 @@ def llama70bmemory(dtype: str) -> list[GemmConfig]: """LLAMA 70b memory bound GEMMs; NT; BF16.""" configs = [] for n in [1280, 3584, 7168]: - configs.append(GemmConfig(2, n, 8192, "N", "T", dtype)) + configs.append( + GemmConfig( + 2, + n, + 8192, + "N", + "T", + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs def compute(dtype: str) -> list[GemmConfig]: """Compute bound GEMMs.""" - #for dtype in ["fp16", "bf16", "fp8"]: configs = [] for tA, tB in [("N", "N"), ("N", "T"), ("T", "N")]: - configs.append(GemmConfig(4096, 4096, 8192, tA, tB, dtype)) + configs.append( + GemmConfig( + 4096, + 4096, + 8192, + tA, + tB, + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs + def unet(dtype: str) -> list[GemmConfig]: configs = [] for tA, tB in [("N", "N"), ("N", "T")]: for m, n, k in UNET: - configs.append(GemmConfig(m, n, k, tA, tB, dtype)) + configs.append( + GemmConfig( + m, + n, + k, + tA, + tB, + dtype, + get_default_accumulator_element_type(dtype), + get_default_result_element_type(dtype), + ) + ) return configs + def get_gemm_configs() -> list[tuple[str, GemmConfig]]: llama13bmatvec_configs: list[GemmConfig] = [] llama13bmatvec_configs += llama13bmatvec("f16") @@ -890,6 +991,7 @@ def get_gemm_configs() -> list[tuple[str, GemmConfig]]: return all_configs + def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]: configs: list[tuple[str, GemmConfig]] = [] tk_default_configs = tk_default("f16") @@ -899,12 +1001,17 @@ def get_tk_gemm_configs() -> list[tuple[str, GemmConfig]]: configs += [("unet", x) for x in tk_unet_configs] return configs -def get_matching_configs(tagged_configs: list[tuple[str, GemmConfig]], - dtypes: list[str], variants: list[str], tag_regex: str) -> list[tuple[str, GemmConfig]]: + +def get_matching_configs( + tagged_configs: list[tuple[str, GemmConfig]], + dtypes: list[str], + variants: list[str], + tag_regex: str, +) -> list[tuple[str, GemmConfig]]: tag_re = re.compile(tag_regex) matching_configs: list[tuple[str, GemmConfig]] = [] for tag, config in tagged_configs: - if config.dtype not in dtypes: + if config.operand_element_type not in dtypes: continue if f"{config.tA}{config.tB}" not in variants: continue