diff --git a/.github/workflows/ci-tk.yaml b/.github/workflows/ci-tk.yaml index f0b8fbec..fafbf355 100644 --- a/.github/workflows/ci-tk.yaml +++ b/.github/workflows/ci-tk.yaml @@ -21,7 +21,7 @@ jobs: fail-fast: false matrix: version: [3.11] - os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64] + os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64, nodai-amdgpu-mi250-x86-64] runs-on: ${{matrix.os}} env: PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" @@ -57,8 +57,8 @@ jobs: run: | pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ - - name: Run e2e tests on MI300 - if: "contains(matrix.os, 'mi300') && !cancelled()" + - name: Run e2e tests on AMD GPU + if: "contains(matrix.os, 'amdgpu') && !cancelled()" run: | pip install --no-compile -r pytorch-rocm-requirements.txt export WAVE_RUN_E2E_TESTS=1 diff --git a/iree/turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py index 83441af0..a12f222a 100644 --- a/iree/turbine/kernel/wave/constraints.py +++ b/iree/turbine/kernel/wave/constraints.py @@ -15,11 +15,45 @@ from ..lang.global_symbols import * +""" +Formatting for different target intrinsics: + __xx_[_] + +Values: 0xABCD where: +* A = vendor: + * 1 = AMD + * 2 = NVIDIA +* B = architecture. When an intrinsic exists in multiple architectures, this + should be the architecture it was introduced in, as long as it still + has the same semantics. If a new architecture breaks an existing + intrinsic's semantics, we can use that field for versioning. + * For AMD: + * 0 = CDNA1 + * 1 = CDNA2 + * 2 = CDNA3 + * 8 = RDNA3 +* C = element type of A-matrix: + * 0 = 64-bit float (e.g. IEEE754 double precision) + * 1 = 32-bit float (e.g. IEEE754 single precision, and "xf32" fast variants) + * 2 = 16-bit float (incl. IREE754 half and bf16) + * 3 = 8-bit float (incl. f8E5M2, f8E4M3, and "FNUZ" variants) + * C = 8-bit integer (any signedness) +* D enumerates intrinsics that share the same 0xABC* bits. +""" + + class MMAType(Enum): - F32_16x16x16_F16 = 0 - F32_32x32x8_F16 = 1 - F32_16x16x32_F8 = 2 - F32_32x32x16_F8 = 3 + # Intrinsics introduced in CDNA1 + F32_16x16x16_F16 = 0x1020 + F32_32x32x8_F16 = 0x1021 + I32_16x16x16_I8 = 0x10C0 + I32_32x32x8_I8 = 0x10C1 + + # Intrinsics introduced in CDNA3 + F32_16x16x32_F8 = 0x1230 + F32_32x32x16_F8 = 0x1231 + I32_16x16x32_I8 = 0x12C0 + I32_32x32x16_I8 = 0x12C1 class MMAOperand(Enum): @@ -89,13 +123,13 @@ def get_thread_id_from_workgroup_dim(self, workgroup_dim: int) -> IndexSymbol: def mma_matrix_shapes(self) -> tuple[int]: # TODO: Eventually the shapes and indices should be provided by a tool match self.mma_type: - case MMAType.F32_16x16x16_F16: + case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8: return (16, 16, 16) - case MMAType.F32_32x32x8_F16: + case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return (32, 32, 8) - case MMAType.F32_16x16x32_F8: + case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8: return (16, 16, 32) - case MMAType.F32_32x32x16_F8: + case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8: return (32, 32, 16) case _: return () @@ -151,7 +185,7 @@ def apply( lane = self.linearized_thread_id % self.threads_per_wave match self.mma_type: # (M x K, N x K) -> M x N - case MMAType.F32_16x16x16_F16: + case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8: offset = [ Piecewise( (lane % 16, ~MMA_ACC), @@ -170,7 +204,7 @@ def apply( 1, # N 1, # K ] - case MMAType.F32_32x32x8_F16: + case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: offset = [ Piecewise( (lane % 32, ~MMA_ACC), @@ -194,7 +228,7 @@ def apply( 1, # N 1, # K ] - case MMAType.F32_16x16x32_F8: + case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8: offset = [ Piecewise( (lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC) @@ -212,7 +246,7 @@ def apply( 1, # N 1, # K ] - case MMAType.F32_32x32x16_F8: + case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8: offset = [ Piecewise( (lane % 32, ~MMA_ACC), diff --git a/iree/turbine/kernel/wave/iree_utils.py b/iree/turbine/kernel/wave/iree_utils.py index d1031bd5..7a435e15 100644 --- a/iree/turbine/kernel/wave/iree_utils.py +++ b/iree/turbine/kernel/wave/iree_utils.py @@ -7,7 +7,7 @@ import torch from typing import Any from .utils import compile_and_invoke -from ...support.conversions import TORCH_DTYPE_TO_MLIR_TYPE_ASM +from ...support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM def get_chain_mmt_asm( @@ -93,7 +93,7 @@ def get_mmt_asm( if not cast_fp8: matmul_function = f""" func.func @{func_name}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{acc_type}> {{ - %c0 = arith.constant 0.0 : {acc_dtype} + %c0 = arith.constant {"0.0" if acc_dtype.startswith("f") else "0"} : {acc_dtype} %init = tensor.empty() : tensor<{acc_type}> %inital_result = linalg.fill ins(%c0 : {acc_dtype}) outs(%init : tensor<{acc_type}>) -> tensor<{acc_type}> %result = linalg.{operator} ins(%lhs, %rhs: tensor<{lhs_type}>, tensor<{rhs_type}>) @@ -138,7 +138,7 @@ def get_conv_asm( def dtype_str(dtype: torch.dtype) -> str: - dtype_str = TORCH_DTYPE_TO_MLIR_TYPE_ASM.get(dtype, None) + dtype_str = TORCH_DTYPE_TO_IREE_TYPE_ASM[dtype] if dtype_str is None: raise ValueError(f"Unsupported dtype: {dtype}") return dtype_str @@ -166,7 +166,11 @@ def generate_iree_ref( rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype) acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype) asm = get_mmt_asm( - lhs_type, rhs_type, acc_type, batch=False, cast_fp8=kernel_type == "mmt_f8" + lhs_type, + rhs_type, + acc_type, + batch=False, + cast_fp8=kernel_type == "mmt_f8", ) elif kernel_type == "bmmt": lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype) diff --git a/iree/turbine/kernel/wave/utils.py b/iree/turbine/kernel/wave/utils.py index a28ea2fa..aa257a41 100644 --- a/iree/turbine/kernel/wave/utils.py +++ b/iree/turbine/kernel/wave/utils.py @@ -97,8 +97,26 @@ def run_test(func: Callable[[], None]) -> Callable[[], None]: return func +def get_default_arch() -> str: + """Return default ROCM architecture""" + if not torch.cuda.is_available(): + return "cpu" + device = torch.device("cuda") + gcnArch = torch.cuda.get_device_properties(device).gcnArchName + assert "gfx" in gcnArch, "Currently only support GFX/ROCm for get_default_arch." + # The gcnArchName comes back like gfx90a:sramecc+:xnack. + colon_pos = gcnArch.find(":") + return gcnArch[0:colon_pos] + + def get_default_run_config() -> dict[Any, Any]: - """Return default config for testing.""" + """Return default config for running.""" + arch = get_default_arch() + return {"backend": "rocm", "device": "hip", "target": arch} + + +def get_default_compile_config() -> dict[Any, Any]: + """Return default config for compilation.""" return {"backend": "rocm", "device": "hip", "target": "gfx942"} @@ -880,25 +898,25 @@ def ceildiv(a: int, b: int) -> int: def get_mfma_load_elems_per_thread(mfma_variant: MMAType) -> int: match mfma_variant: - case MMAType.F32_16x16x16_F16: + case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8: return 4 - case MMAType.F32_32x32x8_F16: + case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return 4 - case MMAType.F32_16x16x32_F8: + case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8: return 8 - case MMAType.F32_32x32x16_F8: + case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8: return 8 def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int: match mfma_variant: - case MMAType.F32_16x16x16_F16: + case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8: return 4 - case MMAType.F32_32x32x8_F16: + case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8: return 16 - case MMAType.F32_16x16x32_F8: + case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8: return 4 - case MMAType.F32_32x32x16_F8: + case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8: return 16 @@ -908,20 +926,28 @@ def all_equal(input_list: list[Any]) -> bool: return all(elem == input_list[0] for elem in input_list) +def get_default_device() -> str: + return "cuda" if torch.cuda.is_available() else "cpu" + + +def to_default_device(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(get_default_device()) + + def device_randn(*args, **kwargs): - return torch.randn(*args, **kwargs).to("cuda") + return to_default_device(torch.randn(*args, **kwargs)) def device_randint(*args, **kwargs): - return torch.randint(*args, **kwargs).to("cuda") + return to_default_device(torch.randint(*args, **kwargs)) def device_randperm(*args, **kwargs): - return torch.randperm(*args, **kwargs).to("cuda") + return to_default_device(torch.randperm(*args, **kwargs)) def device_zeros(*args, **kwargs): - return torch.zeros(*args, **kwargs).to("cuda") + return to_default_device(torch.zeros(*args, **kwargs)) def get_assumptions(constraints: list[Constraint]) -> list[Assumption]: diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 75924390..7f8a3302 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -8,7 +8,7 @@ from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.utils import ( run_test, - get_default_run_config, + get_default_compile_config, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, ) @@ -656,6 +656,155 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: return +@run_test +def test_cdna2_int_gemm(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.I32_16x16x16_I8, + ) + ] + + @tkw.wave(constraints) + def cdna2_int_gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.i32], + ): + c_reg = tkl.Register[M, N, tkl.i32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 64, + N: 128, + K: 64, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 16, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + a = torch.ones(64, 32, dtype=torch.int8) + b = torch.ones(128, 32, dtype=torch.int8) + c = torch.zeros(64, 128, dtype=torch.int32) + print(cdna2_int_gemm(a, b, c).module_op) + + # CHECK: func.func @cdna2_int_gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x24xi8, #gpu.address_space> + # CHECK: %[[ALLOC_1:.+]] = memref.alloc() : memref<32x24xi8, #gpu.address_space> + # CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]] + # CHECK: %[[GLOBAL_1:.+]] = stream.binding.subspan %[[ARG1]] + # CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) { + # CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]] + # CHECK: vector.store %[[REG_0]], %[[ALLOC_0]] + # CHECK: %[[LHS:.+]] = vector.load %[[ALLOC]]{{.*}} : memref<32x24xi8, #gpu.address_space>, vector<4xi8> + # CHECK: %[[REG_1:.+]] = vector.load %[[GLOBAL_1]] + # CHECK: vector.store %[[REG_1]], %[[ALLOC_1]] + # CHECK: %[[RHS:.+]] = vector.load %[[ALLOC_1]]{{.*}} : memref<32x24xi8, #gpu.address_space>, vector<4xi8> + # CHECK: %[[MMA:.+]] = amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32> + # CHECK: scf.yield %[[MMA]] : vector<4xi32> + + +@run_test +def test_cdna3_int_gemm(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + mma_variant = tkw.MMAType.I32_16x16x32_I8 + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=mma_variant, + ) + ] + + @tkw.wave(constraints) + def cdna3_int_gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.i32], + ): + c_reg = tkl.Register[M, N, tkl.i32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 64, + N: 128, + K: 64, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 32, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mma_variant), + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + ): + a = torch.ones(64, 32, dtype=torch.int8) + b = torch.ones(128, 32, dtype=torch.int8) + c = torch.zeros(64, 128, dtype=torch.int32) + print(cdna3_int_gemm(a, b, c).module_op) + + # CHECK: func.func @cdna3_int_gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding, + # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x40xi8, #gpu.address_space> + # CHECK: %[[ALLOC_1:.+]] = memref.alloc() : memref<32x40xi8, #gpu.address_space> + # CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]] + # CHECK: %[[GLOBAL_1:.+]] = stream.binding.subspan %[[ARG1]] + # CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) { + # CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]] + # CHECK: vector.store %[[REG_0]], %[[ALLOC_0]] + # CHECK: %[[LHS:.+]] = vector.load %[[ALLOC]]{{.*}} : memref<32x40xi8, #gpu.address_space>, vector<8xi8> + # CHECK: %[[REG_1:.+]] = vector.load %[[GLOBAL_1]] + # CHECK: vector.store %[[REG_1]], %[[ALLOC_1]] + # CHECK: %[[RHS:.+]] = vector.load %[[ALLOC_1]]{{.*}} : memref<32x40xi8, #gpu.address_space>, vector<8xi8> + # CHECK: %[[MMA:.+]] = amdgpu.mfma %[[LHS]] * %[[RHS]] + %[[ACC]] {blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32> + # CHECK: scf.yield %[[MMA]] : vector<4xi32> + + @run_test def test_batched_gemm(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] @@ -1668,7 +1817,7 @@ def test( res = tkw.sum(res, dim=N) tkw.write(res, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 128) a = torch.randn(shape, dtype=torch.float16) @@ -1744,7 +1893,7 @@ def test( res = tkw.sum([lhs, rhs], dim=N) tkw.write(res, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 128) a = torch.randn(shape, dtype=torch.float16) @@ -1816,7 +1965,7 @@ def repeat( result = repeat + repeat tkw.write(result, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 512) a = torch.randn(shape, dtype=torch.float16) @@ -1904,7 +2053,7 @@ def repeat( tkw.write(repeat, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 512) a = torch.randn(shape, dtype=torch.float16) @@ -2003,7 +2152,7 @@ def repeat( tkw.write(res_max, c, elements_per_thread=1) tkw.write(res_sum, d, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 512) a = torch.randn(shape, dtype=torch.float16) @@ -2102,7 +2251,7 @@ def repeat( res_max, res_sum = repeat tkw.write(res_sum, c, elements_per_thread=1) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 1024) a = torch.randn(shape, dtype=torch.float32) @@ -2163,7 +2312,7 @@ def test( res = lhs + rhs tkw.write(res, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - config = get_default_run_config() + config = get_default_compile_config() shape = (256, 128) a = torch.ones(shape, dtype=torch.float16) diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 2ac974e3..01aa70c6 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -16,6 +16,7 @@ from iree.turbine.kernel.wave.iree_utils import generate_iree_ref from iree.turbine.kernel.wave.utils import ( get_default_run_config, + get_default_arch, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, device_randn, @@ -28,6 +29,9 @@ _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") +require_cdna3 = pytest.mark.skipif( + "gfx94" not in get_default_arch(), reason="Default device is not CDNA3" +) # Whether to dump the generated MLIR module. test_dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) # Whether to use scheduling group barriers (needs LLVM fix). @@ -205,6 +209,7 @@ def repeat( # This test requires some more analysis on the index sequences between # the two chained GEMMs. @require_e2e +@require_cdna3 @pytest.mark.xfail @pytest.mark.parametrize("shape", get_test_shapes("test_attention")) @pytest.mark.parametrize("enable_scheduling", [False]) diff --git a/tests/kernel/wave/wave_e2e_test.py b/tests/kernel/wave/wave_e2e_test.py index 1414c0c6..fe8e8285 100644 --- a/tests/kernel/wave/wave_e2e_test.py +++ b/tests/kernel/wave/wave_e2e_test.py @@ -11,6 +11,7 @@ from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel.wave.iree_utils import generate_iree_ref from iree.turbine.kernel.wave.utils import ( + get_default_arch, get_default_run_config, device_randn, device_randint, @@ -27,6 +28,9 @@ _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") +require_cdna3 = pytest.mark.skipif( + "gfx94" not in get_default_arch(), reason="Default device is not CDNA3" +) default_test_shapes = [ (1, 27), (111, 813), @@ -689,6 +693,7 @@ def test( assert_close(b, expected) +# TODO: Fix test for CDNA2. CDNA2 seem to have worse accuracy, atol=0.0094, rtol=10.2405 @require_e2e def test_im2col_mma(request): run_bench = request.config.getoption("--runperf") @@ -891,6 +896,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]: @require_e2e +@require_cdna3 @pytest.mark.parametrize("n, h, w, c, hf, wf, nf, stride", _igemm_cases) @pytest.mark.parametrize("mem_space", _mem_spaces) @pytest.mark.parametrize("layout", _layouts) diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 324089e0..8db6c8a7 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -15,9 +15,11 @@ from iree.turbine.kernel.wave.iree_utils import generate_iree_ref from iree.turbine.kernel.wave.utils import ( get_default_run_config, + get_default_arch, get_mfma_load_elems_per_thread, get_mfma_store_elems_per_thread, device_randn, + device_randint, device_zeros, ) from iree.turbine.kernel.wave.constraints import MMAType @@ -28,6 +30,12 @@ _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") +require_cdna2 = pytest.mark.skipif( + "gfx90" not in get_default_arch(), reason="Default device is not CDNA2" +) +require_cdna3 = pytest.mark.skipif( + "gfx94" not in get_default_arch(), reason="Default device is not CDNA3" +) # Whether to dump the generated MLIR module. test_dump_generated_mlir = int(os.environ.get("WAVE_DUMP_MLIR", 0)) # Whether to use scheduling group barriers (needs LLVM fix). @@ -219,6 +227,295 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: @require_e2e +@require_cdna2 +@pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) +@pytest.mark.parametrize("enable_scheduling", [False, True]) +@pytest.mark.parametrize("dynamic_dims", [False, True]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.I32_16x16x16_I8, + MMAType.I32_32x32x8_I8, + ], +) +def testCDNA2IntGemm( + shape: tuple[int], + enable_scheduling: bool, + dynamic_dims: bool, + mfma_variant: MMAType, + request, +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(2, 2, 1), mma_type=mfma_variant + ) + ] + + # With dynamic dimensions, we need to add an assumption on how big + # the reduction dimension is to determine whether we can schedule or not. + if dynamic_dims: + constraints += [tkw.Assumption(K > BLOCK_K * 4)] + + # Wave-level micro-kernel. + # Since warps are not directly addressable, there is no + # explicit notion of a warp id (like a workgroup or thread id). + # This kernel uses the input sizes M, N, K throughout, as the tiling + # and data movement strategy is determined during the compilation process. + # These can be influenced by introducing constraints. + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.i32], + ): + c_reg = tkl.Register[M, N, tkl.i32](0.0) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: + # a_reg: tkw.Register[M, K, tkl.i8] + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[N, K, tkl.i8] + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[M, N, tkl.i32] + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # repeat represents the results of the loop + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: shape[0], + N: shape[1], + K: shape[2], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, + } + config = get_default_run_config() + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + dynamic_symbols = [] + dynamic_symbols_map = {} + if dynamic_dims: + dynamic_symbols_map[M] = hyperparams[M] + dynamic_symbols_map[N] = hyperparams[N] + dynamic_symbols_map[K] = hyperparams[K] + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(K) + del hyperparams[M] + del hyperparams[N] + del hyperparams[K] + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + dynamic_symbols=dynamic_symbols, + dynamic_symbols_map=dynamic_symbols_map, + ): + randint_hi = 4 + a = device_randint(randint_hi, (shape[0], shape[2]), dtype=torch.int8) + b = device_randint(randint_hi, (shape[1], shape[2]), dtype=torch.int8) + c = device_zeros(shape[0], shape[1], dtype=torch.int32) + mb = gemm(a, b, c) + + if test_dump_generated_mlir: + filename = f"wave_gemm_{'x'.join(map(str, shape))}.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + if run_bench: + if dump_perf is not None: + config["benchmark_results_file"] = os.path.join( + dump_perf, "iree_" + perf_filename + ) + iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.int32) + generate_iree_ref("mmt", [a, b], [iree_ref], config, run_bench=run_bench) + assert_close(c, iree_ref, check_device=False) + + +@require_e2e +@require_cdna3 +@pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) +@pytest.mark.parametrize("enable_scheduling", [False, True]) +@pytest.mark.parametrize( + "mfma_variant", + [ + MMAType.I32_16x16x32_I8, + MMAType.I32_32x32x16_I8, + ], +) +def testCDNA3IntGemm( + shape: tuple[int], enable_scheduling: bool, mfma_variant: MMAType, request +): + run_bench = request.config.getoption("--runperf") + dump_perf = request.config.getoption("--dump-perf-files-path") + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + # Other hyperparameters + LOAD_ELEMS_PER_THREAD = tkl.sym.LOAD_ELEMS_PER_THREAD + STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(2, 2, 1), mma_type=mfma_variant + ) + ] + + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.i32], + ): + c_reg = tkl.Register[M, N, tkl.i32](0.0) + + # This microkernel encodes the fact that if the reduction + # dimension were tiled, then we would need to materialize a loop. + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.i32]: + # a_reg: tkw.Register[M, K, tkl.i8] + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # b_reg: tkw.Register[N, K, tkl.i8] + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + # acc: tkw.Register[M, N, tkl.i32] + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + LOAD_ELEMS_PER_THREAD: get_mfma_load_elems_per_thread(mfma_variant), + STORE_ELEMS_PER_THREAD: get_mfma_store_elems_per_thread(mfma_variant), + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + M: shape[0], + N: shape[1], + K: shape[2], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + VALU_DELAY: 1, + SHUFFLE_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + VALU_UNITS: 8, + SHUFFLE_UNITS: 8, + } + config = get_default_run_config() + if run_bench: + config["benchmark_batch_size"] = 10 + config["benchmark_repetitions"] = 3 + if dump_perf is not None: + perf_filename = request.node.name + ".json" + config["benchmark_results_file"] = os.path.join( + dump_perf, "tk_" + perf_filename + ) + + with tk.gen.TestLaunchContext( + hyperparams, + canonicalize=True, + run=True, + run_bench=run_bench, + run_config=config, + schedule=enable_scheduling, + use_scheduling_barriers=enable_scheduling_barriers, + ): + randint_hi = 4 + a = device_randint(randint_hi, (shape[0], shape[2]), dtype=torch.int8) + b = device_randint(randint_hi, (shape[1], shape[2]), dtype=torch.int8) + c = device_zeros(shape[0], shape[1], dtype=torch.int32) + mb = gemm(a, b, c) + + if test_dump_generated_mlir: + filename = f"wave_gemm_{'x'.join(map(str, shape))}_f8.mlir" + with open(filename, "w") as f: + f.write(mb.module_op.get_asm()) + + if run_bench: + if dump_perf is not None: + config["benchmark_results_file"] = os.path.join( + dump_perf, "iree_" + perf_filename + ) + iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.int32) + generate_iree_ref("mmt", [a, b], [iree_ref], config, run_bench=run_bench) + assert_close(c, iree_ref, check_device=False) + + +@require_e2e +@require_cdna3 @pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) @pytest.mark.parametrize("enable_scheduling", [False, True]) @pytest.mark.parametrize(