Skip to content

Commit

Permalink
[LLVMGPU] Add multi-row vector reduction configuration (#73)
Browse files Browse the repository at this point in the history
This is to speed up matvec. The new configuration is experimental and
only applied on ROCm targets.
  • Loading branch information
kuhar authored Dec 6, 2023
1 parent 70bfff7 commit 04753d5
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class VectorReductionToGPUPass
bool expandSubgroupReduction,
std::function<int(func::FuncOp)> getWarpSize)
: expandSubgroupReduction(expandSubgroupReduction),
getWarpSize(getWarpSize) {}
getWarpSize(std::move(getWarpSize)) {}

void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,68 @@ hal.executable private @shared_memory_copy {
// CHECK: vector.transfer_read %[[ALLOC]]{{.*}} : memref<32xf32, #gpu.address_space<workgroup>>, vector<1xf32>
// CHECK: vector.transfer_write {{.*}} : vector<1xf32>, memref<128x32xf32>
// CHECK: return


// -----

// Check that we multi-row matvec gets distributed across subgoroup threads.

#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
hal.executable private @multirow {
hal.executable.variant @rocm target(#executable_target_rocm_hsaco_fb) {
hal.executable.export @multirow layout(#pipeline_layout) attributes {
workgroup_size = [64 : index, 1 : index, 1 : index]
}
builtin.module {
func.func @multirow() {
%cst = arith.constant dense<0.000000e+00> : vector<4x512xf16>
%c0 = arith.constant 0 : index
%cst_0 = arith.constant dense<0.000000e+00> : vector<1x4xf16>
%c4096 = arith.constant 4096 : index
%c512 = arith.constant 512 : index
%cst_1 = arith.constant 0.000000e+00 : f16
%id = gpu.thread_id x
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%3 = affine.apply affine_map<()[s0] -> (s0 * 4)>()[%workgroup_id_x]
%4 = scf.for %arg0 = %c0 to %c4096 step %c512 iter_args(%arg1 = %cst) -> (vector<4x512xf16>) {
%8 = vector.transfer_read %0[%c0, %arg0], %cst_1 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (0, d1)>} : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x512xf16>
%9 = vector.transfer_read %1[%3, %arg0], %cst_1 {in_bounds = [true, true]} : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x512xf16>
%10 = arith.mulf %8, %9 : vector<4x512xf16>
%11 = arith.addf %arg1, %10 : vector<4x512xf16>
scf.yield %11 : vector<4x512xf16>
}
%5 = vector.broadcast %4 : vector<4x512xf16> to vector<1x4x512xf16>
%6 = vector.multi_reduction <add>, %5, %cst_0 [2] : vector<1x4x512xf16> to vector<1x4xf16>
%7 = vector.extract %6[0] : vector<4xf16> from vector<1x4xf16>
vector.transfer_write %7, %2[%c0, %3] {in_bounds = [true]} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
return
}
}
}
}

// CHECK-LABEL: func.func @multirow() {
// CHECK: scf.for {{.*}} -> (vector<4x8xf16>) {
// CHECK: vector.transfer_read {{.*}} : memref<32000x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x8xf16>
// CHECK: vector.transfer_read {{.*}} : memref<1x4096xf16, #hal.descriptor_type<storage_buffer>>, vector<4x8xf16>
// CHECK: arith.mulf %{{.*}}, %{{.*}} : vector<4x8xf16>
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4x8xf16>
// CHECK: }
// CHECK: gpu.shuffle xor
// CHECK: scf.if {{.*}} {
// CHECK: vector.transfer_write {{.*}} : vector<4xf16>, memref<1x32000xf16, #hal.descriptor_type<storage_buffer>>
// CHECK: }
// CHECK-NEXT: return
21 changes: 21 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
Expand Down Expand Up @@ -926,6 +928,25 @@ static LogicalResult setWarpReductionConfig(func::FuncOp entryPoint,
if ((groupSize / subgroupSize) > subgroupSize)
return failure();

// With just one subgroup per workgroup, make each subgroup do more work and
// process a few reductions along the last parallel dimension.
// TODO: We should also check that this will result in data reuse for at least
// one argument.
// TODO: This is experimental for matvec (matmul_transpose_b) on rocm-only for
// now.
if (numDynamicReductionDims == 0 && numParallelDims == 2 &&
isRocmTarget(entryPoint)) {
if (*parallelSize && !parallelDims.empty() && groupSize == subgroupSize) {
int maxParallelFactor = 4; // Keeping this conservative for now.
int64_t lastParallelBound = bounds[parallelDims.back()];
if (!ShapedType::isDynamic(lastParallelBound) &&
(lastParallelBound % maxParallelFactor == 0) &&
lastParallelBound > maxParallelFactor) {
workgroupTileSizes.back() = maxParallelFactor;
}
}
}

std::array<int64_t, 3> workgroupSize = {groupSize, 1, 1};
SmallVector<int64_t> reductionTileSizes(op.getNumLoops(), 0);
int64_t remainingGroupSize = groupSize;
Expand Down
47 changes: 47 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/test/config_matvec.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,50 @@ hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gf
// CHECK: func.func @dynamic_batch_matvec()
// CHECK: linalg.batch_matmul
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>

hal.executable @vmt {
hal.executable.variant @rocm target(<"rocm", "rocm-hsaco-fb", {target_arch = "gfx940"}>) {
hal.executable.export @vmt layout(#pipeline_layout)
builtin.module {
func.func @vmt() {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<1x4096xf16>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x32000xf16>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [1, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<1x4096xf16>> -> tensor<1x4096xf16>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0], sizes = [32000, 4096], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<32000x4096xf16>> -> tensor<32000x4096xf16>
%5 = tensor.empty() : tensor<1x32000xf16>
%6 = linalg.fill ins(%cst : f16) outs(%5 : tensor<1x32000xf16>) -> tensor<1x32000xf16>
%7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%3, %4 : tensor<1x4096xf16>, tensor<32000x4096xf16>) outs(%6 : tensor<1x32000xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%8 = arith.mulf %in, %in_0 : f16
%9 = arith.addf %out, %8 : f16
linalg.yield %9 : f16
} -> tensor<1x32000xf16>
flow.dispatch.tensor.store %7, %2, offsets = [0, 0], sizes = [1, 32000], strides = [1, 1] : tensor<1x32000xf16> -> !flow.dispatch.tensor<writeonly:tensor<1x32000xf16>>
return
}
}
}
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 4], [0, 0, 512]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<LLVMGPUWarpReduction>
// CHECK-LABEL: hal.executable.export public @vmt
// CHECK-SAME: subgroup_size = 64 : index
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK-SAME: workgroup_size = [64 : index, 1 : index, 1 : index]
// CHECK: func.func @vmt()
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]
2 changes: 1 addition & 1 deletion third_party/llvm-project

0 comments on commit 04753d5

Please sign in to comment.