Skip to content

Commit

Permalink
Add padding to reduce shared memory bank conflicts (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh-nod authored Oct 4, 2024
1 parent 4fec47c commit da3436d
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 24 deletions.
6 changes: 3 additions & 3 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_read_write_equal_sizes():
# CHECK-NEXT: %read_0_1
# CHECK-SAME: (%a, 4, None, None)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_shared_0_0
# CHECK-SAME: (%read_0_0, %allocate, 4, None)
# CHECK-NEXT: %write_shared_1_1
Expand Down Expand Up @@ -182,9 +182,9 @@ def test_gemm():
# CHECK-NEXT: %register_1_0_0
# CHECK-NEXT: %register_0_1_0
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction
# CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0]
# CHECK-NEXT: %getresult_1_1_0
Expand Down
24 changes: 12 additions & 12 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,12 @@ def mma(
# CHECK: %[[D9:.+]] = arith.muli %[[D8]], %[[C4]] : index
# CHECK: %[[D10:.+]] = vector.load %[[D0]][%[[D6]], %[[D9]]] : memref<64x16xf16, strided<[16, 1], offset:
# CHECK-SAME: ?>>, vector<4xf16>
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space<workgroup>>
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU:.+]].address_space<workgroup>>
# CHECK: %[[D11:.+]] = arith.addi %[[D4]], %[[D2]] : index
# CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16,
# CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x20xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: amdgpu.lds_barrier
# CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16,
# CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x20xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: %[[D13:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16,
# CHECK-SAME: strided<[16, 1], offset: ?>>
Expand All @@ -489,13 +489,13 @@ def mma(
# CHECK: %[[D17:.+]] = arith.addi %[[D16]], %[[D14]] : index
# CHECK: %[[D18:.+]] = vector.load %[[D13]][%[[D17]], %[[D9]]] : memref<128x16xf16, strided<[16, 1], offset:
# CHECK-SAME: ?>>, vector<4xf16>
# CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space<workgroup>>
# CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU]].address_space<workgroup>>
# CHECK: amdgpu.lds_barrier
# CHECK: %[[D19:.+]] = arith.addi %[[D4]], %[[D14]] : index
# CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16,
# CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x20xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: amdgpu.lds_barrier
# CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16,
# CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x20xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: %[[D21:.+]] = amdgpu.mfma %[[D12]] * %[[D20]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 :
# CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
Expand Down Expand Up @@ -593,8 +593,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index
# CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x
# CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space<workgroup>>
# CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space<workgroup>>
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU:.+]].address_space<workgroup>>
# CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU]].address_space<workgroup>>
# CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16,
# CHECK-SAME: strided<[64, 1], offset: ?>>
# CHECK: %[[D1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16,
Expand All @@ -620,18 +620,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: %[[D40:.+]] = arith.addi %[[D39]], %[[D10]] : index
# CHECK: %[[D41:.+]] = vector.load %[[D0]][%[[D7]], %[[D40]]] : memref<64x64xf16, strided<[64, 1], offset:
# CHECK-SAME: ?>>, vector<4xf16>
# CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16,
# CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x20xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: amdgpu.lds_barrier
# CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16,
# CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x20xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: %[[D43:.+]] = vector.load %[[D1]][%[[D15]], %[[D40]]] : memref<128x64xf16, strided<[64, 1],
# CHECK-SAME: offset: ?>>, vector<4xf16>
# CHECK: amdgpu.lds_barrier
# CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16,
# CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x20xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: amdgpu.lds_barrier
# CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16,
# CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x20xf16,
# CHECK-SAME: #[[GPU]].address_space<workgroup>>, vector<4xf16>
# CHECK: %[[D45:.+]] = amdgpu.mfma %[[D42]] * %[[D44]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16
# CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def test_gemm():
# CHECK-NEXT: %register_1_0_0
# CHECK-NEXT: %register_0_1_0
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction
# CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0]
# CHECK-NEXT: %getresult_1_1_0
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def test_gemm():
# CHECK-NEXT: %register_1_0_0
# CHECK-NEXT: %register_0_1_0
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction
# CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0]
# CHECK-NEXT: %getresult_1_1_0
Expand Down
8 changes: 4 additions & 4 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_read_write_equal_sizes():
# CHECK-NEXT: %read
# CHECK-SAME: (%a, 4, None, None)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_1
# CHECK-SAME: (%read, %allocate, 4, None)
# CHECK-NEXT: %read_1
Expand Down Expand Up @@ -123,7 +123,7 @@ def test_read_write_equal_sizes_different_address_spaces():
# CHECK-NEXT: %read
# CHECK-SAME: (%a, 4, None, None)
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %write_1
# CHECK-SAME: (%read, %allocate, 4, None)
# CHECK-NEXT: %read_1
Expand Down Expand Up @@ -181,9 +181,9 @@ def test_gemm():
# CHECK-NEXT: %c
# CHECK-NEXT: %register
# CHECK-NEXT: %allocate
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: %allocate_1
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE)
# CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE)
# CHECK-NEXT: reduction
# CHECK-NEXT: %write
# CHECK-SAME: (%reduction, %c, 4, None)
Expand Down
22 changes: 21 additions & 1 deletion shark_turbine/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,25 @@
logger = get_logger("turbine.wave.promotion")


def apply_padding(
shape: tuple[IndexSymbol | int], dtype: DataType
) -> tuple[IndexSymbol | int]:
"""
When accessing shared memory, we need to be cognizant of bank conflicts
that can have a significant impact on performance. One way to mitigate
these conflicts is by applying padding to the shared memory allocation.
This function applies padding of 64 bits to the shared memory allocation.
While this approach accomplishes the goal of reducing bank conflicts, it
is inefficient in terms of memory usage. A more sophisticated approach
would involve swizzling of the shared memory access patterns.
"""
padding = 64 // dtype.bitwidth()
return tuple(
value + padding if i == len(shape) - 1 else value
for i, value in enumerate(shape)
)


def apply_promotion_pattern(custom_node: Read | Write, allocate_node: Allocate):
match custom_node:
case Read(memory, elements_per_thread) if get_custom(
Expand Down Expand Up @@ -47,9 +66,10 @@ def promote_node(
assert isinstance(node, Read) or isinstance(node, Write)
with node.graph.inserting_before(node.fx_node.next):
constrained_shape = get_constrained_shape(node.type.symbolic_shape, constraints)
padded_shape = apply_padding(constrained_shape, node.type.dtype)
allocate_node = Allocate(
node.type.symbolic_shape,
constrained_shape,
padded_shape,
node.type.dtype,
address_space,
)
Expand Down

0 comments on commit da3436d

Please sign in to comment.