Skip to content

Commit

Permalink
Fix accumulator and result types (#14)
Browse files Browse the repository at this point in the history
For example, perform f16 matmul with f32 as the accumulator type and
truncate the result to f16. This is more realistic than using f16 as the
accumulator type.

Keep track of operand, accumulator, and result types in `GemmConfig`.
  • Loading branch information
kuhar authored Oct 8, 2024
1 parent 218e39c commit 5010a9a
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 79 deletions.
4 changes: 2 additions & 2 deletions gemmbench/gemm_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
target = args.target
extra_compiler_args = list(args.Xiree_compile)
dump_dir = args.dump_dir

args = itertools.starmap(
lambda tag, config: (tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args, tk, dump_dir), configs
)
Expand Down Expand Up @@ -171,7 +171,7 @@ def compile_gemm(tag, config, kernel_dir, vmfb_dir, target, extra_compiler_args,
tflops_per_second = (flops / 1e12) / (benchmark_gemm_mean_time_us / 1e6)

results.append((
index, tag, name, vmfb_hash, config.M, config.N, config.K, config.dtype, config.tA, config.tB,
index, tag, name, vmfb_hash, config.M, config.N, config.K, config.operand_element_type, config.tA, config.tB,
round(benchmark_gemm_mean_time_us, 4),
round(arithmetic_intensity, 4),
round(tflops_per_second, 4),
Expand Down
105 changes: 60 additions & 45 deletions gemmbench/gemm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ class GemmConfig:
K: int
tA: str
tB: str
dtype: str
operand_element_type: str
accumulator_element_type: str
result_element_type: str

def get_name(self) -> str:
name = f"gemm_{self.M}_{self.N}_{self.K}_{self.dtype}"
name = f"gemm_{self.M}_{self.N}_{self.K}_{self.operand_element_type}_{self.accumulator_element_type}"
if self.tA == "T":
name += "_tA"
elif self.tB == "T":
Expand All @@ -27,30 +29,26 @@ def get_name(self) -> str:

def get_inp1(self) -> str:
if self.tA == "T":
inp1 = f"{self.K}x{self.M}x{self.dtype}"
else:
inp1 = f"{self.M}x{self.K}x{self.dtype}"
return inp1
return f"{self.K}x{self.M}x{self.operand_element_type}"
return f"{self.M}x{self.K}x{self.operand_element_type}"

def get_inp2(self) -> str:
if self.tB == "T":
inp2 = f"{self.N}x{self.K}x{self.dtype}"
else:
inp2 = f"{self.K}x{self.N}x{self.dtype}"
return inp2
return f"{self.N}x{self.K}x{self.operand_element_type}"
return f"{self.K}x{self.N}x{self.operand_element_type}"

def get_byte_count(self) -> int:
dtype_bits_map = {
"f32": 32,
"f16": 16,
"bf16": 16,
"f8E4M3FNUZ": 8,
"i8": 8,
"i32": 32,
dtype_to_bytes = {
"f32": 4,
"f16": 2,
"bf16": 2,
"f8E4M3FNUZ": 1,
"i8": 1,
"i32": 4,
}
bytes_per_element = dtype_bits_map[self.dtype] // 8
element_count = self.M * self.K + self.N * self.K + self.M * self.N
byte_count = element_count * bytes_per_element
operand_bytes_per_element = dtype_to_bytes[self.operand_element_type]
result_bytes_per_element = dtype_to_bytes[self.result_element_type]
byte_count = (self.M * self.K + self.N * self.K) * operand_bytes_per_element + (self.M * self.N) * result_bytes_per_element
return byte_count

def get_flops(self) -> int:
Expand All @@ -61,40 +59,54 @@ def generate_mlir(config: GemmConfig):
K = config.K
M = config.M
N = config.N
dtype = config.dtype
operand_element_type = config.operand_element_type
acc_element_type = config.accumulator_element_type
result_element_type = config.result_element_type
assert not operand_element_type.startswith('i'), "Integer types not supported yet"

tA = config.tA
tB = config.tB
mlir_template_A = f"""
module {{
func.func @main(%arg0: tensor<{K}x{M}x{dtype}>, %arg1: tensor<{K}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{
%cst = arith.constant 0.000000e+00 : {dtype}
%0 = tensor.empty() : tensor<{M}x{N}x{dtype}>
%1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{dtype}>, tensor<{K}x{N}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
return %2 : tensor<{M}x{N}x{dtype}>
func.func @main(%arg0: tensor<{K}x{M}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul_transpose_a ins(%arg0, %arg1 : tensor<{K}x{M}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template_B = f"""
module {{
func.func @main(%arg0: tensor<{M}x{K}x{dtype}>, %arg1: tensor<{N}x{K}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{
%cst = arith.constant 0.000000e+00 : {dtype}
%0 = tensor.empty() : tensor<{M}x{N}x{dtype}>
%1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{dtype}>, tensor<{N}x{K}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
return %2 : tensor<{M}x{N}x{dtype}>
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{N}x{K}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{N}x{K}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""

mlir_template = f"""module {{
func.func @main(%arg0: tensor<{M}x{K}x{dtype}>, %arg1: tensor<{K}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}> {{
%cst = arith.constant 0.000000e+00 : {dtype}
%0 = tensor.empty() : tensor<{M}x{N}x{dtype}>
%1 = linalg.fill ins(%cst : {dtype}) outs(%0 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{dtype}>, tensor<{K}x{N}x{dtype}>) outs(%1 : tensor<{M}x{N}x{dtype}>) -> tensor<{M}x{N}x{dtype}>
return %2 : tensor<{M}x{N}x{dtype}>
mlir_template = f"""
module {{
func.func @main(%arg0: tensor<{M}x{K}x{operand_element_type}>, %arg1: tensor<{K}x{N}x{operand_element_type}>) -> tensor<{M}x{N}x{result_element_type}> {{
%cst = arith.constant 0.000000e+00 : {acc_element_type}
%0 = tensor.empty() : tensor<{M}x{N}x{acc_element_type}>
%1 = linalg.fill ins(%cst : {acc_element_type}) outs(%0 : tensor<{M}x{N}x{acc_element_type}>) -> tensor<{M}x{N}x{acc_element_type}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<{M}x{K}x{operand_element_type}>, tensor<{K}x{N}x{operand_element_type}>)
outs(%1 : tensor<{M}x{N}x{acc_element_type}>)
-> tensor<{M}x{N}x{acc_element_type}>
%3 = arith.truncf %2 : tensor<{M}x{N}x{acc_element_type}> to tensor<{M}x{N}x{result_element_type}>
return %3 : tensor<{M}x{N}x{result_element_type}>
}}
}}
"""
Expand All @@ -104,7 +116,10 @@ def generate_mlir(config: GemmConfig):
return mlir_template_B
return mlir_template


def generate_tk_mlir(config: GemmConfig):
assert config.operand_element_type == 'f16', "Unsupported problem"
assert config.accumulator_element_type == 'f32', "Unsupported problem"
# Input sizes
M = tkl.sym.M
N = tkl.sym.N
Expand Down Expand Up @@ -158,12 +173,12 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:

# repeat represents the results of the loop
tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

shape = [config.M, config.N, config.K]
dtype_map = {
operand_element_type_map = {
"f16": torch.float16,
}
dtype = dtype_map[config.dtype]
operand_element_type = operand_element_type_map[config.operand_element_type]

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
Expand All @@ -180,8 +195,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
with tk.gen.TestLaunchContext(
hyperparams, canonicalize=True, run=True, run_config=config
):
a = torch.randn(shape[0], shape[2], dtype=dtype)
b = torch.randn(shape[1], shape[2], dtype=dtype)
a = torch.randn(shape[0], shape[2], dtype=operand_element_type)
b = torch.randn(shape[1], shape[2], dtype=operand_element_type)
c = torch.zeros(shape[0], shape[1], dtype=torch.float32)
mb = gemm(a, b, c)

Expand Down
Loading

0 comments on commit 5010a9a

Please sign in to comment.