Skip to content

Commit

Permalink
[Codegen][GPU] Keep range and divisibility annotations on push constants
Browse files Browse the repository at this point in the history
IREE has useful information indicating the minimum values, maximum
values, and divisibility of push constants encoded in util.assume.int
ops. This information was being thrown away when, in some cases, it
could be profitably communicated to compiler backends.

This commit:
- Changes drop-compiler-hints to have an option that keeps
util.assume.int ops
- Adds rewrites to the LLVMGPU and SPIRV lowerings that erase these
ops
- Changes the rewrites for hal.interface.constant.load to look for
util.assume.int ops in the input IR and use them to add annotations to
the loaded constant
  - In the LLVM case, these annotations take the form of a
    `range(iN lb, ub)` attribute on the corresponding function
    parameter
  - For SPIR-V, these annotations are calls to KHR_AssumeTrue if the
  capability is avaliable
- This commit also adds a case for integer assumption operations to
the SPIR-V i64 emulation pass

While I was here, I converted some of the LLVM lowering patterns to
use ConvertOpToLLVMPattern<>.
  • Loading branch information
krzysz00 committed Dec 12, 2024
1 parent 27742f6 commit d567cd2
Show file tree
Hide file tree
Showing 10 changed files with 320 additions and 52 deletions.
110 changes: 80 additions & 30 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "iree/compiler/Dialect/HAL/IR/HALOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
Expand Down Expand Up @@ -225,15 +227,13 @@ getKernelArgMapping(Operation *funcOp) {
return mapBindingArgIndex;
}

class ConvertFunc : public ConvertToLLVMPattern {
class ConvertFunc : public ConvertOpToLLVMPattern<func::FuncOp> {
public:
explicit ConvertFunc(MLIRContext *context, LLVMTypeConverter &converter)
: ConvertToLLVMPattern(mlir::func::FuncOp::getOperationName(), context,
converter, 100) {}
explicit ConvertFunc(LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter, 100) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<func::FuncOp>(op);
FunctionType fnType = funcOp.getFunctionType();
(void)fnType;
if (!funcOp.isPublic())
Expand Down Expand Up @@ -302,13 +302,9 @@ class ConvertFunc : public ConvertToLLVMPattern {
}
};

class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
public:
explicit ConvertIREEBindingSubspanOp(MLIRContext *context,
LLVMTypeConverter &converter)
: ConvertToLLVMPattern(
IREE::HAL::InterfaceBindingSubspanOp::getOperationName(), context,
converter) {}
struct ConvertIREEBindingSubspanOp final
: public ConvertOpToLLVMPattern<IREE::HAL::InterfaceBindingSubspanOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

/// Checks all subspanOps with the same binding has readonly attribute
static bool checkAllSubspansReadonly(LLVM::LLVMFuncOp llvmFuncOp,
Expand All @@ -330,7 +326,7 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
}

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Bail until nested under an LLVMFuncOp.
auto llvmFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
Expand All @@ -341,8 +337,6 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
auto argMapping = getKernelArgMapping(llvmFuncOp);
Location loc = op->getLoc();
auto subspanOp = cast<IREE::HAL::InterfaceBindingSubspanOp>(op);
IREE::HAL::InterfaceBindingSubspanOpAdaptor adaptor(
operands, op->getAttrDictionary());
MemRefType memrefType =
llvm::dyn_cast<MemRefType>(subspanOp.getResult().getType());
mlir::BlockArgument llvmBufferArg =
Expand Down Expand Up @@ -453,15 +447,12 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern {
}
};

class ConvertIREEConstantOp : public ConvertToLLVMPattern {
public:
explicit ConvertIREEConstantOp(MLIRContext *context,
LLVMTypeConverter &converter)
: ConvertToLLVMPattern(
IREE::HAL::InterfaceConstantLoadOp::getOperationName(), context,
converter) {}
struct ConvertIREEConstantOp final
: public ConvertOpToLLVMPattern<IREE::HAL::InterfaceConstantLoadOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
matchAndRewrite(IREE::HAL::InterfaceConstantLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Bail until nested under an LLVMFuncOp.
auto llvmFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
Expand All @@ -470,9 +461,8 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern {
assert(llvmFuncOp.getNumArguments() > 0);

auto argMapping = getKernelArgMapping(llvmFuncOp);
auto ireeConstantOp = cast<IREE::HAL::InterfaceConstantLoadOp>(op);
mlir::BlockArgument llvmBufferArg = llvmFuncOp.getArgument(
argMapping.size() + ireeConstantOp.getOrdinal().getZExtValue());
argMapping.size() + op.getOrdinal().getZExtValue());
assert(llvmBufferArg.getType().isInteger(32));

// Push constants are never `undef`, annotate that here, just as with
Expand All @@ -481,7 +471,53 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern {
LLVM::LLVMDialect::getNoUndefAttrName(),
rewriter.getUnitAttr());

Type dstType = getTypeConverter()->convertType(ireeConstantOp.getType());
// If the constant has non-trivial assumptions placed on it about
// its min and max values or divisibility, use that information to
// annotate the corresponding arguments.
if (op.getResult().hasOneUse()) {
OpOperand *operand = op.getResult().getUses().begin().getOperand();
auto assumeOp = dyn_cast<IREE::Util::AssumeIntOp>(operand->getOwner());
if (assumeOp) {
unsigned opIdx = operand->getOperandNumber();
auto [min, max] = assumeOp.getUnionedUnsignedRange(opIdx);

if (min.has_value() && max.has_value()) {
assert(*min <= std::numeric_limits<uint32_t>::max() &&
"Push-constant's maximum value can't be outside 32 bits, but "
"this is assumed");
// Note: LLVM's range(iN lb, ub) is [lb, ub), while MLIR's is [lb,
// ub], so we add 1 to the upper bound.
llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(),
LLVM::LLVMDialect::getRangeAttrName(),
rewriter.getAttr<LLVM::ConstantRangeAttr>(
APInt(32, *min), APInt(32, *max) + 1));
}

auto divisibility = assumeOp.getUnionedUnsignedDivisor(opIdx);

auto makeI32Const = [&](uint32_t val) -> Value {
return rewriter.create<LLVM::ConstantOp>(
assumeOp.getLoc(), rewriter.getI32Type(),
rewriter.getI32IntegerAttr(val));
};
if (divisibility.has_value() && *divisibility > 1) {
Location loc = assumeOp.getLoc();
assert(*divisibility <= std::numeric_limits<uint32_t>::max() &&
"push constant shouldn't be statically divisible by a value "
"it can't hold");
Value knownDivisibleBy = makeI32Const(*divisibility);
// This'll almost always become an and
Value lowPart = rewriter.create<LLVM::URemOp>(loc, llvmBufferArg,
knownDivisibleBy);
Value zero = makeI32Const(0);
Value isEvenlyDivided = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::ICmpPredicate::eq, lowPart, zero);
rewriter.create<LLVM::AssumeOp>(loc, isEvenlyDivided);
}
}
}

Type dstType = getTypeConverter()->convertType(op.getType());
// llvm.zext requires that the result type has a larger bitwidth.
if (dstType == llvmBufferArg.getType()) {
rewriter.replaceOp(op, llvmBufferArg);
Expand Down Expand Up @@ -510,14 +546,28 @@ struct HALInterfaceWorkgroupOpsConverter final
}
};

struct ConvertIREEUtilAssumeIntOp final
: public ConvertOpToLLVMPattern<IREE::Util::AssumeIntOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(IREE::Util::AssumeIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Bail until nested under an LLVMFuncOp.
auto llvmFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
if (!llvmFuncOp)
return failure();
rewriter.replaceOp(op, adaptor.getOperands());
return success();
}
};
} // namespace

void populateLLVMConversionPatterns(MLIRContext *context,
RewritePatternSet &patterns,
LLVMTypeConverter &converter) {
patterns
.insert<ConvertFunc, ConvertIREEBindingSubspanOp, ConvertIREEConstantOp>(
context, converter);
patterns.insert<ConvertFunc, ConvertIREEBindingSubspanOp,
ConvertIREEConstantOp, ConvertIREEUtilAssumeIntOp>(converter);
}

void populateScalarizeMathOps(RewritePatternSet &patterns) {
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
modulePassManager.addPass(createStripDebugInfoPass());
// Cast address spaces of all function arguments to generic.
modulePassManager.addPass(createLLVMGPUCastAddressSpaceFunctionPass());
modulePassManager.addPass(IREE::Util::createDropCompilerHintsPass());
modulePassManager.addPass(IREE::Util::createDropCompilerHintsPass(
IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true}));

if (forROCDL) {
// convert to ROCDL.
Expand Down Expand Up @@ -1267,7 +1268,9 @@ void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) {
.addPass(createVerifyWorkgroupDistributionPass);
}
variantPassManager.addPass(createReconcileTranslationInfoPass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass(
IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true}));
;

addLowerToLLVMGPUPasses(variantPassManager.nest<ModuleOp>(),
/*forROCDL=*/true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ hal.executable @abs_dynamic {
%c5 = arith.constant 5 : index
%c7 = arith.constant 7 : index
%o = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index
%o.tagged = util.assume.int %o[<umin = 0, umax = 0>, <umin = 4096, umax = 4096, udiv = 4096>] : index
%d0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index
%d1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index
%d2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%o) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>{%d0, %d1, %d2}
%0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%o.tagged) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>{%d0, %d1, %d2}
%1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref<?x?x?xi32>{%d0, %d1, %d2}
%2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref<?x?x?xf32>{%d0, %d1, %d2}
%9 = memref.load %0[%c3, %c5, %c7] : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>
Expand All @@ -75,10 +76,15 @@ hal.executable @abs_dynamic {
// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef},
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32 {llvm.noundef},
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32 {llvm.noundef, llvm.range = #llvm.constant_range<i32, 0, 4097>},
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32 {llvm.noundef},
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32 {llvm.noundef},
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32 {llvm.noundef})
// CHECK-DAG: %[[C4096_i32:.+]] = llvm.mlir.constant(4096 : i32) : i32
// CHECK-DAG: %[[C0_i32:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK-DAG: %[[ARG3_UREM:.+]] = llvm.urem %[[ARG3]], %[[C4096_i32]] : i32
// CHECK-DAG: %[[ARG3_CMP:.+]] = llvm.icmp "eq" %[[ARG3_UREM]], %[[C0_i32]]
// CHECK: llvm.intr.assume %[[ARG3_CMP]]
// CHECK-DAG: %[[OFFSET:.+]] = llvm.zext %[[ARG3]] : i32 to i64
// CHECK-DAG: %[[D1:.+]] = llvm.zext %[[ARG5]] : i32 to i64
// CHECK-DAG: %[[D2:.+]] = llvm.zext %[[ARG6]] : i32 to i64
Expand Down
67 changes: 63 additions & 4 deletions compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,12 @@ namespace {
/// ops to load from a global variable representing the push constant storage.
struct HALInterfaceLoadConstantConverter final
: OpConversionPattern<IREE::HAL::InterfaceConstantLoadOp> {
using OpConversionPattern::OpConversionPattern;
bool supportsAssume = false;

HALInterfaceLoadConstantConverter(TypeConverter &typeConverter,
MLIRContext *context, bool supportsAssume)
: OpConversionPattern(typeConverter, context),
supportsAssume(supportsAssume) {}

LogicalResult
matchAndRewrite(IREE::HAL::InterfaceConstantLoadOp loadOp, OpAdaptor adaptor,
Expand All @@ -299,11 +304,61 @@ struct HALInterfaceLoadConstantConverter final
Value value = spirv::getPushConstantValue(loadOp, elementCount, index,
i32Type, rewriter);

if (loadOp.getResult().hasOneUse() && supportsAssume) {
OpOperand *operand = loadOp.getResult().getUses().begin().getOperand();
auto assumeOp = dyn_cast<IREE::Util::AssumeIntOp>(operand->getOwner());
if (assumeOp) {
Location loc = assumeOp.getLoc();
unsigned opIdx = operand->getOperandNumber();

auto [min, max] = assumeOp.getUnionedUnsignedRange(opIdx);
if (min.has_value() && max.has_value()) {
Value minConst = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(*min));
Value maxConst = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(*max));
Value minBound =
rewriter.create<spirv::UGreaterThanEqualOp>(loc, value, minConst);
rewriter.create<spirv::KHRAssumeTrueOp>(loc, minBound);
Value maxBound =
rewriter.create<spirv::ULessThanEqualOp>(loc, value, maxConst);
rewriter.create<spirv::KHRAssumeTrueOp>(loc, maxBound);
}

std::optional<uint64_t> divisibility =
assumeOp.getUnionedUnsignedDivisor(opIdx);
if (divisibility.has_value() && *divisibility > 1) {
Value divisor = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(*divisibility));
Value zero = rewriter.create<spirv::ConstantOp>(
loc, i32Type, rewriter.getI32IntegerAttr(0));
Value lowPart = rewriter.create<spirv::UModOp>(loc, value, divisor);
Value dividesExactly =
rewriter.create<spirv::IEqualOp>(loc, lowPart, zero);
rewriter.create<spirv::KHRAssumeTrueOp>(loc, dividesExactly);
}
}
}

rewriter.replaceOp(loadOp, value);
return success();
}
};

/// A pattern to convert util.assume.int into a noop, since we're using it
/// to annotate push constants.
struct UtilAssumeIntConverter final
: OpConversionPattern<IREE::Util::AssumeIntOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(IREE::Util::AssumeIntOp assumeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOp(assumeOp, adaptor.getOperands());
return success();
}
};

/// A pattern to convert hal.interface.workgroup.id/count/size into
/// corresponding SPIR-V Builtin ops.
template <typename InterfaceOpTy, spirv::BuiltIn builtin>
Expand Down Expand Up @@ -615,6 +670,8 @@ void ConvertToSPIRVPass::runOnOperation() {
return signalPassFailure();
}

bool supportsAssume = targetEnv.allows(spirv::Capability::ExpectAssumeKHR);

SPIRVConversionOptions options = {};
options.use64bitIndex = use64bitIndex;

Expand Down Expand Up @@ -660,14 +717,16 @@ void ConvertToSPIRVPass::runOnOperation() {

// Add IREE HAL interface op conversions.
patterns.add<
HALInterfaceLoadConstantConverter,
HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupIDOp,
spirv::BuiltIn::WorkgroupId>,
HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupSizeOp,
spirv::BuiltIn::WorkgroupSize>,
HALInterfaceWorkgroupOpsConverter<IREE::HAL::InterfaceWorkgroupCountOp,
spirv::BuiltIn::NumWorkgroups>>(
typeConverter, context);
spirv::BuiltIn::NumWorkgroups>,
UtilAssumeIntConverter>(typeConverter, context);

patterns.add<HALInterfaceLoadConstantConverter>(typeConverter, context,
supportsAssume);

// Performs a prelimiary step to analyze all hal.interface.binding.subspan ops
// and creates spirv.GlobalVariables.
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &variantPassManager) {
addMemRefLoweringPasses(modulePassManager);
}
variantPassManager.addPass(createReconcileTranslationInfoPass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass());
variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass(
IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true}));

{
OpPassManager &modulePassManager = variantPassManager.nest<ModuleOp>();
Expand Down
Loading

0 comments on commit d567cd2

Please sign in to comment.