Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen][GPU] Keep range and divisibility annotations on push constants #19348

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 128 additions & 30 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
Expand Down Expand Up @@ -225,15 +227,13 @@ getKernelArgMapping(Operation *funcOp) {
return mapBindingArgIndex;
}

class ConvertFunc : public ConvertToLLVMPattern {
class ConvertFunc : public ConvertOpToLLVMPattern<func::FuncOp> {
public:
explicit ConvertFunc(MLIRContext *context, LLVMTypeConverter &converter)
: ConvertToLLVMPattern(mlir::func::FuncOp::getOperationName(), context,
converter, 100) {}
explicit ConvertFunc(LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter, 100) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<func::FuncOp>(op);
FunctionType fnType = funcOp.getFunctionType();
(void)fnType;
if (!funcOp.isPublic())
Expand Down Expand Up @@ -302,13 +302,9 @@ class ConvertFunc : public ConvertToLLVMPattern {
}
};

class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
public:
explicit ConvertIREEBindingSubspanOp(MLIRContext *context,
LLVMTypeConverter &converter)
: ConvertToLLVMPattern(
IREE::HAL::InterfaceBindingSubspanOp::getOperationName(), context,
converter) {}
struct ConvertIREEBindingSubspanOp final
: public ConvertOpToLLVMPattern<IREE::HAL::InterfaceBindingSubspanOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

/// Checks all subspanOps with the same binding has readonly attribute
static bool checkAllSubspansReadonly(LLVM::LLVMFuncOp llvmFuncOp,
Expand All @@ -330,7 +326,7 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Bail until nested under an LLVMFuncOp.
auto llvmFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
Expand All @@ -341,8 +337,6 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
auto argMapping = getKernelArgMapping(llvmFuncOp);
Location loc = op->getLoc();
auto subspanOp = cast<IREE::HAL::InterfaceBindingSubspanOp>(op);
IREE::HAL::InterfaceBindingSubspanOpAdaptor adaptor(
operands, op->getAttrDictionary());
MemRefType memrefType =
llvm::dyn_cast<MemRefType>(subspanOp.getResult().getType());
mlir::BlockArgument llvmBufferArg =
Expand Down Expand Up @@ -453,15 +447,12 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
}
};

class ConvertIREEConstantOp : public ConvertToLLVMPattern {
public:
explicit ConvertIREEConstantOp(MLIRContext *context,
LLVMTypeConverter &converter)
: ConvertToLLVMPattern(
IREE::HAL::InterfaceConstantLoadOp::getOperationName(), context,
converter) {}
struct ConvertIREEConstantOp final
: public ConvertOpToLLVMPattern<IREE::HAL::InterfaceConstantLoadOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(IREE::HAL::InterfaceConstantLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Bail until nested under an LLVMFuncOp.
auto llvmFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
Expand All @@ -470,9 +461,8 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern {
assert(llvmFuncOp.getNumArguments() > 0);

auto argMapping = getKernelArgMapping(llvmFuncOp);
auto ireeConstantOp = cast<IREE::HAL::InterfaceConstantLoadOp>(op);
mlir::BlockArgument llvmBufferArg = llvmFuncOp.getArgument(
argMapping.size() + ireeConstantOp.getOrdinal().getZExtValue());
argMapping.size() + op.getOrdinal().getZExtValue());
assert(llvmBufferArg.getType().isInteger(32));

// Push constants are never `undef`, annotate that here, just as with
Expand All @@ -481,7 +471,54 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern {
LLVM::LLVMDialect::getNoUndefAttrName(),
rewriter.getUnitAttr());

Type dstType = getTypeConverter()->convertType(ireeConstantOp.getType());
// If the constant has non-trivial assumptions placed on it about
// its min and max values or divisibility, use that information to
// annotate the corresponding arguments. The hasOneUse() check prevents us
// from applying assumptions that don't hold at all usage sites.
if (op.getResult().hasOneUse()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the one use condition?

OpOperand *operand = op.getResult().getUses().begin().getOperand();
auto assumeOp = dyn_cast<IREE::Util::AssumeIntOp>(operand->getOwner());
if (assumeOp) {
unsigned opIdx = operand->getOperandNumber();
auto [min, max] = assumeOp.getUnionedUnsignedRange(opIdx);

if (min.has_value() && max.has_value()) {
assert(*min <= std::numeric_limits<uint32_t>::max() &&
"Push-constant's maximum value can't be outside 32 bits, but "
"this is assumed");
// Note: LLVM's range(iN lb, ub) is [lb, ub), while MLIR's is [lb,
// ub], so we add 1 to the upper bound.
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
LLVM::LLVMDialect::getRangeAttrName(),
rewriter.getAttr<LLVM::ConstantRangeAttr>(
APInt(32, *min), APInt(32, *max) + 1));
}

auto divisibility = assumeOp.getUnionedUnsignedDivisor(opIdx);

auto makeI32Const = [&](uint32_t val) -> Value {
return rewriter.create<LLVM::ConstantOp>(
assumeOp.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(val));
};
if (divisibility.has_value() && *divisibility > 1) {
Location loc = assumeOp.getLoc();
assert(*divisibility <= std::numeric_limits<uint32_t>::max() &&
"push constant shouldn't be statically divisible by a value "
"it can't hold");
Value knownDivisibleBy = makeI32Const(*divisibility);
// This'll almost always become an and
Value lowPart = rewriter.create<LLVM::URemOp>(loc, llvmBufferArg,
knownDivisibleBy);
Value zero = makeI32Const(0);
Value isEvenlyDivided = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::eq, lowPart, zero);
rewriter.create<LLVM::AssumeOp>(loc, isEvenlyDivided);
}
}
}

Type dstType = getTypeConverter()->convertType(op.getType());
// llvm.zext requires that the result type has a larger bitwidth.
if (dstType == llvmBufferArg.getType()) {
rewriter.replaceOp(op, llvmBufferArg);
Expand Down Expand Up @@ -513,14 +550,75 @@ struct HALInterfaceWorkgroupOpsConverter final
}
};

struct ConvertIREEUtilAssumeIntOp final
: public ConvertOpToLLVMPattern<IREE::Util::AssumeIntOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(IREE::Util::AssumeIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Bail until nested under an LLVMFuncOp.
auto llvmFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
if (!llvmFuncOp)
return failure();

Location loc = op.getLoc();
auto updateConds = [&](std::optional<Value> &conds, Value cond) {
if (!conds)
conds = cond;
else
conds = rewriter.create<LLVM::AndOp>(loc, *conds, cond);
};
// Materialize the assumptions that aren't atteched directly to arguments
// in order to account for the fact that i64 inputs get passed in as a pair
// of i32 constants.
for (auto [idx, mlirVal, llvmVal] :
llvm::enumerate(op.getOperands(), adaptor.getOperands())) {
if (mlirVal.getDefiningOp<IREE::HAL::InterfaceConstantLoadOp>())
continue;
std::optional<Value> conds;
Type type = llvmVal.getType();
auto [min, max] = op.getUnionedUnsignedRange(idx);
// This should be a range() bundle but LLVM doesn't understand those yet.
if (min.has_value() && *min > 0) {
Value minConst = createIndexAttrConstant(rewriter, loc, type, *min);
Value minCond = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::uge, llvmVal, minConst);
updateConds(conds, minCond);
}
if (max.has_value()) {
Value maxConst = createIndexAttrConstant(rewriter, loc, type, *max);
Value maxCond = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::ule, llvmVal, maxConst);
updateConds(conds, maxCond);
}
std::optional<uint64_t> divisor = op.getUnionedUnsignedDivisor(idx);
if (divisor && *divisor > 1) {
Value divisorConst =
createIndexAttrConstant(rewriter, loc, type, *divisor);
Value remainder =
rewriter.create<LLVM::URemOp>(loc, llvmVal, divisorConst);
Value zero = createIndexAttrConstant(rewriter, loc, type, 0);
Value divisorCond = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::eq, remainder, zero);
updateConds(conds, divisorCond);
}

if (conds.has_value()) {
rewriter.create<LLVM::AssumeOp>(loc, *conds);
}
}
rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
};
} // namespace

void populateLLVMConversionPatterns(MLIRContext *context,
RewritePatternSet &patterns,
LLVMTypeConverter &converter) {
patterns
.insert<ConvertFunc, ConvertIREEBindingSubspanOp, ConvertIREEConstantOp>(
context, converter);
patterns.insert<ConvertFunc, ConvertIREEBindingSubspanOp,
ConvertIREEConstantOp, ConvertIREEUtilAssumeIntOp>(converter);
}

void populateScalarizeMathOps(RewritePatternSet &patterns) {
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
modulePassManager.addPass(createStripDebugInfoPass());
// Cast address spaces of all function arguments to generic.
modulePassManager.addPass(createLLVMGPUCastAddressSpaceFunctionPass());
modulePassManager.addPass(IREE::Util::createDropCompilerHintsPass());
modulePassManager.addPass(IREE::Util::createDropCompilerHintsPass(
IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true}));

if (forROCDL) {
// convert to ROCDL.
Expand Down Expand Up @@ -1273,7 +1274,9 @@ void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) {
.addPass(createVerifyWorkgroupDistributionPass);
}
variantPassManager.addPass(createReconcileTranslationInfoPass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass(
IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true}));
;

addLowerToLLVMGPUPasses(variantPassManager.nest<ModuleOp>(),
/*forROCDL=*/true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ hal.executable @abs_ex_dispatch_0 {
// CHECK: llvm.store %[[FADD]], %[[ADDR]] : f32, !llvm.ptr
// -----

#pipeline_layout = #hal.pipeline.layout<constants = 4, bindings = [
#pipeline_layout = #hal.pipeline.layout<constants = 5, bindings = [
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>,
#hal.pipeline.binding<storage_buffer>
Expand All @@ -54,13 +54,24 @@ hal.executable @abs_dynamic {
%c3 = arith.constant 3 : index
%c5 = arith.constant 5 : index
%c7 = arith.constant 7 : index
%o = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
%d0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
%d1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
%d2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%o) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>{%d0, %d1, %d2}
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<?x?x?xi32>{%d0, %d1, %d2}
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<?x?x?xf32>{%d0, %d1, %d2}
%c32_i64 = arith.constant 32 : i64
// This method for passing in 64-bit values is taken from a Llama dispatch
// and added here to test integer range assumption preservation.
%o.low = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : i32
%o.high = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : i32
%o.low.ext = arith.extui %o.low : i32 to i64
%o.high.ext = arith.extui %o.high : i32 to i64
%o.high.shift = arith.shli %o.high.ext, %c32_i64 : i64
%o.i64 = arith.ori %o.low.ext, %o.high.shift : i64
%o.index = arith.index_castui %o.i64 : i64 to index
%o.tagged = util.assume.int %o.index[<umin = 5185728, umax = 4438911803328>] : index
%d0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
%d1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
%d1.tagged = util.assume.int %d1[<umin = 0, umax = 0>, <umin = 4096, umax = 4096, udiv = 4096>] : index
%d2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(4) : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%o.tagged) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>{%d0, %d1.tagged, %d2}
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<?x?x?xi32>{%d0, %d1.tagged, %d2}
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<?x?x?xf32>{%d0, %d1.tagged, %d2}
%9 = memref.load %0[%c3, %c5, %c7] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>
%10 = memref.load %1[%c3, %c5, %c7] : memref<?x?x?xi32>
%11 = arith.sitofp %10 : i32 to f32
Expand All @@ -78,10 +89,26 @@ hal.executable @abs_dynamic {
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32 {llvm.noundef},
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32 {llvm.noundef},
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32 {llvm.noundef},
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32 {llvm.noundef})
// CHECK-DAG: %[[OFFSET:.+]] = llvm.zext %[[ARG3]] : i32 to i64
// CHECK-DAG: %[[D1:.+]] = llvm.zext %[[ARG5]] : i32 to i64
// CHECK-DAG: %[[D2:.+]] = llvm.zext %[[ARG6]] : i32 to i64
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32 {llvm.noundef, llvm.range = #llvm.constant_range<i32, 0, 4097>},
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: i32 {llvm.noundef})
// CHECK-DAG: %[[C4096_i32:.+]] = llvm.mlir.constant(4096 : i32) : i32
// CHECK-DAG: %[[C0_i32:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: %[[C32_i64:.+]] = llvm.mlir.constant(32 : i64) : i64
// CHECK-DAG: %[[C5185728_i64:.+]] = llvm.mlir.constant(5185728 : index) : i64
// CHECK-DAG: %[[C4438911803328_i64:.+]] = llvm.mlir.constant(4438911803328 : index) : i64
// CHECK-DAG: %[[OFFSET_LO:.+]] = llvm.zext %[[ARG3]] : i32 to i64
// CHECK-DAG: %[[OFFSET_HI:.+]] = llvm.zext %[[ARG4]] : i32 to i64
// CHECK-DAG: %[[OFFSET_HI_SHL:.+]] = llvm.shl %[[OFFSET_HI]], %[[C32_i64]] : i64
// CHECK-DAG: %[[OFFSET:.+]] = llvm.or %[[OFFSET_LO]], %[[OFFSET_HI_SHL]] : i64
// CHECK-DAG: %[[MIN_COND:.+]] = llvm.icmp "uge" %[[OFFSET]], %[[C5185728_i64]] : i64
// CHECK-DAG: %[[MAX_COND:.+]] = llvm.icmp "ule" %[[OFFSET]], %[[C4438911803328_i64]] : i64
// CHECK-DAG: %[[OFFSET_COND:.+]] = llvm.and %[[MIN_COND]], %[[MAX_COND]] : i1
// CHECK-DAG: llvm.intr.assume %[[OFFSET_COND]]
// CHECK-DAG: %[[D1:.+]] = llvm.zext %[[ARG6]] : i32 to i64
// CHECK-DAG: %[[ARG6_UREM:.+]] = llvm.urem %[[ARG6]], %[[C4096_i32]] : i32
// CHECK-DAG: %[[ARG6_CMP:.+]] = llvm.icmp "eq" %[[ARG6_UREM]], %[[C0_i32]]
// CHECK-DAG: llvm.intr.assume %[[ARG6_CMP]]
// CHECK-DAG: %[[D2:.+]] = llvm.zext %[[ARG7]] : i32 to i64
// CHECK: %[[GEP1:.+]] = llvm.getelementptr %[[ARG1]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[GEP:.+]] = llvm.getelementptr %[[GEP1]][%{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[LOAD:.+]] = llvm.load %[[GEP]] : !llvm.ptr -> f32
Expand Down
Loading
Loading