diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index 8f8b4a6fc..fcb7dbeda 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -98,7 +98,7 @@ def test_read_write_equal_sizes(): # CHECK-NEXT: %read_0_1 # CHECK-SAME: (%a, 4, None, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_shared_0_0 # CHECK-SAME: (%read_0_0, %allocate, 4, None) # CHECK-NEXT: %write_shared_1_1 @@ -182,9 +182,9 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] # CHECK-NEXT: %getresult_1_1_0 diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index b65a90355..4800e9bd9 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -474,12 +474,12 @@ def mma( # CHECK: %[[D9:.+]] = arith.muli %[[D8]], %[[C4]] : index # CHECK: %[[D10:.+]] = vector.load %[[D0]][%[[D6]], %[[D9]]] : memref<64x16xf16, strided<[16, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> - # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU:.+]].address_space> # CHECK: %[[D11:.+]] = arith.addi %[[D4]], %[[D2]] : index - # CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, + # CHECK: vector.store %[[D10]], %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x16xf16, + # CHECK: %[[D12:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: %[[D13:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x16xf16, # CHECK-SAME: strided<[16, 1], offset: ?>> @@ -489,13 +489,13 @@ def mma( # CHECK: %[[D17:.+]] = arith.addi %[[D16]], %[[D14]] : index # CHECK: %[[D18:.+]] = vector.load %[[D13]][%[[D17]], %[[D9]]] : memref<128x16xf16, strided<[16, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> - # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU]].address_space> # CHECK: amdgpu.lds_barrier # CHECK: %[[D19:.+]] = arith.addi %[[D4]], %[[D14]] : index - # CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, + # CHECK: vector.store %[[D18]], %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x16xf16, + # CHECK: %[[D20:.+]] = vector.load %[[ALLOC_0]][%[[D19]], %[[D9]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: %[[D21:.+]] = amdgpu.mfma %[[D12]] * %[[D20]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> @@ -593,8 +593,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y - # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU:.+]].address_space> - # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x16xf16, #[[GPU]].address_space> + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU:.+]].address_space> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x20xf16, #[[GPU]].address_space> # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<64x64xf16, # CHECK-SAME: strided<[64, 1], offset: ?>> # CHECK: %[[D1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x64xf16, @@ -620,18 +620,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: %[[D40:.+]] = arith.addi %[[D39]], %[[D10]] : index # CHECK: %[[D41:.+]] = vector.load %[[D0]][%[[D7]], %[[D40]]] : memref<64x64xf16, strided<[64, 1], offset: # CHECK-SAME: ?>>, vector<4xf16> - # CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, + # CHECK: vector.store %[[D41]], %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D42:.+]] = vector.load %[[ALLOC]][%[[D11]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: %[[D43:.+]] = vector.load %[[D1]][%[[D15]], %[[D40]]] : memref<128x64xf16, strided<[64, 1], # CHECK-SAME: offset: ?>>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, + # CHECK: vector.store %[[D43]], %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: amdgpu.lds_barrier - # CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x16xf16, + # CHECK: %[[D44:.+]] = vector.load %[[ALLOC_0]][%[[D16]], %[[D10]]] : memref<32x20xf16, # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> # CHECK: %[[D45:.+]] = amdgpu.mfma %[[D42]] * %[[D44]] + %[[ARG4]] {blocks = 1 : i32, k = 16 : i32, m = 16 # CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 2bebc6908..7dd266ee7 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -98,9 +98,9 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] # CHECK-NEXT: %getresult_1_1_0 diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index dcf6b2258..7596a94b7 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -103,9 +103,9 @@ def test_gemm(): # CHECK-NEXT: %register_1_0_0 # CHECK-NEXT: %register_0_1_0 # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-SAME (K, [%register_0_0_0, %register_1_1_0, %register_1_0_0, %register_0_1_0] # CHECK-NEXT: %getresult_1_1_0 diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index 01db88cc6..3843c406b 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -74,7 +74,7 @@ def test_read_write_equal_sizes(): # CHECK-NEXT: %read # CHECK-SAME: (%a, 4, None, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_1 # CHECK-SAME: (%read, %allocate, 4, None) # CHECK-NEXT: %read_1 @@ -123,7 +123,7 @@ def test_read_write_equal_sizes_different_address_spaces(): # CHECK-NEXT: %read # CHECK-SAME: (%a, 4, None, None) # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, N), (BLOCK_M, BLOCK_N + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %write_1 # CHECK-SAME: (%read, %allocate, 4, None) # CHECK-NEXT: %read_1 @@ -181,9 +181,9 @@ def test_gemm(): # CHECK-NEXT: %c # CHECK-NEXT: %register # CHECK-NEXT: %allocate - # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((M, K), (BLOCK_M, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: %allocate_1 - # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K), f16, $SHARED_ADDRESS_SPACE) + # CHECK-SAME: ((N, K), (BLOCK_N, BLOCK_K + 4), f16, $SHARED_ADDRESS_SPACE) # CHECK-NEXT: reduction # CHECK-NEXT: %write # CHECK-SAME: (%reduction, %c, 4, None) diff --git a/shark_turbine/kernel/wave/promotion.py b/shark_turbine/kernel/wave/promotion.py index fd1aa541c..3711436f1 100644 --- a/shark_turbine/kernel/wave/promotion.py +++ b/shark_turbine/kernel/wave/promotion.py @@ -15,6 +15,25 @@ logger = get_logger("turbine.wave.promotion") +def apply_padding( + shape: tuple[IndexSymbol | int], dtype: DataType +) -> tuple[IndexSymbol | int]: + """ + When accessing shared memory, we need to be cognizant of bank conflicts + that can have a significant impact on performance. One way to mitigate + these conflicts is by applying padding to the shared memory allocation. + This function applies padding of 64 bits to the shared memory allocation. + While this approach accomplishes the goal of reducing bank conflicts, it + is inefficient in terms of memory usage. A more sophisticated approach + would involve swizzling of the shared memory access patterns. + """ + padding = 64 // dtype.bitwidth() + return tuple( + value + padding if i == len(shape) - 1 else value + for i, value in enumerate(shape) + ) + + def apply_promotion_pattern(custom_node: Read | Write, allocate_node: Allocate): match custom_node: case Read(memory, elements_per_thread) if get_custom( @@ -47,9 +66,10 @@ def promote_node( assert isinstance(node, Read) or isinstance(node, Write) with node.graph.inserting_before(node.fx_node.next): constrained_shape = get_constrained_shape(node.type.symbolic_shape, constraints) + padded_shape = apply_padding(constrained_shape, node.type.dtype) allocate_node = Allocate( node.type.symbolic_shape, - constrained_shape, + padded_shape, node.type.dtype, address_space, )