From 262ce12a0df0cb10931c1acdb11ea8fb81a6c1c7 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 14:54:17 -0800 Subject: [PATCH 01/11] [TKW] Add CDNA2 Int8 intrinsics and refactor intrinsic enums - Added CDNA2 int8 intrinsic layouts - Modified iree_ref to handle int gemms - Modified certain e2e test to require certain GPU arch to be available - Modified enum for easy handling in the future - Get default architecture function - Borrowed device_randint from Ivan Signed-off-by: Stanley Winata Co-authored-by: Ivan Butygin Signed-off-by: Stanley Winata --- iree/turbine/kernel/wave/constraints.py | 48 +++++-- iree/turbine/kernel/wave/iree_utils.py | 15 ++- iree/turbine/kernel/wave/utils.py | 20 ++- tests/kernel/wave/wave_gemm_test.py | 171 ++++++++++++++++++++++++ 4 files changed, 238 insertions(+), 16 deletions(-) diff --git a/iree/turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py index 83441af0..5b04e00d 100644 --- a/iree/turbine/kernel/wave/constraints.py +++ b/iree/turbine/kernel/wave/constraints.py @@ -15,11 +15,43 @@ from ..lang.global_symbols import * +""" +Formatting for different target intrinsics: + __xx_[_] + +Values: 0xABCD where: +* A = vendor: + * 1 = AMD + * 2 = NVIDIA +* B = architecture. When an intrinsic exists in multiple architectures, this + should be the architecture it was introduced in, as long as it still + has the same semantics. If a new architecture breaks an existing + intrinsic's semantics, we can use that field for versioning. + * For AMD: + * 0 = CDNA1 + * 1 = CDNA2 + * 2 = CDNA3 + * 8 = RDNA3 +* C = element type of A-matrix: + * 0 = 64-bit float (e.g. IEEE754 double precision) + * 1 = 32-bit float (e.g. IEEE754 single precision, and "xf32" fast variants) + * 2 = 16-bit float (incl. IREE754 half and bf16) + * 3 = 8-bit float (incl. f8E5M2, f8E4M3, and "FNUZ" variants) + * C = 8-bit integer (any signedness) +* D enumerates intrinsics that share the same 0xABC* bits. +""" + + class MMAType(Enum): - F32_16x16x16_F16 = 0 - F32_32x32x8_F16 = 1 - F32_16x16x32_F8 = 2 - F32_32x32x16_F8 = 3 + # Intrinsics introduced in CDNA1 + F32_16x16x16_F16 = 0x1020 + F32_32x32x8_F16 = 0x1021 + I32_16x16x16_I8 = 0x10C0 + I32_32x32x8_I8 = 0x10C1 + + # Intrinsics introduced in CDNA3 + F32_16x16x32_F8 = 0x1230 + F32_32x32x16_F8 = 0x1231 class MMAOperand(Enum): @@ -89,9 +121,9 @@ def get_thread_id_from_workgroup_dim(self, workgroup_dim: int) -> IndexSymbol: def mma_matrix_shapes(self) -> tuple[int]: # TODO: Eventually the shapes and indices should be provided by a tool match self.mma_type: - case MMAType.F32_16x16x16_F16: + case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8: return (16, 16, 16) - case MMAType.F32_32x32x8_F16: + case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return (32, 32, 8) case MMAType.F32_16x16x32_F8: return (16, 16, 32) @@ -151,7 +183,7 @@ def apply( lane = self.linearized_thread_id % self.threads_per_wave match self.mma_type: # (M x K, N x K) -> M x N - case MMAType.F32_16x16x16_F16: + case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8: offset = [ Piecewise( (lane % 16, ~MMA_ACC), @@ -170,7 +202,7 @@ def apply( 1, # N 1, # K ] - case MMAType.F32_32x32x8_F16: + case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: offset = [ Piecewise( (lane % 32, ~MMA_ACC), diff --git a/iree/turbine/kernel/wave/iree_utils.py b/iree/turbine/kernel/wave/iree_utils.py index d1031bd5..9c36e7e7 100644 --- a/iree/turbine/kernel/wave/iree_utils.py +++ b/iree/turbine/kernel/wave/iree_utils.py @@ -7,7 +7,7 @@ import torch from typing import Any from .utils import compile_and_invoke -from ...support.conversions import TORCH_DTYPE_TO_MLIR_TYPE_ASM +from ...support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM def get_chain_mmt_asm( @@ -83,6 +83,7 @@ def get_mmt_asm( lhs_type: str, rhs_type: str, acc_type: str, + zero: str, batch: bool = False, cast_fp8: bool = False, ) -> str: @@ -93,7 +94,7 @@ def get_mmt_asm( if not cast_fp8: matmul_function = f""" func.func @{func_name}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{acc_type}> {{ - %c0 = arith.constant 0.0 : {acc_dtype} + %c0 = arith.constant {zero} : {acc_dtype} %init = tensor.empty() : tensor<{acc_type}> %inital_result = linalg.fill ins(%c0 : {acc_dtype}) outs(%init : tensor<{acc_type}>) -> tensor<{acc_type}> %result = linalg.{operator} ins(%lhs, %rhs: tensor<{lhs_type}>, tensor<{rhs_type}>) @@ -138,7 +139,7 @@ def get_conv_asm( def dtype_str(dtype: torch.dtype) -> str: - dtype_str = TORCH_DTYPE_TO_MLIR_TYPE_ASM.get(dtype, None) + dtype_str = TORCH_DTYPE_TO_IREE_TYPE_ASM[dtype] if dtype_str is None: raise ValueError(f"Unsupported dtype: {dtype}") return dtype_str @@ -165,8 +166,14 @@ def generate_iree_ref( lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) + zero = "0.0" if kernel_outputs[0].dtype.is_floating_point else "0" asm = get_mmt_asm( - lhs_type, rhs_type, acc_type, batch=False, cast_fp8=kernel_type == "mmt_f8" + lhs_type, + rhs_type, + acc_type, + zero, + batch=False, + cast_fp8=kernel_type == "mmt_f8", ) elif kernel_type == "bmmt": lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index a28ea2fa..e79eafb3 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -102,6 +102,18 @@ def get_default_run_config() -> dict[Any, Any]: return {"backend": "rocm", "device": "hip", "target": "gfx942"} +def get_default_arch() -> str: + """Return default ROCM architecture""" + if not torch.cuda.is_available(): + return "cpu" + device = torch.device("cuda") + gcnArch = torch.cuda.get_device_properties(device).gcnArchName + assert "gfx" in gcnArch, "Currently only support GFX/ROCm for get_default_arch." + # The gcnArchName comes back like gfx90a:sramecc+:xnack. + colon_pos = gcnArch.find(":") + return gcnArch[0:colon_pos] + + def print_trace(trace: CapturedTrace, custom_print: bool = True): """ Prints all subgraphs of a trace starting with the root graph. @@ -880,9 +892,9 @@ def ceildiv(a: int, b: int) -> int: def get_mfma_load_elems_per_thread(mfma_variant: MMAType) -> int: match mfma_variant: - case MMAType.F32_16x16x16_F16: + case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8: return 4 - case MMAType.F32_32x32x8_F16: + case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return 4 case MMAType.F32_16x16x32_F8: return 8 @@ -892,9 +904,9 @@ def get_mfma_load_elems_per_thread(mfma_variant: MMAType) -> int: def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int: match mfma_variant: - case MMAType.F32_16x16x16_F16: + case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8: return 4 - case MMAType.F32_32x32x8_F16: + case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return 16 case MMAType.F32_16x16x32_F8: return 4 diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 324089e0..d7d5b56d 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -15,9 +15,11 @@ from iree.turbine.kernel.wave.iree_utils import generate_iree_ref from iree.turbine.kernel.wave.utils import ( get_default_run_config, + get_default_arch, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, device_randn, + device_randint, device_zeros, ) from iree.turbine.kernel.wave.constraints import MMAType @@ -28,6 +30,12 @@ _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") +require_cdna2 = pytest.mark.skipif( + "gfx90" not in get_default_arch(), reason="Default device is not CDNA2" +) +require_cdna3 = pytest.mark.skipif( + "gfx94" not in get_default_arch(), reason="Default device is not CDNA3" +) # Whether to dump the generated MLIR module. test_dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) # Whether to use scheduling group barriers (needs LLVM fix). @@ -219,6 +227,169 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: @require_e2e +@require_cdna2 +@pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) +@pytest.mark.parametrize("enable_scheduling", [False, True]) +@pytest.mark.parametrize("dynamic_dims", [False, True]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.I32_16x16x16_I8, + MMAType.I32_32x32x8_I8, + ], +) +def testCDNA2IntGemm( + shape: tuple[int], + enable_scheduling: bool, + dynamic_dims: bool, + mfma_variant: MMAType, + request, +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(2, 2, 1), mma_type=mfma_variant + ) + ] + + # With dynamic dimensions, we need to add an assumption on how big + # the reduction dimension is to determine whether we can schedule or not. + if dynamic_dims: + constraints += [tkw.Assumption(K > BLOCK_K * 4)] + + # Wave-level micro-kernel. + # Since warps are not directly addressable, there is no + # explicit notion of a warp id (like a workgroup or thread id). + # This kernel uses the input sizes M, N, K throughout, as the tiling + # and data movement strategy is determined during the compilation process. + # These can be influenced by introducing constraints. + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i16], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.i32], + ): + c_reg = tkl.Register[M, N, tkl.i32](0.0) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: + # a_reg: tkw.Register[M, K, tkl.i16] + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + a_reg = tkw.cast(a_reg, tkl.i8) + # b_reg: tkw.Register[N, K, tkl.i16] + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.cast(b_reg, tkl.i8) + # acc: tkw.Register[M, N, tkl.i32] + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # repeat represents the results of the loop + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: shape[0], + N: shape[1], + K: shape[2], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, + } + config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[K] = hyperparams[K] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(K) + del hyperparams[M] + del hyperparams[N] + del hyperparams[K] + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + dynamic_symbols=dynamic_symbols, + dynamic_symbols_map=dynamic_symbols_map, + ): + randint_hi = 4 + a = device_randint(randint_hi, (shape[0], shape[2]), dtype=torch.int16) + b = device_randint(randint_hi, (shape[1], shape[2]), dtype=torch.int16) + c = device_zeros(shape[0], shape[1], dtype=torch.int32) + mb = gemm(a, b, c) + + if test_dump_generated_mlir: + filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + if run_bench: + if dump_perf is not None: + config["benchmark_results_file"] = os.path.join( + dump_perf, "iree_" + perf_filename + ) + iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.int32) + generate_iree_ref("mmt", [a, b], [iree_ref], config, run_bench=run_bench) + assert_close(c, iree_ref, check_device=False) + + +@require_e2e +@require_cdna3 @pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) @pytest.mark.parametrize("enable_scheduling", [False, True]) @pytest.mark.parametrize( From 7e60f37af6430c1c39602fb6703fdc634299533b Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 15:30:03 -0800 Subject: [PATCH 02/11] Fix NIT and add CDNA3 int gemms Signed-off-by: Stanley Winata --- iree/turbine/kernel/wave/constraints.py | 10 +- iree/turbine/kernel/wave/iree_utils.py | 5 +- iree/turbine/kernel/wave/utils.py | 8 +- tests/kernel/wave/wave_gemm_test.py | 130 ++++++++++++++++++++++++ 4 files changed, 141 insertions(+), 12 deletions(-) diff --git a/iree/turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py index 5b04e00d..a12f222a 100644 --- a/iree/turbine/kernel/wave/constraints.py +++ b/iree/turbine/kernel/wave/constraints.py @@ -52,6 +52,8 @@ class MMAType(Enum): # Intrinsics introduced in CDNA3 F32_16x16x32_F8 = 0x1230 F32_32x32x16_F8 = 0x1231 + I32_16x16x32_I8 = 0x12C0 + I32_32x32x16_I8 = 0x12C1 class MMAOperand(Enum): @@ -125,9 +127,9 @@ def mma_matrix_shapes(self) -> tuple[int]: return (16, 16, 16) case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return (32, 32, 8) - case MMAType.F32_16x16x32_F8: + case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8: return (16, 16, 32) - case MMAType.F32_32x32x16_F8: + case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8: return (32, 32, 16) case _: return () @@ -226,7 +228,7 @@ def apply( 1, # N 1, # K ] - case MMAType.F32_16x16x32_F8: + case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8: offset = [ Piecewise( (lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC) @@ -244,7 +246,7 @@ def apply( 1, # N 1, # K ] - case MMAType.F32_32x32x16_F8: + case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8: offset = [ Piecewise( (lane % 32, ~MMA_ACC), diff --git a/iree/turbine/kernel/wave/iree_utils.py b/iree/turbine/kernel/wave/iree_utils.py index 9c36e7e7..7a435e15 100644 --- a/iree/turbine/kernel/wave/iree_utils.py +++ b/iree/turbine/kernel/wave/iree_utils.py @@ -83,7 +83,6 @@ def get_mmt_asm( lhs_type: str, rhs_type: str, acc_type: str, - zero: str, batch: bool = False, cast_fp8: bool = False, ) -> str: @@ -94,7 +93,7 @@ def get_mmt_asm( if not cast_fp8: matmul_function = f""" func.func @{func_name}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{acc_type}> {{ - %c0 = arith.constant {zero} : {acc_dtype} + %c0 = arith.constant {"0.0" if acc_dtype.startswith("f") else "0"} : {acc_dtype} %init = tensor.empty() : tensor<{acc_type}> %inital_result = linalg.fill ins(%c0 : {acc_dtype}) outs(%init : tensor<{acc_type}>) -> tensor<{acc_type}> %result = linalg.{operator} ins(%lhs, %rhs: tensor<{lhs_type}>, tensor<{rhs_type}>) @@ -166,12 +165,10 @@ def generate_iree_ref( lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) - zero = "0.0" if kernel_outputs[0].dtype.is_floating_point else "0" asm = get_mmt_asm( lhs_type, rhs_type, acc_type, - zero, batch=False, cast_fp8=kernel_type == "mmt_f8", ) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index e79eafb3..122a1978 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -896,9 +896,9 @@ def get_mfma_load_elems_per_thread(mfma_variant: MMAType) -> int: return 4 case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return 4 - case MMAType.F32_16x16x32_F8: + case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8: return 8 - case MMAType.F32_32x32x16_F8: + case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8: return 8 @@ -908,9 +908,9 @@ def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int: return 4 case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return 16 - case MMAType.F32_16x16x32_F8: + case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8: return 4 - case MMAType.F32_32x32x16_F8: + case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8: return 16 diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index d7d5b56d..4573932a 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -388,6 +388,136 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: assert_close(c, iree_ref, check_device=False) +@require_e2e +@require_cdna3 +@pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) +@pytest.mark.parametrize("enable_scheduling", [False, True]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.F32_16x16x32_F8, + MMAType.F32_32x32x16_F8, + ], +) +def testCDNA3IntGemm( + shape: tuple[int], enable_scheduling: bool, mfma_variant: MMAType, request +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(2, 2, 1), mma_type=mfma_variant + ) + ] + + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i16], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.i32], + ): + c_reg = tkl.Register[M, N, tkl.i32](0.0) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: + # a_reg: tkw.Register[M, K, tkl.i16] + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + a_reg = tkw.cast(a_reg, tkl.i8) + # b_reg: tkw.Register[N, K, tkl.i16] + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.cast(b_reg, tkl.i8) + # acc: tkw.Register[M, N, tkl.i32] + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: shape[0], + N: shape[1], + K: shape[2], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, + } + config = get_default_run_config() + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + ): + randint_hi = 4 + a = device_randint(randint_hi, (shape[0], shape[2]), dtype=torch.int16) + b = device_randint(randint_hi, (shape[1], shape[2]), dtype=torch.int16) + c = device_zeros(shape[0], shape[1], dtype=torch.int32) + mb = gemm(a, b, c) + + if test_dump_generated_mlir: + filename = f"wave_gemm_{'x'.join(map(str, shape))}_f8.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + if run_bench: + if dump_perf is not None: + config["benchmark_results_file"] = os.path.join( + dump_perf, "iree_" + perf_filename + ) + iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.int32) + generate_iree_ref("mmt", [a, b], [iree_ref], config, run_bench=run_bench) + assert_close(c, iree_ref, check_device=False) + + @require_e2e @require_cdna3 @pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) From 93bc2382907aabe4d2d429b973d7761fcd8babfa Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 15:34:31 -0800 Subject: [PATCH 03/11] Fix NIT Signed-off-by: Stanley Winata --- iree/turbine/kernel/wave/utils.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 122a1978..3f7ca5b7 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -920,12 +920,18 @@ def all_equal(input_list: list[Any]) -> bool: return all(elem == input_list[0] for elem in input_list) +def get_default_device(): + return "cuda" if torch.cuda.is_available() else "cpu" + + def device_randn(*args, **kwargs): - return torch.randn(*args, **kwargs).to("cuda") + device = get_default_device() + return torch.randn(*args, **kwargs).to(device) def device_randint(*args, **kwargs): - return torch.randint(*args, **kwargs).to("cuda") + device = get_default_device() + return torch.randint(*args, **kwargs).to(device) def device_randperm(*args, **kwargs): @@ -933,7 +939,8 @@ def device_randperm(*args, **kwargs): def device_zeros(*args, **kwargs): - return torch.zeros(*args, **kwargs).to("cuda") + device = get_default_device() + return torch.zeros(*args, **kwargs).to(device) def get_assumptions(constraints: list[Constraint]) -> list[Assumption]: From 9b08f6cb78557b9f05aee451c16de3b709767348 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 16:11:40 -0800 Subject: [PATCH 04/11] fix some more nits Signed-off-by: Stanley Winata --- iree/turbine/kernel/wave/utils.py | 9 +++------ tests/kernel/wave/wave_gemm_test.py | 2 +- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 3f7ca5b7..35677d27 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -925,13 +925,11 @@ def get_default_device(): def device_randn(*args, **kwargs): - device = get_default_device() - return torch.randn(*args, **kwargs).to(device) + return torch.randn(*args, **kwargs).to(get_default_device()) def device_randint(*args, **kwargs): - device = get_default_device() - return torch.randint(*args, **kwargs).to(device) + return torch.randint(*args, **kwargs).to(get_default_device()) def device_randperm(*args, **kwargs): @@ -939,8 +937,7 @@ def device_randperm(*args, **kwargs): def device_zeros(*args, **kwargs): - device = get_default_device() - return torch.zeros(*args, **kwargs).to(device) + return torch.zeros(*args, **kwargs).to(get_default_device()) def get_assumptions(constraints: list[Constraint]) -> list[Assumption]: diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 4573932a..66de9f12 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -333,7 +333,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: VALU_UNITS: 8, SHUFFLE_UNITS: 8, } - config = {"backend": "rocm", "device": "hip", "target": "gfx942"} + config = {"backend": "rocm", "device": "hip", "target": "gfx90a"} if run_bench: config["benchmark_batch_size"] = 10 config["benchmark_repetitions"] = 3 From 48fa66dd3ca5ddbbb4e2791b6db16eb1e7bf6214 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 16:14:42 -0800 Subject: [PATCH 05/11] naming nit Signed-off-by: Stanley Winata --- iree/turbine/kernel/wave/utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 35677d27..8ba763b3 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -920,16 +920,20 @@ def all_equal(input_list: list[Any]) -> bool: return all(elem == input_list[0] for elem in input_list) -def get_default_device(): +def get_default_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" +def to_default_device(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(get_default_device()) + + def device_randn(*args, **kwargs): - return torch.randn(*args, **kwargs).to(get_default_device()) + return to_default_device(torch.randn(*args, **kwargs)) def device_randint(*args, **kwargs): - return torch.randint(*args, **kwargs).to(get_default_device()) + return to_default_device(torch.randint(*args, **kwargs)) def device_randperm(*args, **kwargs): @@ -937,7 +941,7 @@ def device_randperm(*args, **kwargs): def device_zeros(*args, **kwargs): - return torch.zeros(*args, **kwargs).to(get_default_device()) + return to_default_device(torch.zeros(*args, **kwargs)) def get_assumptions(constraints: list[Constraint]) -> list[Assumption]: From 1385eb48f0dcc0b55ab7266d7a17fc24da280ec2 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 16:26:38 -0800 Subject: [PATCH 06/11] Add Mi250 runner Signed-off-by: Stanley Winata --- .github/workflows/perf.yaml | 2 +- .github/workflows/perf_mi250.yaml | 63 +++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/perf_mi250.yaml diff --git a/.github/workflows/perf.yaml b/.github/workflows/perf.yaml index c574d542..238fdbf3 100644 --- a/.github/workflows/perf.yaml +++ b/.github/workflows/perf.yaml @@ -1,4 +1,4 @@ -name: PERF +name: PERF on MI300 on: workflow_dispatch: diff --git a/.github/workflows/perf_mi250.yaml b/.github/workflows/perf_mi250.yaml new file mode 100644 index 00000000..fe80a3b7 --- /dev/null +++ b/.github/workflows/perf_mi250.yaml @@ -0,0 +1,63 @@ +name: PERF on MI250 + +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + schedule: + - cron: '30 5 * * *' + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + test: + name: "Unit Tests and Type Checking" + strategy: + fail-fast: false + matrix: + version: [3.11] + os: [nodai-amdgpu-mi250-x86-64] + runs-on: ${{matrix.os}} + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@v3 + with: + python-version: ${{matrix.version}} + + - name: "Checkout Code" + uses: actions/checkout@v3 + + - name: Cache Pip Packages + uses: actions/cache@v4 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + + - name: Install pip deps + run: | + python -m pip install --no-compile --upgrade pip + # Note: We install in three steps in order to satisfy requirements + # from non default locations first. Installing the PyTorch CPU + # wheels saves multiple minutes and a lot of bandwidth on runner setup. + pip install --no-compile -r pytorch-rocm-requirements.txt + pip install --no-cache-dir -r iree-requirements-ci.txt --upgrade + pip install -r requirements.txt -e . + + - name: Run e2e tests on MI300 + if: "contains(matrix.os, 'mi300') && !cancelled()" + run: | + export WAVE_RUN_E2E_TESTS=1 + export TEST_PARAMS_PATH="tests/kernel/wave/test_param.json" + pytest -n 1 --capture=tee-sys -vv ./tests/kernel/wave/ From 968c2e319524c68ded6e30da9b88a3f080fdbe23 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 16:29:18 -0800 Subject: [PATCH 07/11] Clean up runner yaml Signed-off-by: Stanley Winata --- .github/workflows/perf_mi250.yaml | 4 ++-- .github/workflows/{perf.yaml => perf_mi300.yaml} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename .github/workflows/{perf.yaml => perf_mi300.yaml} (100%) diff --git a/.github/workflows/perf_mi250.yaml b/.github/workflows/perf_mi250.yaml index fe80a3b7..ea54168d 100644 --- a/.github/workflows/perf_mi250.yaml +++ b/.github/workflows/perf_mi250.yaml @@ -55,8 +55,8 @@ jobs: pip install --no-cache-dir -r iree-requirements-ci.txt --upgrade pip install -r requirements.txt -e . - - name: Run e2e tests on MI300 - if: "contains(matrix.os, 'mi300') && !cancelled()" + - name: Run e2e tests on MI250 + if: "contains(matrix.os, 'mi250') && !cancelled()" run: | export WAVE_RUN_E2E_TESTS=1 export TEST_PARAMS_PATH="tests/kernel/wave/test_param.json" diff --git a/.github/workflows/perf.yaml b/.github/workflows/perf_mi300.yaml similarity index 100% rename from .github/workflows/perf.yaml rename to .github/workflows/perf_mi300.yaml From e2f798359a80b29b7351b5699a454ab86301fb91 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 16:34:45 -0800 Subject: [PATCH 08/11] Integrate back perf yamls Signed-off-by: Stanley Winata --- .../workflows/{perf_mi250.yaml => perf.yaml} | 11 +++- .github/workflows/perf_mi300.yaml | 63 ------------------- 2 files changed, 9 insertions(+), 65 deletions(-) rename .github/workflows/{perf_mi250.yaml => perf.yaml} (84%) delete mode 100644 .github/workflows/perf_mi300.yaml diff --git a/.github/workflows/perf_mi250.yaml b/.github/workflows/perf.yaml similarity index 84% rename from .github/workflows/perf_mi250.yaml rename to .github/workflows/perf.yaml index ea54168d..30328d98 100644 --- a/.github/workflows/perf_mi250.yaml +++ b/.github/workflows/perf.yaml @@ -1,4 +1,4 @@ -name: PERF on MI250 +name: PERF on: workflow_dispatch: @@ -24,7 +24,7 @@ jobs: fail-fast: false matrix: version: [3.11] - os: [nodai-amdgpu-mi250-x86-64] + os: [nodai-amdgpu-mi300-x86-64, nodai-amdgpu-mi250-x86-64] runs-on: ${{matrix.os}} env: PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" @@ -55,6 +55,13 @@ jobs: pip install --no-cache-dir -r iree-requirements-ci.txt --upgrade pip install -r requirements.txt -e . + - name: Run e2e tests on MI300 + if: "contains(matrix.os, 'mi300') && !cancelled()" + run: | + export WAVE_RUN_E2E_TESTS=1 + export TEST_PARAMS_PATH="tests/kernel/wave/test_param.json" + pytest -n 1 --capture=tee-sys -vv ./tests/kernel/wave/ + - name: Run e2e tests on MI250 if: "contains(matrix.os, 'mi250') && !cancelled()" run: | diff --git a/.github/workflows/perf_mi300.yaml b/.github/workflows/perf_mi300.yaml deleted file mode 100644 index 238fdbf3..00000000 --- a/.github/workflows/perf_mi300.yaml +++ /dev/null @@ -1,63 +0,0 @@ -name: PERF on MI300 - -on: - workflow_dispatch: - pull_request: - push: - branches: - - main - schedule: - - cron: '30 5 * * *' - -concurrency: - # A PR number if a pull request and otherwise the commit hash. This cancels - # queued and in-progress runs for the same PR (presubmit) or commit - # (postsubmit). The workflow name is prepended to avoid conflicts between - # different workflows. - group: ${{ github.workflow }}-${{ github.event.number || github.sha }} - cancel-in-progress: true - -jobs: - test: - name: "Unit Tests and Type Checking" - strategy: - fail-fast: false - matrix: - version: [3.11] - os: [nodai-amdgpu-mi300-x86-64] - runs-on: ${{matrix.os}} - env: - PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" - steps: - - name: "Setting up Python" - id: setup_python - uses: actions/setup-python@v3 - with: - python-version: ${{matrix.version}} - - - name: "Checkout Code" - uses: actions/checkout@v3 - - - name: Cache Pip Packages - uses: actions/cache@v4 - id: cache-pip - with: - path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} - - - name: Install pip deps - run: | - python -m pip install --no-compile --upgrade pip - # Note: We install in three steps in order to satisfy requirements - # from non default locations first. Installing the PyTorch CPU - # wheels saves multiple minutes and a lot of bandwidth on runner setup. - pip install --no-compile -r pytorch-rocm-requirements.txt - pip install --no-cache-dir -r iree-requirements-ci.txt --upgrade - pip install -r requirements.txt -e . - - - name: Run e2e tests on MI300 - if: "contains(matrix.os, 'mi300') && !cancelled()" - run: | - export WAVE_RUN_E2E_TESTS=1 - export TEST_PARAMS_PATH="tests/kernel/wave/test_param.json" - pytest -n 1 --capture=tee-sys -vv ./tests/kernel/wave/ From 54a6345876b06d2a298f16b0b1d980a6088fdf92 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 16:46:35 -0800 Subject: [PATCH 09/11] Make get_default_run_config use default device + move test to CI-TK.yaml Signed-off-by: Stanley Winata --- .github/workflows/ci-tk.yaml | 9 ++++++++- .github/workflows/perf.yaml | 9 +-------- iree/turbine/kernel/wave/utils.py | 16 +++++++++++----- lit_tests/kernel/wave/codegen.py | 16 ++++++++-------- tests/kernel/wave/wave_attention_test.py | 5 +++++ tests/kernel/wave/wave_gemm_test.py | 2 +- 6 files changed, 34 insertions(+), 23 deletions(-) diff --git a/.github/workflows/ci-tk.yaml b/.github/workflows/ci-tk.yaml index f0b8fbec..52ebc5ee 100644 --- a/.github/workflows/ci-tk.yaml +++ b/.github/workflows/ci-tk.yaml @@ -21,7 +21,7 @@ jobs: fail-fast: false matrix: version: [3.11] - os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] + os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64, nodai-amdgpu-mi250-x86-64] runs-on: ${{matrix.os}} env: PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" @@ -64,6 +64,13 @@ jobs: export WAVE_RUN_E2E_TESTS=1 pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ + - name: Run e2e tests on MI250 + if: "contains(matrix.os, 'mi250') && !cancelled()" + run: | + pip install --no-compile -r pytorch-rocm-requirements.txt + export WAVE_RUN_E2E_TESTS=1 + pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ + - name: Run LIT tests if: ${{ !cancelled() }} run: | diff --git a/.github/workflows/perf.yaml b/.github/workflows/perf.yaml index 30328d98..c574d542 100644 --- a/.github/workflows/perf.yaml +++ b/.github/workflows/perf.yaml @@ -24,7 +24,7 @@ jobs: fail-fast: false matrix: version: [3.11] - os: [nodai-amdgpu-mi300-x86-64, nodai-amdgpu-mi250-x86-64] + os: [nodai-amdgpu-mi300-x86-64] runs-on: ${{matrix.os}} env: PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" @@ -61,10 +61,3 @@ jobs: export WAVE_RUN_E2E_TESTS=1 export TEST_PARAMS_PATH="tests/kernel/wave/test_param.json" pytest -n 1 --capture=tee-sys -vv ./tests/kernel/wave/ - - - name: Run e2e tests on MI250 - if: "contains(matrix.os, 'mi250') && !cancelled()" - run: | - export WAVE_RUN_E2E_TESTS=1 - export TEST_PARAMS_PATH="tests/kernel/wave/test_param.json" - pytest -n 1 --capture=tee-sys -vv ./tests/kernel/wave/ diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 8ba763b3..14f046b8 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -97,11 +97,6 @@ def run_test(func: Callable[[], None]) -> Callable[[], None]: return func -def get_default_run_config() -> dict[Any, Any]: - """Return default config for testing.""" - return {"backend": "rocm", "device": "hip", "target": "gfx942"} - - def get_default_arch() -> str: """Return default ROCM architecture""" if not torch.cuda.is_available(): @@ -114,6 +109,17 @@ def get_default_arch() -> str: return gcnArch[0:colon_pos] +def get_default_run_config() -> dict[Any, Any]: + """Return default config for running.""" + arch = get_default_arch() + return {"backend": "rocm", "device": "hip", "target": arch} + + +def get_default_compile_config() -> dict[Any, Any]: + """Return default config for compilation.""" + return {"backend": "rocm", "device": "hip", "target": "gfx942"} + + def print_trace(trace: CapturedTrace, custom_print: bool = True): """ Prints all subgraphs of a trace starting with the root graph. diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 75924390..12bd917e 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -8,7 +8,7 @@ from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.utils import ( run_test, - get_default_run_config, + get_default_compile_config, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, ) @@ -1668,7 +1668,7 @@ def test( res = tkw.sum(res, dim=N) tkw.write(res, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 128) a = torch.randn(shape, dtype=torch.float16) @@ -1744,7 +1744,7 @@ def test( res = tkw.sum([lhs, rhs], dim=N) tkw.write(res, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 128) a = torch.randn(shape, dtype=torch.float16) @@ -1816,7 +1816,7 @@ def repeat( result = repeat + repeat tkw.write(result, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 512) a = torch.randn(shape, dtype=torch.float16) @@ -1904,7 +1904,7 @@ def repeat( tkw.write(repeat, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 512) a = torch.randn(shape, dtype=torch.float16) @@ -2003,7 +2003,7 @@ def repeat( tkw.write(res_max, c, elements_per_thread=1) tkw.write(res_sum, d, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 512) a = torch.randn(shape, dtype=torch.float16) @@ -2102,7 +2102,7 @@ def repeat( res_max, res_sum = repeat tkw.write(res_sum, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 1024) a = torch.randn(shape, dtype=torch.float32) @@ -2163,7 +2163,7 @@ def test( res = lhs + rhs tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 128) a = torch.ones(shape, dtype=torch.float16) diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 2ac974e3..01aa70c6 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -16,6 +16,7 @@ from iree.turbine.kernel.wave.iree_utils import generate_iree_ref from iree.turbine.kernel.wave.utils import ( get_default_run_config, + get_default_arch, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, device_randn, @@ -28,6 +29,9 @@ _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") +require_cdna3 = pytest.mark.skipif( + "gfx94" not in get_default_arch(), reason="Default device is not CDNA3" +) # Whether to dump the generated MLIR module. test_dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) # Whether to use scheduling group barriers (needs LLVM fix). @@ -205,6 +209,7 @@ def repeat( # This test requires some more analysis on the index sequences between # the two chained GEMMs. @require_e2e +@require_cdna3 @pytest.mark.xfail @pytest.mark.parametrize("shape", get_test_shapes("test_attention")) @pytest.mark.parametrize("enable_scheduling", [False]) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 66de9f12..3b8d8492 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -333,7 +333,7 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: VALU_UNITS: 8, SHUFFLE_UNITS: 8, } - config = {"backend": "rocm", "device": "hip", "target": "gfx90a"} + config = get_default_run_config() if run_bench: config["benchmark_batch_size"] = 10 config["benchmark_repetitions"] = 3 From 4b2d4525d102d79bfdf66be40833e841492ad26d Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Tue, 19 Nov 2024 17:07:47 -0800 Subject: [PATCH 10/11] turn off igemm for cdna2 temporarily + fix yaml Signed-off-by: Stanley Winata --- .github/workflows/ci-tk.yaml | 11 ++--------- tests/kernel/wave/wave_e2e_test.py | 6 ++++++ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/.github/workflows/ci-tk.yaml b/.github/workflows/ci-tk.yaml index 52ebc5ee..fafbf355 100644 --- a/.github/workflows/ci-tk.yaml +++ b/.github/workflows/ci-tk.yaml @@ -57,15 +57,8 @@ jobs: run: | pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ - - name: Run e2e tests on MI300 - if: "contains(matrix.os, 'mi300') && !cancelled()" - run: | - pip install --no-compile -r pytorch-rocm-requirements.txt - export WAVE_RUN_E2E_TESTS=1 - pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ - - - name: Run e2e tests on MI250 - if: "contains(matrix.os, 'mi250') && !cancelled()" + - name: Run e2e tests on AMD GPU + if: "contains(matrix.os, 'amdgpu') && !cancelled()" run: | pip install --no-compile -r pytorch-rocm-requirements.txt export WAVE_RUN_E2E_TESTS=1 diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 1414c0c6..fe8e8285 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -11,6 +11,7 @@ from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.iree_utils import generate_iree_ref from iree.turbine.kernel.wave.utils import ( + get_default_arch, get_default_run_config, device_randn, device_randint, @@ -27,6 +28,9 @@ _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") +require_cdna3 = pytest.mark.skipif( + "gfx94" not in get_default_arch(), reason="Default device is not CDNA3" +) default_test_shapes = [ (1, 27), (111, 813), @@ -689,6 +693,7 @@ def test( assert_close(b, expected) +# TODO: Fix test for CDNA2. CDNA2 seem to have worse accuracy, atol=0.0094, rtol=10.2405 @require_e2e def test_im2col_mma(request): run_bench = request.config.getoption("--runperf") @@ -891,6 +896,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: @require_e2e +@require_cdna3 @pytest.mark.parametrize("n, h, w, c, hf, wf, nf, stride", _igemm_cases) @pytest.mark.parametrize("mem_space", _mem_spaces) @pytest.mark.parametrize("layout", _layouts) From 3a3f9f31f11baf857afb44af388a544a4f00a9f2 Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Wed, 20 Nov 2024 08:43:56 -0800 Subject: [PATCH 11/11] rebase, nit, and lit Signed-off-by: Stanley Winata --- iree/turbine/kernel/wave/utils.py | 2 +- lit_tests/kernel/wave/codegen.py | 149 ++++++++++++++++++++++++++++ tests/kernel/wave/wave_gemm_test.py | 32 +++--- 3 files changed, 164 insertions(+), 19 deletions(-) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index 14f046b8..aa257a41 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -943,7 +943,7 @@ def device_randint(*args, **kwargs): def device_randperm(*args, **kwargs): - return torch.randperm(*args, **kwargs).to("cuda") + return to_default_device(torch.randperm(*args, **kwargs)) def device_zeros(*args, **kwargs): diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 12bd917e..7f8a3302 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -656,6 +656,155 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: return +@run_test +def test_cdna2_int_gemm(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.I32_16x16x16_I8, + ) + ] + + @tkw.wave(constraints) + def cdna2_int_gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.i32], + ): + c_reg = tkl.Register[M, N, tkl.i32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 64, + N: 128, + K: 64, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 16, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + a = torch.ones(64, 32, dtype=torch.int8) + b = torch.ones(128, 32, dtype=torch.int8) + c = torch.zeros(64, 128, dtype=torch.int32) + print(cdna2_int_gemm(a, b, c).module_op) + + # CHECK: func.func @cdna2_int_gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x24xi8, #gpu.address_space> + # CHECK: %[[ALLOC_1:.+]] = memref.alloc() : memref<32x24xi8, #gpu.address_space> + # CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]] + # CHECK: %[[GLOBAL_1:.+]] = stream.binding.subspan %[[ARG1]] + # CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) { + # CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]] + # CHECK: vector.store %[[REG_0]], %[[ALLOC_0]] + # CHECK: %[[LHS:.+]] = vector.load %[[ALLOC]]{{.*}} : memref<32x24xi8, #gpu.address_space>, vector<4xi8> + # CHECK: %[[REG_1:.+]] = vector.load %[[GLOBAL_1]] + # CHECK: vector.store %[[REG_1]], %[[ALLOC_1]] + # CHECK: %[[RHS:.+]] = vector.load %[[ALLOC_1]]{{.*}} : memref<32x24xi8, #gpu.address_space>, vector<4xi8> + # CHECK: %[[MMA:.+]] = amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32> + # CHECK: scf.yield %[[MMA]] : vector<4xi32> + + +@run_test +def test_cdna3_int_gemm(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mma_variant = tkw.MMAType.I32_16x16x32_I8 + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mma_variant, + ) + ] + + @tkw.wave(constraints) + def cdna3_int_gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.i32], + ): + c_reg = tkl.Register[M, N, tkl.i32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 64, + N: 128, + K: 64, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 32, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mma_variant), + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + a = torch.ones(64, 32, dtype=torch.int8) + b = torch.ones(128, 32, dtype=torch.int8) + c = torch.zeros(64, 128, dtype=torch.int32) + print(cdna3_int_gemm(a, b, c).module_op) + + # CHECK: func.func @cdna3_int_gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x40xi8, #gpu.address_space> + # CHECK: %[[ALLOC_1:.+]] = memref.alloc() : memref<32x40xi8, #gpu.address_space> + # CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]] + # CHECK: %[[GLOBAL_1:.+]] = stream.binding.subspan %[[ARG1]] + # CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) { + # CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]] + # CHECK: vector.store %[[REG_0]], %[[ALLOC_0]] + # CHECK: %[[LHS:.+]] = vector.load %[[ALLOC]]{{.*}} : memref<32x40xi8, #gpu.address_space>, vector<8xi8> + # CHECK: %[[REG_1:.+]] = vector.load %[[GLOBAL_1]] + # CHECK: vector.store %[[REG_1]], %[[ALLOC_1]] + # CHECK: %[[RHS:.+]] = vector.load %[[ALLOC_1]]{{.*}} : memref<32x40xi8, #gpu.address_space>, vector<8xi8> + # CHECK: %[[MMA:.+]] = amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]] {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32> + # CHECK: scf.yield %[[MMA]] : vector<4xi32> + + @run_test def test_batched_gemm(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 3b8d8492..8db6c8a7 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -287,8 +287,8 @@ def testCDNA2IntGemm( # These can be influenced by introducing constraints. @tkw.wave(constraints) def gemm( - a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i16], - b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i16], + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i8], c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.i32], ): c_reg = tkl.Register[M, N, tkl.i32](0.0) @@ -297,12 +297,10 @@ def gemm( # dimension were tiled, then we would need to materialize a loop. @tkw.reduction(K, init_args=[c_reg]) def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: - # a_reg: tkw.Register[M, K, tkl.i16] + # a_reg: tkw.Register[M, K, tkl.i8] a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) - a_reg = tkw.cast(a_reg, tkl.i8) - # b_reg: tkw.Register[N, K, tkl.i16] + # b_reg: tkw.Register[N, K, tkl.i8] b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) - b_reg = tkw.cast(b_reg, tkl.i8) # acc: tkw.Register[M, N, tkl.i32] acc = tkw.mma(a_reg, b_reg, acc) return acc @@ -368,8 +366,8 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: dynamic_symbols_map=dynamic_symbols_map, ): randint_hi = 4 - a = device_randint(randint_hi, (shape[0], shape[2]), dtype=torch.int16) - b = device_randint(randint_hi, (shape[1], shape[2]), dtype=torch.int16) + a = device_randint(randint_hi, (shape[0], shape[2]), dtype=torch.int8) + b = device_randint(randint_hi, (shape[1], shape[2]), dtype=torch.int8) c = device_zeros(shape[0], shape[1], dtype=torch.int32) mb = gemm(a, b, c) @@ -395,8 +393,8 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: @pytest.mark.parametrize( "mfma_variant", [ - MMAType.F32_16x16x32_F8, - MMAType.F32_32x32x16_F8, + MMAType.I32_16x16x32_I8, + MMAType.I32_32x32x16_I8, ], ) def testCDNA3IntGemm( @@ -433,8 +431,8 @@ def testCDNA3IntGemm( @tkw.wave(constraints) def gemm( - a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i16], - b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i16], + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i8], c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.i32], ): c_reg = tkl.Register[M, N, tkl.i32](0.0) @@ -443,12 +441,10 @@ def gemm( # dimension were tiled, then we would need to materialize a loop. @tkw.reduction(K, init_args=[c_reg]) def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: - # a_reg: tkw.Register[M, K, tkl.i16] + # a_reg: tkw.Register[M, K, tkl.i8] a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) - a_reg = tkw.cast(a_reg, tkl.i8) - # b_reg: tkw.Register[N, K, tkl.i16] + # b_reg: tkw.Register[N, K, tkl.i8] b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) - b_reg = tkw.cast(b_reg, tkl.i8) # acc: tkw.Register[M, N, tkl.i32] acc = tkw.mma(a_reg, b_reg, acc) return acc @@ -498,8 +494,8 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: use_scheduling_barriers=enable_scheduling_barriers, ): randint_hi = 4 - a = device_randint(randint_hi, (shape[0], shape[2]), dtype=torch.int16) - b = device_randint(randint_hi, (shape[1], shape[2]), dtype=torch.int16) + a = device_randint(randint_hi, (shape[0], shape[2]), dtype=torch.int8) + b = device_randint(randint_hi, (shape[1], shape[2]), dtype=torch.int8) c = device_zeros(shape[0], shape[1], dtype=torch.int32) mb = gemm(a, b, c)