Skip to content

Commit

Permalink
[TKW] Add CDNA2 + CDNA3 Int8 intrinsics and refactor intrinsic enums (#…
Browse files Browse the repository at this point in the history
…279)

- Added CDNA2 int8 intrinsic layouts
- Modified iree_ref to handle int gemms
- Modified certain e2e test to require certain GPU arch to be available
- Modified enum for easy handling in the future
- Get default architecture function
- Borrowed device_randint from Ivan
- Turn on CDNA2 runner for TK-CI

Manually tested that the generated iree_ref for int gemms are working as
expected!

---------

Signed-off-by: Stanley Winata <[email protected]>
Co-authored-by: Ivan Butygin <[email protected]>
  • Loading branch information
raikonenfnu and Hardcode84 authored Nov 20, 2024
1 parent c499d32 commit f8e0cbb
Show file tree
Hide file tree
Showing 8 changed files with 561 additions and 40 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci-tk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
fail-fast: false
matrix:
version: [3.11]
os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64]
os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64, nodai-amdgpu-mi250-x86-64]
runs-on: ${{matrix.os}}
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
Expand Down Expand Up @@ -57,8 +57,8 @@ jobs:
run: |
pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/
- name: Run e2e tests on MI300
if: "contains(matrix.os, 'mi300') && !cancelled()"
- name: Run e2e tests on AMD GPU
if: "contains(matrix.os, 'amdgpu') && !cancelled()"
run: |
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1
Expand Down
58 changes: 46 additions & 12 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,45 @@
from ..lang.global_symbols import *


"""
Formatting for different target intrinsics:
<kind>_<elem-type-C>_<M>x<N>x<K>_<elem-type-A>[_<elem-type-B>]
Values: 0xABCD where:
* A = vendor:
* 1 = AMD
* 2 = NVIDIA
* B = architecture. When an intrinsic exists in multiple architectures, this
should be the architecture it was introduced in, as long as it still
has the same semantics. If a new architecture breaks an existing
intrinsic's semantics, we can use that field for versioning.
* For AMD:
* 0 = CDNA1
* 1 = CDNA2
* 2 = CDNA3
* 8 = RDNA3
* C = element type of A-matrix:
* 0 = 64-bit float (e.g. IEEE754 double precision)
* 1 = 32-bit float (e.g. IEEE754 single precision, and "xf32" fast variants)
* 2 = 16-bit float (incl. IREE754 half and bf16)
* 3 = 8-bit float (incl. f8E5M2, f8E4M3, and "FNUZ" variants)
* C = 8-bit integer (any signedness)
* D enumerates intrinsics that share the same 0xABC* bits.
"""


class MMAType(Enum):
F32_16x16x16_F16 = 0
F32_32x32x8_F16 = 1
F32_16x16x32_F8 = 2
F32_32x32x16_F8 = 3
# Intrinsics introduced in CDNA1
F32_16x16x16_F16 = 0x1020
F32_32x32x8_F16 = 0x1021
I32_16x16x16_I8 = 0x10C0
I32_32x32x8_I8 = 0x10C1

# Intrinsics introduced in CDNA3
F32_16x16x32_F8 = 0x1230
F32_32x32x16_F8 = 0x1231
I32_16x16x32_I8 = 0x12C0
I32_32x32x16_I8 = 0x12C1


class MMAOperand(Enum):
Expand Down Expand Up @@ -89,13 +123,13 @@ def get_thread_id_from_workgroup_dim(self, workgroup_dim: int) -> IndexSymbol:
def mma_matrix_shapes(self) -> tuple[int]:
# TODO: Eventually the shapes and indices should be provided by a tool
match self.mma_type:
case MMAType.F32_16x16x16_F16:
case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8:
return (16, 16, 16)
case MMAType.F32_32x32x8_F16:
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
return (32, 32, 8)
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8:
return (16, 16, 32)
case MMAType.F32_32x32x16_F8:
case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8:
return (32, 32, 16)
case _:
return ()
Expand Down Expand Up @@ -151,7 +185,7 @@ def apply(
lane = self.linearized_thread_id % self.threads_per_wave
match self.mma_type:
# (M x K, N x K) -> M x N
case MMAType.F32_16x16x16_F16:
case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8:
offset = [
Piecewise(
(lane % 16, ~MMA_ACC),
Expand All @@ -170,7 +204,7 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_32x32x8_F16:
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
offset = [
Piecewise(
(lane % 32, ~MMA_ACC),
Expand All @@ -194,7 +228,7 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8:
offset = [
Piecewise(
(lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC)
Expand All @@ -212,7 +246,7 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_32x32x16_F8:
case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8:
offset = [
Piecewise(
(lane % 32, ~MMA_ACC),
Expand Down
12 changes: 8 additions & 4 deletions iree/turbine/kernel/wave/iree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from typing import Any
from .utils import compile_and_invoke
from ...support.conversions import TORCH_DTYPE_TO_MLIR_TYPE_ASM
from ...support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM


def get_chain_mmt_asm(
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_mmt_asm(
if not cast_fp8:
matmul_function = f"""
func.func @{func_name}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{acc_type}> {{
%c0 = arith.constant 0.0 : {acc_dtype}
%c0 = arith.constant {"0.0" if acc_dtype.startswith("f") else "0"} : {acc_dtype}
%init = tensor.empty() : tensor<{acc_type}>
%inital_result = linalg.fill ins(%c0 : {acc_dtype}) outs(%init : tensor<{acc_type}>) -> tensor<{acc_type}>
%result = linalg.{operator} ins(%lhs, %rhs: tensor<{lhs_type}>, tensor<{rhs_type}>)
Expand Down Expand Up @@ -138,7 +138,7 @@ def get_conv_asm(


def dtype_str(dtype: torch.dtype) -> str:
dtype_str = TORCH_DTYPE_TO_MLIR_TYPE_ASM.get(dtype, None)
dtype_str = TORCH_DTYPE_TO_IREE_TYPE_ASM[dtype]
if dtype_str is None:
raise ValueError(f"Unsupported dtype: {dtype}")
return dtype_str
Expand Down Expand Up @@ -166,7 +166,11 @@ def generate_iree_ref(
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_mmt_asm(
lhs_type, rhs_type, acc_type, batch=False, cast_fp8=kernel_type == "mmt_f8"
lhs_type,
rhs_type,
acc_type,
batch=False,
cast_fp8=kernel_type == "mmt_f8",
)
elif kernel_type == "bmmt":
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
Expand Down
52 changes: 39 additions & 13 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,26 @@ def run_test(func: Callable[[], None]) -> Callable[[], None]:
return func


def get_default_arch() -> str:
"""Return default ROCM architecture"""
if not torch.cuda.is_available():
return "cpu"
device = torch.device("cuda")
gcnArch = torch.cuda.get_device_properties(device).gcnArchName
assert "gfx" in gcnArch, "Currently only support GFX/ROCm for get_default_arch."
# The gcnArchName comes back like gfx90a:sramecc+:xnack.
colon_pos = gcnArch.find(":")
return gcnArch[0:colon_pos]


def get_default_run_config() -> dict[Any, Any]:
"""Return default config for testing."""
"""Return default config for running."""
arch = get_default_arch()
return {"backend": "rocm", "device": "hip", "target": arch}


def get_default_compile_config() -> dict[Any, Any]:
"""Return default config for compilation."""
return {"backend": "rocm", "device": "hip", "target": "gfx942"}


Expand Down Expand Up @@ -880,25 +898,25 @@ def ceildiv(a: int, b: int) -> int:

def get_mfma_load_elems_per_thread(mfma_variant: MMAType) -> int:
match mfma_variant:
case MMAType.F32_16x16x16_F16:
case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8:
return 4
case MMAType.F32_32x32x8_F16:
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
return 4
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8:
return 8
case MMAType.F32_32x32x16_F8:
case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8:
return 8


def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int:
match mfma_variant:
case MMAType.F32_16x16x16_F16:
case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8:
return 4
case MMAType.F32_32x32x8_F16:
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
return 16
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8:
return 4
case MMAType.F32_32x32x16_F8:
case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8:
return 16


Expand All @@ -908,20 +926,28 @@ def all_equal(input_list: list[Any]) -> bool:
return all(elem == input_list[0] for elem in input_list)


def get_default_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"


def to_default_device(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(get_default_device())


def device_randn(*args, **kwargs):
return torch.randn(*args, **kwargs).to("cuda")
return to_default_device(torch.randn(*args, **kwargs))


def device_randint(*args, **kwargs):
return torch.randint(*args, **kwargs).to("cuda")
return to_default_device(torch.randint(*args, **kwargs))


def device_randperm(*args, **kwargs):
return torch.randperm(*args, **kwargs).to("cuda")
return to_default_device(torch.randperm(*args, **kwargs))


def device_zeros(*args, **kwargs):
return torch.zeros(*args, **kwargs).to("cuda")
return to_default_device(torch.zeros(*args, **kwargs))


def get_assumptions(constraints: list[Constraint]) -> list[Assumption]:
Expand Down
Loading

0 comments on commit f8e0cbb

Please sign in to comment.