Skip to content

Commit

Permalink
Reapply "[Codegen][GPU] Add range information to GPU dispatch IDs" (#…
Browse files Browse the repository at this point in the history
…19361) (#19372)

This reverts commit cb5be1d.

Compaled to the previous revision, this one works around a correctness
bug in dataflow analysis that's being fixed by removing the analysis
after SCF->CF.

---

First, this patch implements InferIntRangeInterface for
hal.interface.workgroup.{size,id,count} using a local upper_bound
attribute.

Then, it adds a -iree-codegen-gpu-propagate-dispatch-size-bounds pass
that adds these upper_bounds identifiers to the interface.workgroup
operations and to gpu.thread_id based on static information available
late in the codegen pipeline.

Then, it uses -optimize-int-arithmetic to optimize indexing after
-lower-affine, getting rid of a bunch of "if the input's negative" logic
that isn't actually needed in many of our kernels.

It also ensures that these upper_bound values propagate to LLVM.
  • Loading branch information
krzysz00 authored Dec 13, 2024
1 parent 442956c commit 63cdc7d
Show file tree
Hide file tree
Showing 16 changed files with 317 additions and 58 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ iree_compiler_cc_library(
"GPUPatterns.cpp",
"GPUPipelining.cpp",
"GPUPromoteMatmulOperands.cpp",
"GPUPropagateDispatchSizeBounds.cpp",
"GPUReduceBankConflicts.cpp",
"GPUReuseSharedMemoryAllocs.cpp",
"GPUTensorAlloc.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ iree_cc_library(
"GPUPatterns.cpp"
"GPUPipelining.cpp"
"GPUPromoteMatmulOperands.cpp"
"GPUPropagateDispatchSizeBounds.cpp"
"GPUReduceBankConflicts.cpp"
"GPUReuseSharedMemoryAllocs.cpp"
"GPUTensorAlloc.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/Passes.h"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_GPUPROPAGATEDISPATCHSIZEBOUNDSPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"

namespace {

static void applyBounds(FunctionOpInterface funcOp,
ArrayRef<int32_t> workgroupSizes,
ArrayRef<int32_t> workgroupCounts) {
Builder b(funcOp->getContext());
funcOp->walk([&](Operation *op) {
TypeSwitch<Operation *>(op)
.Case([&](gpu::ThreadIdOp tidOp) {
tidOp.setUpperBoundAttr(b.getIndexAttr(
workgroupSizes[static_cast<uint32_t>(tidOp.getDimension())]));
})
.Case([&](IREE::HAL::InterfaceWorkgroupSizeOp wgSizeOp) {
wgSizeOp.setUpperBoundAttr(b.getIndexAttr(
workgroupSizes[wgSizeOp.getDimension().getZExtValue()]));
})
.Case([&](IREE::HAL::InterfaceWorkgroupIDOp wgIdOp) {
wgIdOp.setUpperBoundAttr(b.getIndexAttr(
workgroupCounts[wgIdOp.getDimension().getZExtValue()]));
})
.Case([&](IREE::HAL::InterfaceWorkgroupCountOp wgCountOp) {
wgCountOp.setUpperBoundAttr(b.getIndexAttr(
workgroupCounts[wgCountOp.getDimension().getZExtValue()]));
})
.Default([](Operation *) {});
});
}

struct GPUPropagateDispatchSizeBoundsPass final
: impl::GPUPropagateDispatchSizeBoundsPassBase<
GPUPropagateDispatchSizeBoundsPass> {
using Base::Base;

void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
IREE::GPU::TargetAttr target = getGPUTargetAttr(funcOp);
if (!target) {
funcOp.emitWarning("no known target attribute late in GPU codegen");
return;
}
SmallVector<int32_t, 3> workgroupSizes(
target.getWgp().getMaxWorkgroupSizes().asArrayRef());
SmallVector<int32_t, 3> workgroupCounts(
target.getWgp().getMaxWorkgroupCounts().asArrayRef());

std::optional<SmallVector<int64_t>> staticWorkgroupSize =
getWorkgroupSize(funcOp);

// Late in codegen, we've reconciled the workgroup size onto the export op.
if (std::optional<IREE::HAL::ExecutableExportOp> exportOp =
getEntryPoint(funcOp)) {
if (std::optional<ArrayAttr> exportWorkgroupSize =
exportOp->getWorkgroupSize()) {
staticWorkgroupSize =
llvm::map_to_vector(exportWorkgroupSize->getAsRange<IntegerAttr>(),
[](IntegerAttr a) { return a.getInt(); });
}
}

if (staticWorkgroupSize) {
// Target info with no workgroup sizes gives a 0-length array, hence no
// zip_equal.
for (auto [size, staticSize] :
llvm::zip(workgroupSizes, *staticWorkgroupSize)) {
size = staticSize;
}
}
SmallVector<int64_t> staticWorkgroupCounts = getStaticNumWorkgroups(funcOp);
assert(staticWorkgroupCounts.size() <= 3 &&
"workgroup counts are 3D at most");
for (auto [count, staticCount] :
llvm::zip(workgroupCounts, staticWorkgroupCounts)) {
if (staticCount != ShapedType::kDynamic) {
count = staticCount;
}
}

applyBounds(funcOp, workgroupSizes, workgroupCounts);
}
};
} // namespace

} // namespace mlir::iree_compiler
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ def GPUPromoteMatmulOperandsPass :
];
}

def GPUPropagateDispatchSizeBoundsPass :
InterfacePass<"iree-codegen-gpu-propagate-dispatch-size-bounds", "mlir::FunctionOpInterface"> {
let summary = "Pass to annotate workitem and workgroup IDs with known bounds";
}

def GPUReduceBankConflictsPass :
InterfacePass<"iree-codegen-gpu-reduce-bank-conflicts", "mlir::FunctionOpInterface"> {
let summary = "Pass to try to reduce the number of bank conflicts by padding memref.alloc ops.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ iree_lit_test_suite(
"gpu_pad_operands.mlir",
"gpu_pipeline.mlir",
"gpu_promote_matmul_operands.mlir",
"gpu_propagate_dispatch_size_bounds.mlir",
"gpu_reorder_workgroups_static.mlir",
"gpu_reorder_workgroups.mlir",
"gpu_reuse_shared_memory_allocs.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ iree_lit_test_suite(
"gpu_pad_operands.mlir"
"gpu_pipeline.mlir"
"gpu_promote_matmul_operands.mlir"
"gpu_propagate_dispatch_size_bounds.mlir"
"gpu_reorder_workgroups.mlir"
"gpu_reorder_workgroups_static.mlir"
"gpu_reuse_shared_memory_allocs.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// RUN: iree-opt %s --split-input-file \
// RUN: --pass-pipeline="builtin.module(hal.executable(hal.executable.variant(builtin.module(func.func(iree-codegen-gpu-propagate-dispatch-size-bounds)))))" \
// RUN: | FileCheck %s

// Note: not the real target definition, missing types
#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
wgp = <compute = fp32,
storage = b32,
subgroup = arithmetic,
dot = none, mma = [],
subgroup_size_choices = [32, 64],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>

hal.executable private @static {
hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
hal.executable.export public @static ordinal(0) layout(#pipeline_layout) attributes {workgroup_size = [64 : index, 2 : index, 1 : index]} {
^bb0(%arg0: !hal.device):
%c32 = arith.constant 32 : index
%c8 = arith.constant 8 : index
%c1 = arith.constant 1 : index
hal.return %c32, %c8, %c1 : index, index, index
}
builtin.module {
// CHECK-LABEL: func.func @static
func.func @static() {
// CHECK: gpu.thread_id x upper_bound 64
// CHECK: gpu.thread_id y upper_bound 2
// CHECK: gpu.thread_id z upper_bound 1
%thread_id_x = gpu.thread_id x
%thread_id_y = gpu.thread_id y
%thread_id_z = gpu.thread_id z

// CHECK: hal.interface.workgroup.size[0] upper_bound 64
// CHECK: hal.interface.workgroup.size[1] upper_bound 2
// CHECK: hal.interface.workgroup.size[2] upper_bound 1
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%workgroup_size_z = hal.interface.workgroup.size[2] : index

// CHECK: hal.interface.workgroup.id[0] upper_bound 32
// CHECK: hal.interface.workgroup.id[1] upper_bound 8
// CHECK: hal.interface.workgroup.id[2] upper_bound 1
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index

// CHECK: hal.interface.workgroup.count[0] upper_bound 32
// CHECK: hal.interface.workgroup.count[1] upper_bound 8
// CHECK: hal.interface.workgroup.count[2] upper_bound 1
%workgroup_conut_x = hal.interface.workgroup.count[0] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%workgroup_count_z = hal.interface.workgroup.count[2] : index

return
}
}
}
}

// -----

#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb",
{iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
wgp = <compute = fp32,
storage = b32,
subgroup = arithmetic,
dot = none, mma = [],
subgroup_size_choices = [32, 64],
max_workgroup_sizes = [1024, 1024, 1024],
max_thread_count_per_workgroup = 1024,
max_workgroup_memory_bytes = 65536,
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>

hal.executable private @dynamic {
hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
hal.executable.export public @dynamic ordinal(0) layout(#pipeline_layout) {
^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
%count_x = affine.apply affine_map<()[s0] -> (s0 ceildiv 32)>()[%arg1]
%count_y = affine.apply affine_map<()[s0] -> (s0 ceildiv 8)>()[%arg2]
%count_z = arith.constant 1 : index
hal.return %count_x, %count_y, %count_z : index, index, index
}
builtin.module {
func.func @dynamic() {
// CHECK: gpu.thread_id x upper_bound 1024
// CHECK: gpu.thread_id y upper_bound 1024
// CHECK: gpu.thread_id z upper_bound 1024
%thread_id_x = gpu.thread_id x
%thread_id_y = gpu.thread_id y
%thread_id_z = gpu.thread_id z

// CHECK: hal.interface.workgroup.size[0] upper_bound 1024
// CHECK: hal.interface.workgroup.size[1] upper_bound 1024
// CHECK: hal.interface.workgroup.size[2] upper_bound 1024
%workgroup_size_x = hal.interface.workgroup.size[0] : index
%workgroup_size_y = hal.interface.workgroup.size[1] : index
%workgroup_size_z = hal.interface.workgroup.size[2] : index

// CHECK: hal.interface.workgroup.id[0] upper_bound 2147483647
// CHECK: hal.interface.workgroup.id[1] upper_bound 2147483647
// CHECK: hal.interface.workgroup.id[2] upper_bound 1
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_id_y = hal.interface.workgroup.id[1] : index
%workgroup_id_z = hal.interface.workgroup.id[2] : index

// CHECK: hal.interface.workgroup.count[0] upper_bound 2147483647
// CHECK: hal.interface.workgroup.count[1] upper_bound 2147483647
// CHECK: hal.interface.workgroup.count[2] upper_bound 1
%workgroup_conut_x = hal.interface.workgroup.count[0] : index
%workgroup_count_y = hal.interface.workgroup.count[1] : index
%workgroup_count_z = hal.interface.workgroup.count[2] : index

return
}
}
}
}
5 changes: 4 additions & 1 deletion compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,10 @@ struct HALInterfaceWorkgroupOpsConverter final
int32_t index = static_cast<int32_t>(op.getDimension().getSExtValue());
std::array<gpu::Dimension, 3> dimAttr{gpu::Dimension::x, gpu::Dimension::y,
gpu::Dimension::z};
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), dimAttr[index]);
NewOpTy newOp =
rewriter.replaceOpWithNewOp<NewOpTy>(op, op.getType(), dimAttr[index]);
if (IntegerAttr bound = op.getUpperBoundAttr())
newOp.setUpperBoundAttr(bound);
return success();
}
};
Expand Down
16 changes: 11 additions & 5 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,13 @@ addLowerAndOptimizeAddressComputationPasses(FunctionLikeNest &funcPassManager) {
.addPass(createCSEPass)
// Hoist the resulting decompositions.
.addPass(createIREELoopInvariantCodeMotionPass)
.addPass(createLowerAffinePass);
.addPass(affine::createAffineExpandIndexOpsPass)
.addPass(createLowerAffinePass)
.addPass(IREE::Util::createOptimizeIntArithmeticPass)
// Do another round of LICM now that we've lowered and optimized
// arithmetic
.addPass(createCSEPass)
.addPass(createIREELoopInvariantCodeMotionPass);
}

static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
Expand Down Expand Up @@ -1103,7 +1109,9 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
FunctionLikeNest funcPassManager(modulePassManager);
funcPassManager.addPass(createFoldTensorExtractOpPass)
.addPass(createLLVMGPUVectorLoweringPass)
.addPass(createExpandGPUOpsPass);
.addPass(createExpandGPUOpsPass)
// Expose workitem and workgroup counts to range inference later.
.addPass(createGPUPropagateDispatchSizeBoundsPass);

// This pass needs to run before SCF -> CF.
addLowerAndOptimizeAddressComputationPasses(funcPassManager);
Expand All @@ -1130,9 +1138,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
.addPass(memref::createExpandStridedMetadataPass)
.addPass(createEmulateNarrowTypePass)
.addPass(affine::createAffineExpandIndexOpsPass)
.addPass(createLowerAffinePass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass);
.addPass(createLowerAffinePass);

// Strip out the debug info for the kernel.
modulePassManager.addPass(createStripDebugInfoPass());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
// CHECK-DAG: %[[C8192:.*]] = llvm.mlir.constant(8192 : index) : i64
//
// Match the interesting special registers.
// CHECK-DAG: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y : i32
// CHECK-DAG: %[[TID_Y:.*]] = nvvm.read.ptx.sreg.tid.y range <i32, 0, 2> : i32
// CHECK-DAG: %[[TID_Y_EXT:.*]] = llvm.sext %[[TID_Y]] : i32 to i64
// CHECK-DAG: %[[LANEID:.*]] = nvvm.read.ptx.sreg.laneid range <i32, 0, 32> : i32
// CHECK-DAG: %[[LANEID_EXT:.*]] = llvm.sext %[[LANEID]] : i32 to i64
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,11 @@ static void addMemRefLoweringPasses(OpPassManager &modulePassManager) {
/// Adds passes to perform the final SPIR-V conversion.
static void addSPIRVLoweringPasses(OpPassManager &modulePassManager) {
FunctionLikeNest(modulePassManager)
.addPass(createGPUPropagateDispatchSizeBoundsPass)
.addPass(createCanonicalizerPass)
.addPass(createCSEPass)
.addPass(createLowerAffinePass)
.addPass(IREE::Util::createOptimizeIntArithmeticPass)

// Lower ApplyScale before the i64 Emulation Pass so that new 64-bit ops
// are also emulated if not supported by the target.
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ iree_td_library(
"//compiler/src/iree/compiler/Dialect/Util/IR:td_files",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:FuncTdFiles",
"@llvm-project//mlir:InferIntRangeInterfaceTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:ViewLikeInterfaceTdFiles",
Expand Down Expand Up @@ -81,6 +82,7 @@ iree_compiler_cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferIntRangeInterface",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Parser",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/HAL/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ iree_cc_library(
MLIRFuncDialect
MLIRFunctionInterfaces
MLIRIR
MLIRInferIntRangeInterface
MLIRInferTypeOpInterface
MLIRMemRefDialect
MLIRParser
Expand Down
Loading

0 comments on commit 63cdc7d

Please sign in to comment.