Skip to content

Commit

Permalink
[GPU] Add gather fusion tests for vector distribution (iree-org#19209)
Browse files Browse the repository at this point in the history
VectorDistribution now supports gather fusion on producers. This pr adds
pipeline tests for that. There are still numerical issues being tracked
seperatly,related to distribution of gather.
  • Loading branch information
Groverkss authored Dec 4, 2024
1 parent 6ff85df commit 29229df
Showing 1 changed file with 142 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,77 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb">) {

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>]>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {}>
#config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, promote_operands = [0, 1], reduction = [0, 0, 64], subgroup_m_count = 2 : i64, subgroup_n_count = 2 : i64, workgroup = [64, 128, 0]}>

hal.executable public @matmul_gather_rhs {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb">) {
hal.executable.export public @matmul_gather_rhs ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @matmul_gather_rhs() attributes {translation_info = #translation} {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x64xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x64xi64>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<4096x64xf16>>
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<4096x4096xf16>>
%4 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [4096, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x64xf16>> -> tensor<4096x64xf16>
%5 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [4096, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x64xi64>> -> tensor<4096x64xi64>
%6 = flow.dispatch.tensor.load %2, offsets = [0, 0], sizes = [4096, 64], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<4096x64xf16>> -> tensor<4096x64xf16>
%7 = tensor.empty() : tensor<4096x4096xf16>
%8 = tensor.empty() : tensor<4096x4096xf32>
%9 = tensor.empty() : tensor<4096x64xf16>
%10 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%5 : tensor<4096x64xi64>) outs(%9 : tensor<4096x64xf16>) {
^bb0(%in: i64, %out: f16):
%14 = linalg.index 0 : index
%15 = arith.index_cast %in : i64 to index
%extracted = tensor.extract %4[%14, %15] : tensor<4096x64xf16>
linalg.yield %extracted : f16
} -> tensor<4096x64xf16>
%11 = linalg.fill ins(%cst : f32) outs(%8 : tensor<4096x4096xf32>) -> tensor<4096x4096xf32>
%12 = linalg.generic {indexing_maps = [#map1, #map2, #map3],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%6, %10 : tensor<4096x64xf16>, tensor<4096x64xf16>)
outs(%11 : tensor<4096x4096xf32>)
attrs = {lowering_config = #config} {
^bb0(%in: f16, %in_0: f16, %out: f32):
%14 = arith.extf %in : f16 to f32
%15 = arith.extf %in_0 : f16 to f32
%16 = arith.mulf %14, %15 : f32
%17 = arith.addf %out, %16 : f32
linalg.yield %17 : f32
} -> tensor<4096x4096xf32>
%13 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%12 : tensor<4096x4096xf32>) outs(%7 : tensor<4096x4096xf16>) {
^bb0(%in: f32, %out: f16):
%14 = arith.truncf %in : f32 to f16
linalg.yield %14 : f16
} -> tensor<4096x4096xf16>
flow.dispatch.tensor.store %13, %3, offsets = [0, 0], sizes = [4096, 4096], strides = [1, 1] : tensor<4096x4096xf16> -> !flow.dispatch.tensor<writeonly:tensor<4096x4096xf16>>
return
}
}
}
}

// CHECK-LABEL: func.func @matmul_gather_rhs
// CHECK: vector.gather
// CHECK-COUNT-32: amdgpu.mfma

// -----

#config = #iree_gpu.lowering_config<{workgroup = [1, 64, 0, 0, 64], reduction = [0, 0, 0, 64, 0], promote_operands = [0, 1, 2]}>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64>

Expand Down Expand Up @@ -1169,3 +1240,74 @@ hal.executable private @online_attention_split_k2 {
// MEMORY-LABEL: func.func @online_attention_split_k2()
// MEMORY-COUNT-3: memref.alloc
// MEMORY-NOT: memref.alloc

// -----

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb">
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
#map5 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>
#translation = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [128, 1, 1] subgroup_size = 64, {}>

#qk_config = {attention_qk_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, promote_operands = [0, 1], subgroup_m_count = 2 : i64, subgroup_n_count = 1 : i64}>}
#pv_config = {attention_pv_matmul, lowering_config = #iree_gpu.lowering_config<{mma_kind = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, promote_operands = [1], subgroup_m_count = 2 : i64, subgroup_n_count = 1 : i64}>}
#config = #iree_gpu.lowering_config<{promote_operands = [0, 1, 2], reduction = [0, 0, 0, 0, 0, 64], workgroup = [1, 1, 64, 64, 0, 0]}>

module {
hal.executable public @attention_gather_k {
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
hal.executable.export public @attention_gather_k ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device):
%x, %y, %z = flow.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @attention_gather_k() attributes {translation_info = #translation} {
%cst = arith.constant 1.250000e-01 : f16
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x10x4096x64xf16>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x10x4096x64xi64>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x10x4096x64xf16>>
%3 = hal.interface.binding.subspan layout(#pipeline_layout) binding(3) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<2x10x4096x64xf16>>
%4 = hal.interface.binding.subspan layout(#pipeline_layout) binding(4) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<2x10x4096x64xf16>>
%5 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x10x4096x64xf16>> -> tensor<2x10x4096x64xf16>
%6 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x10x4096x64xi64>> -> tensor<2x10x4096x64xi64>
%7 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x10x4096x64xf16>> -> tensor<2x10x4096x64xf16>
%8 = flow.dispatch.tensor.load %3, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<2x10x4096x64xf16>> -> tensor<2x10x4096x64xf16>
%9 = tensor.empty() : tensor<2x10x4096x64xf16>
%10 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6 : tensor<2x10x4096x64xi64>) outs(%9 : tensor<2x10x4096x64xf16>) {
^bb0(%in: i64, %out: f16):
%12 = linalg.index 0 : index
%13 = linalg.index 1 : index
%14 = arith.index_cast %in : i64 to index
%15 = linalg.index 3 : index
%extracted = tensor.extract %5[%12, %13, %14, %15] : tensor<2x10x4096x64xf16>
linalg.yield %extracted : f16
} -> tensor<2x10x4096x64xf16>
%11 = iree_linalg_ext.attention {
indexing_maps = [#map1, #map2, #map3, #map4, #map5],
decomposition_config = { qk_attrs = #qk_config, pv_attrs = #pv_config },
lowering_config = #config} ins(%7, %10, %8, %cst : tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, tensor<2x10x4096x64xf16>, f16) outs(%9 : tensor<2x10x4096x64xf16>) {
^bb0(%arg0: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<2x10x4096x64xf16>
flow.dispatch.tensor.store %11, %4, offsets = [0, 0, 0, 0], sizes = [2, 10, 4096, 64], strides = [1, 1, 1, 1] : tensor<2x10x4096x64xf16> -> !flow.dispatch.tensor<writeonly:tensor<2x10x4096x64xf16>>
return
}
}
}
}
}

// CHECK-LABEL: func.func @attention_gather_k
// CHECK: scf.for %{{.*}} = %c0 to %c4096 step %c64
// CHECK: vector.gather
// CHECK-SAME: into vector<4x1x1x1x1x8xf16>
// CHECK: scf.yield

// MEMORY-LABEL: func.func @attention_gather_k
// MEMORY-COUNT-3: memref.alloc

0 comments on commit 29229df

Please sign in to comment.