Skip to content

Commit

Permalink
Merge branch 'iree-org:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
amd-chrissosa authored Aug 26, 2024
2 parents c8f6b7a + 03eccd9 commit 060daac
Show file tree
Hide file tree
Showing 20 changed files with 1,072 additions and 147 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
strategy:
matrix:
version: [3.11]
os: [ubuntu-latest]
os: [ubuntu-latest,nodai-amdgpu-mi300-x86-64]
runs-on: ${{matrix.os}}
env:
PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache"
Expand Down Expand Up @@ -57,6 +57,12 @@ jobs:
run: |
pytest -n 4 .
- name: Run e2e tests on MI300
if: "contains(matrix.os, 'mi300') && !cancelled()"
run: |
export WAVE_RUN_E2E_TESTS=1
pytest -n 4 ./tests/kernel/wave/
- name: Run LIT tests
if: ${{ !cancelled() }}
run: |
Expand Down
10 changes: 5 additions & 5 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def tweak_index(graph: fx.Graph):
]
# Modify the write dependency index to trigger a barrier.
for promoted_read_node in promoted_read_nodes:
write_dependency = promoted_read_node.write_dependency
write_dependency = promoted_read_node.write_dependency[0]
for key, value in write_dependency.index.items():
write_dependency.index[key].start = value.start + 1

Expand Down Expand Up @@ -109,13 +109,13 @@ def test_read_write_equal_sizes():
# CHECK-SAME: (%read_0_1, %allocate, 4, None)
# CHECK-NEXT: %barrier
# CHECK-NEXT: %read_shared_0_0
# CHECK-SAME: (%allocate, 4, None, %write_shared_0_0)
# CHECK-SAME: (%allocate, 4, None, [%write_shared_0_0])
# CHECK-NEXT: %read_shared_1_1
# CHECK-SAME: (%allocate, 4, None, %write_shared_1_1)
# CHECK-SAME: (%allocate, 4, None, [%write_shared_1_1])
# CHECK-NEXT: %read_shared_1_0
# CHECK-SAME: (%allocate, 4, None, %write_shared_1_0)
# CHECK-SAME: (%allocate, 4, None, [%write_shared_1_0])
# CHECK-NEXT: %read_shared_0_1
# CHECK-SAME: (%allocate, 4, None, %write_shared_0_1)
# CHECK-SAME: (%allocate, 4, None, [%write_shared_0_1])
# CHECK-NEXT: %write_0_0
# CHECK-SAME: (%read_shared_0_0, %c, 4, None)
# CHECK-NEXT: %write_1_1
Expand Down
102 changes: 38 additions & 64 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def mma(
c = torch.zeros(64, 128, dtype=torch.float32)
print(mma(a, b, c).module_op)

# CHECK: func.func @mma(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %[[ARG2:.+]]: !stream.binding) {
# CHECK: func.func @mma(%[[ARG0:.+]]: !stream.binding, %[[ARG1:.+]]: !stream.binding, %[[ARG2:.+]]: !stream.binding)
# CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
# CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
# CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
Expand All @@ -305,37 +305,19 @@ def mma(
# CHECK: %[[R3:.+]] = arith.addi %[[R2]], %[[R1]] : index
# CHECK: %[[R4:.+]] = vector.load %0[%[[R3]], %[[C0]]] : memref<64x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
# CHECK: %[[R5:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R6:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R7:.+]] = arith.addi %[[R6]], %[[R5]] : index
# CHECK: vector.store %4, %[[ALLOC]][%[[R7]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R8:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R9:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R10:.+]] = arith.addi %[[R9]], %[[R8]] : index
# CHECK: %[[R11:.+]] = vector.load %[[ALLOC]][%[[R10]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: vector.store %4, %[[ALLOC]][%[[R3]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R11:.+]] = vector.load %[[ALLOC]][%[[R3]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R12:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, strided<[16, 1], offset: ?>>
# CHECK: %[[R13:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R14:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R15:.+]] = arith.addi %[[R14]], %[[R13]] : index
# CHECK: %[[R16:.+]] = vector.load %[[R12]][%[[R15]], %[[C0]]] : memref<128x16xf16, strided<[16, 1], offset: ?>>, vector<4xf16>
# CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #gpu.address_space<workgroup>>
# CHECK: %[[R17:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R18:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R19:.+]] = arith.addi %[[R18]], %[[R17]] : index
# CHECK: vector.store %16, %[[ALLOC_0]][%[[R19]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R20:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R21:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R22:.+]] = arith.addi %[[R21]], %[[R20]] : index
# CHECK: %[[R23:.+]] = vector.load %[[ALLOC_0]][%[[R22]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: vector.store %[[R16]], %[[ALLOC_0]][%[[R15]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R23:.+]] = vector.load %[[ALLOC_0]][%[[R15]], %[[C0]]] : memref<32x16xf16, #gpu.address_space<workgroup>>, vector<4xf16>
# CHECK: %[[R24:.+]] = amdgpu.mfma %[[R11]] * %[[R23]] + %[[ACC]] {blocks = 1 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
# CHECK: %[[R25:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<64x128xf32, strided<[128, 1], offset: ?>>
# CHECK: %[[R26:.+]] = arith.muli %[[WG0]], %[[C32]] : index
# CHECK: %[[R27:.+]] = arith.divsi %[[TX]], %[[C4]] : index
# CHECK: %[[R28:.+]] = arith.addi %[[R27]], %[[R26]] : index
# CHECK: %[[R29:.+]] = arith.muli %[[TY]], %[[C16]] : index
# CHECK: %[[R30:.+]] = arith.muli %[[WG1]], %[[C32]] : index
# CHECK: %[[R31:.+]] = arith.addi %[[R30]], %[[R29]] : index
# CHECK: vector.store %[[R24]], %[[R25]][%[[R28]], %[[R31]]] : memref<64x128xf32, strided<[128, 1], offset: ?>>, vector<4xf32>
# CHECK: vector.store %[[R24]], %[[R25]][%[[R3]], %[[R15]]] : memref<64x128xf32, strided<[128, 1], offset: ?>>, vector<4xf32>


@run_test
Expand Down Expand Up @@ -392,7 +374,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
print(gemm(a, b, c).module_op)

# CHECK: func.func @gemm(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: !stream.binding,
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) {
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding)
# CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
# CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
# CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
Expand All @@ -417,35 +399,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: %[[D13:.+]] = arith.muli %[[ARG3]], %[[C16]] : index
# CHECK: %[[D14:.+]] = vector.load %[[D8]][%[[D12]], %[[D13]]] : memref<64x64xf16, strided<[64, 1],
# CHECK-SAME: offset: ?>>, vector<4xf16>
# CHECK: %[[D15:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index
# CHECK: %[[D16:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C4]] : index
# CHECK: %[[D17:.+]] = arith.addi %[[D16]], %[[D15]] : index
# CHECK: %[[D18:.+]] = arith.muli %[[ARG3]], %[[C16]] : index
# CHECK: vector.store %[[D14]], %[[ALLOC]][%[[D17]], %[[D18]]] : memref<32x64xf16,
# CHECK: vector.store %[[D14]], %[[ALLOC]][%[[D12]], %[[D13]]] : memref<32x64xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: %[[D19:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C32]] : index
# CHECK: %[[D20:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C4]] : index
# CHECK: %[[D21:.+]] = arith.addi %[[D20]], %[[D19]] : index
# CHECK: %[[D22:.+]] = arith.muli %[[ARG3]], %[[C16]] : index
# CHECK: %[[D23:.+]] = vector.load %[[ALLOC]][%[[D21]], %[[D22]]] : memref<32x64xf16,
# CHECK: %[[D23:.+]] = vector.load %[[ALLOC]][%[[D12]], %[[D13]]] : memref<32x64xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: %[[D24:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index
# CHECK: %[[D25:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index
# CHECK: %[[D26:.+]] = arith.addi %[[D25]], %[[D24]] : index
# CHECK: %[[D27:.+]] = arith.muli %[[ARG3]], %[[C16]] : index
# CHECK: %[[D28:.+]] = vector.load %[[D9]][%[[D26]], %[[D27]]] : memref<128x64xf16, strided<[64, 1],
# CHECK: %[[D28:.+]] = vector.load %[[D9]][%[[D26]], %[[D13]]] : memref<128x64xf16, strided<[64, 1],
# CHECK-SAME: offset: ?>>, vector<4xf16>
# CHECK: %[[D29:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index
# CHECK: %[[D30:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index
# CHECK: %[[D31:.+]] = arith.addi %[[D30]], %[[D29]] : index
# CHECK: %[[D32:.+]] = arith.muli %[[ARG3]], %[[C16]] : index
# CHECK: vector.store %[[D28]], %[[ALLOC_0]][%[[D31]], %[[D32]]] : memref<32x64xf16,
# CHECK: vector.store %[[D28]], %[[ALLOC_0]][%[[D26]], %[[D13]]] : memref<32x64xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: %[[D33:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C16]] : index
# CHECK: %[[D34:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C32]] : index
# CHECK: %[[D35:.+]] = arith.addi %[[D34]], %[[D33]] : index
# CHECK: %[[D36:.+]] = arith.muli %[[ARG3]], %[[C16]] : index
# CHECK: %[[D37:.+]] = vector.load %[[ALLOC_0]][%[[D35]], %[[D36]]] : memref<32x64xf16,
# CHECK: %[[D37:.+]] = vector.load %[[ALLOC_0]][%[[D26]], %[[D13]]] : memref<32x64xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: %[[D38:.+]] = amdgpu.mfma %[[D23]] * %[[D37]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16
# CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
Expand Down Expand Up @@ -511,9 +476,8 @@ def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32]):
# CHECK: arith.addi %[[SLICE]], %[[SLICE]] : vector<16xi32>


@launch
@pytest.mark.skip(reason="neg: Currently only stub implementation")
def test_neg():
@run_test
def test_unary_lowerings():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16}
Expand All @@ -526,19 +490,20 @@ def test_neg():

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
res = -a
a_reg = tkw.read(a, elements_per_thread=4)
res = -a_reg
res = tkw.exp2(res)
tkw.write(res, a, elements_per_thread=4)

a = torch.randn(16, 16, dtype=torch.float16)
with pytest.raises(
NotImplementedError, match="neg: Currently only stub implementation"
):
test(a)
with codegen_test_context():
print(test(a).module_op)
# CHECK: %[[NEG:.+]] = arith.negf
# CHECK: math.exp2 %[[NEG]]


@launch
@pytest.mark.skip(reason="sub: Currently only stub implementation")
def test_sub():
@run_test
def test_binary_lowerings():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16}
Expand All @@ -550,15 +515,24 @@ def test_sub():
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

@tkw.wave(constraints)
def test(a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16]):
res = a - a
def test(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f16],
):
a_reg = tkw.read(a, elements_per_thread=4)
b_reg = tkw.read(b, elements_per_thread=4)
res = a_reg - b_reg
res = res * a_reg
res = res / b_reg
tkw.write(res, a, elements_per_thread=4)

a = torch.randn(16, 16, dtype=torch.float16)
with pytest.raises(
NotImplementedError, match="sub: Currently only stub implementation"
):
test(a)
b = torch.randn(16, 16, dtype=torch.float16)
with codegen_test_context():
print(test(a, b).module_op)
# CHECK: %[[SUB:.+]] = arith.subf
# CHECK: %[[MUL:.+]] = arith.mulf %[[SUB]]
# CHECK: %[[DIV:.+]] = arith.divf %[[MUL]]


@launch
Expand Down
Loading

0 comments on commit 060daac

Please sign in to comment.