Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Add CDNA2 + CDNA3 Int8 intrinsics and refactor intrinsic enums #279

Merged
merged 11 commits into from
Nov 20, 2024
6 changes: 3 additions & 3 deletions .github/workflows/ci-tk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
fail-fast: false
matrix:
version: [3.11]
os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64]
os: [ubuntu-latest, nodai-amdgpu-mi300-x86-64, nodai-amdgpu-mi250-x86-64]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice :)

runs-on: ${{matrix.os}}
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
Expand Down Expand Up @@ -57,8 +57,8 @@ jobs:
run: |
pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/
- name: Run e2e tests on MI300
if: "contains(matrix.os, 'mi300') && !cancelled()"
- name: Run e2e tests on AMD GPU
if: "contains(matrix.os, 'amdgpu') && !cancelled()"
run: |
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1
Expand Down
58 changes: 46 additions & 12 deletions iree/turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,45 @@
from ..lang.global_symbols import *


"""
Formatting for different target intrinsics:
<kind>_<elem-type-C>_<M>x<N>x<K>_<elem-type-A>[_<elem-type-B>]
Values: 0xABCD where:
* A = vendor:
* 1 = AMD
* 2 = NVIDIA
* B = architecture. When an intrinsic exists in multiple architectures, this
should be the architecture it was introduced in, as long as it still
has the same semantics. If a new architecture breaks an existing
intrinsic's semantics, we can use that field for versioning.
* For AMD:
* 0 = CDNA1
* 1 = CDNA2
* 2 = CDNA3
* 8 = RDNA3
* C = element type of A-matrix:
* 0 = 64-bit float (e.g. IEEE754 double precision)
* 1 = 32-bit float (e.g. IEEE754 single precision, and "xf32" fast variants)
* 2 = 16-bit float (incl. IREE754 half and bf16)
* 3 = 8-bit float (incl. f8E5M2, f8E4M3, and "FNUZ" variants)
harsh-nod marked this conversation as resolved.
Show resolved Hide resolved
* C = 8-bit integer (any signedness)
* D enumerates intrinsics that share the same 0xABC* bits.
"""


class MMAType(Enum):
F32_16x16x16_F16 = 0
F32_32x32x8_F16 = 1
F32_16x16x32_F8 = 2
F32_32x32x16_F8 = 3
# Intrinsics introduced in CDNA1
F32_16x16x16_F16 = 0x1020
F32_32x32x8_F16 = 0x1021
I32_16x16x16_I8 = 0x10C0
I32_32x32x8_I8 = 0x10C1

# Intrinsics introduced in CDNA3
F32_16x16x32_F8 = 0x1230
F32_32x32x16_F8 = 0x1231
I32_16x16x32_I8 = 0x12C0
I32_32x32x16_I8 = 0x12C1


class MMAOperand(Enum):
Expand Down Expand Up @@ -89,13 +123,13 @@ def get_thread_id_from_workgroup_dim(self, workgroup_dim: int) -> IndexSymbol:
def mma_matrix_shapes(self) -> tuple[int]:
# TODO: Eventually the shapes and indices should be provided by a tool
match self.mma_type:
case MMAType.F32_16x16x16_F16:
case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8:
raikonenfnu marked this conversation as resolved.
Show resolved Hide resolved
return (16, 16, 16)
case MMAType.F32_32x32x8_F16:
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
return (32, 32, 8)
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8:
return (16, 16, 32)
case MMAType.F32_32x32x16_F8:
case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8:
return (32, 32, 16)
case _:
return ()
Expand Down Expand Up @@ -151,7 +185,7 @@ def apply(
lane = self.linearized_thread_id % self.threads_per_wave
match self.mma_type:
# (M x K, N x K) -> M x N
case MMAType.F32_16x16x16_F16:
case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8:
offset = [
Piecewise(
(lane % 16, ~MMA_ACC),
Expand All @@ -170,7 +204,7 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_32x32x8_F16:
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
offset = [
Piecewise(
(lane % 32, ~MMA_ACC),
Expand All @@ -194,7 +228,7 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8:
offset = [
Piecewise(
(lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC)
Expand All @@ -212,7 +246,7 @@ def apply(
1, # N
1, # K
]
case MMAType.F32_32x32x16_F8:
case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8:
offset = [
Piecewise(
(lane % 32, ~MMA_ACC),
Expand Down
12 changes: 8 additions & 4 deletions iree/turbine/kernel/wave/iree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from typing import Any
from .utils import compile_and_invoke
from ...support.conversions import TORCH_DTYPE_TO_MLIR_TYPE_ASM
from ...support.conversions import TORCH_DTYPE_TO_IREE_TYPE_ASM


def get_chain_mmt_asm(
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_mmt_asm(
if not cast_fp8:
matmul_function = f"""
func.func @{func_name}(%lhs: tensor<{lhs_type}>, %rhs: tensor<{rhs_type}>) -> tensor<{acc_type}> {{
%c0 = arith.constant 0.0 : {acc_dtype}
%c0 = arith.constant {"0.0" if acc_dtype.startswith("f") else "0"} : {acc_dtype}
%init = tensor.empty() : tensor<{acc_type}>
%inital_result = linalg.fill ins(%c0 : {acc_dtype}) outs(%init : tensor<{acc_type}>) -> tensor<{acc_type}>
%result = linalg.{operator} ins(%lhs, %rhs: tensor<{lhs_type}>, tensor<{rhs_type}>)
Expand Down Expand Up @@ -138,7 +138,7 @@ def get_conv_asm(


def dtype_str(dtype: torch.dtype) -> str:
dtype_str = TORCH_DTYPE_TO_MLIR_TYPE_ASM.get(dtype, None)
dtype_str = TORCH_DTYPE_TO_IREE_TYPE_ASM[dtype]
raikonenfnu marked this conversation as resolved.
Show resolved Hide resolved
if dtype_str is None:
raise ValueError(f"Unsupported dtype: {dtype}")
return dtype_str
Expand Down Expand Up @@ -166,7 +166,11 @@ def generate_iree_ref(
rhs_type = get_type_str(kernel_inputs[1].shape, kernel_inputs[1].dtype)
acc_type = get_type_str(kernel_outputs[0].shape, kernel_outputs[0].dtype)
asm = get_mmt_asm(
lhs_type, rhs_type, acc_type, batch=False, cast_fp8=kernel_type == "mmt_f8"
lhs_type,
rhs_type,
acc_type,
batch=False,
cast_fp8=kernel_type == "mmt_f8",
)
elif kernel_type == "bmmt":
lhs_type = get_type_str(kernel_inputs[0].shape, kernel_inputs[0].dtype)
Expand Down
52 changes: 39 additions & 13 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,26 @@ def run_test(func: Callable[[], None]) -> Callable[[], None]:
return func


def get_default_arch() -> str:
"""Return default ROCM architecture"""
if not torch.cuda.is_available():
return "cpu"
device = torch.device("cuda")
gcnArch = torch.cuda.get_device_properties(device).gcnArchName
assert "gfx" in gcnArch, "Currently only support GFX/ROCm for get_default_arch."
# The gcnArchName comes back like gfx90a:sramecc+:xnack.
colon_pos = gcnArch.find(":")
return gcnArch[0:colon_pos]


def get_default_run_config() -> dict[Any, Any]:
"""Return default config for testing."""
"""Return default config for running."""
arch = get_default_arch()
return {"backend": "rocm", "device": "hip", "target": arch}


def get_default_compile_config() -> dict[Any, Any]:
"""Return default config for compilation."""
return {"backend": "rocm", "device": "hip", "target": "gfx942"}


Expand Down Expand Up @@ -880,25 +898,25 @@ def ceildiv(a: int, b: int) -> int:

def get_mfma_load_elems_per_thread(mfma_variant: MMAType) -> int:
match mfma_variant:
case MMAType.F32_16x16x16_F16:
case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8:
return 4
case MMAType.F32_32x32x8_F16:
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
return 4
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8:
return 8
case MMAType.F32_32x32x16_F8:
case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8:
return 8


def get_mfma_store_elems_per_thread(mfma_variant: MMAType) -> int:
match mfma_variant:
case MMAType.F32_16x16x16_F16:
case MMAType.F32_16x16x16_F16 | MMAType.I32_16x16x16_I8:
return 4
case MMAType.F32_32x32x8_F16:
case MMAType.F32_32x32x8_F16 | MMAType.I32_32x32x8_I8:
return 16
case MMAType.F32_16x16x32_F8:
case MMAType.F32_16x16x32_F8 | MMAType.I32_16x16x32_I8:
return 4
case MMAType.F32_32x32x16_F8:
case MMAType.F32_32x32x16_F8 | MMAType.I32_32x32x16_I8:
return 16


Expand All @@ -908,20 +926,28 @@ def all_equal(input_list: list[Any]) -> bool:
return all(elem == input_list[0] for elem in input_list)


def get_default_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
raikonenfnu marked this conversation as resolved.
Show resolved Hide resolved


def to_default_device(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(get_default_device())


def device_randn(*args, **kwargs):
return torch.randn(*args, **kwargs).to("cuda")
return to_default_device(torch.randn(*args, **kwargs))


def device_randint(*args, **kwargs):
return torch.randint(*args, **kwargs).to("cuda")
return to_default_device(torch.randint(*args, **kwargs))


def device_randperm(*args, **kwargs):
return torch.randperm(*args, **kwargs).to("cuda")
return to_default_device(torch.randperm(*args, **kwargs))


def device_zeros(*args, **kwargs):
return torch.zeros(*args, **kwargs).to("cuda")
return to_default_device(torch.zeros(*args, **kwargs))


def get_assumptions(constraints: list[Constraint]) -> list[Assumption]:
Expand Down
Loading