From e26a9661bfcba6ea37855bc2d118f863d0c1b315 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Wed, 27 Nov 2024 20:34:27 +0000 Subject: [PATCH] [Codegen][GPU] Keep range and divisibility annotations on push constants IREE has useful information indicating the minimum values, maximum values, and divisibility of push constants encoded in util.assume.int ops. This information was being thrown away when, in some cases, it could be profitably communicated to compiler backends. This commit: - Changes drop-compiler-hints to have an option that keeps util.assume.int ops - Adds rewrites to the LLVMGPU and SPIRV lowerings that erase these ops - Changes the rewrites for hal.interface.constant.load to look for util.assume.int ops in the input IR and use them to add annotations to the loaded constant - In the LLVM case, these annotations take the form of a `range(iN lb, ub)` attribute on the corresponding function parameter - For SPIR-V, these annotations are calls to KHR_AssumeTrue if the capability is avaliable - This commit also adds a case for integer assumption operations to the SPIR-V i64 emulation pass While I was here, I converted some of the LLVM lowering patterns to use ConvertOpToLLVMPattern<>. --- .../Codegen/LLVMGPU/ConvertToLLVM.cpp | 110 +++++++++++++----- .../iree/compiler/Codegen/LLVMGPU/Passes.cpp | 7 +- .../Codegen/LLVMGPU/test/convert_to_nvvm.mlir | 10 +- .../Codegen/SPIRV/ConvertToSPIRVPass.cpp | 67 ++++++++++- .../iree/compiler/Codegen/SPIRV/Passes.cpp | 3 +- .../Codegen/SPIRV/SPIRVEmulateI64.cpp | 78 +++++++++++-- .../Codegen/SPIRV/test/convert_to_spirv.mlir | 39 +++++++ .../Codegen/SPIRV/test/emulate_i64.mlir | 34 ++++++ .../Util/Transforms/DropCompilerHints.cpp | 11 +- .../Dialect/Util/Transforms/Passes.td | 13 +++ 10 files changed, 320 insertions(+), 52 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index c056d44538bb..6d6b4c19e5fb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -9,7 +9,9 @@ #include "iree/compiler/Codegen/LLVMGPU/Passes.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" @@ -225,15 +227,13 @@ getKernelArgMapping(Operation *funcOp) { return mapBindingArgIndex; } -class ConvertFunc : public ConvertToLLVMPattern { +class ConvertFunc : public ConvertOpToLLVMPattern { public: - explicit ConvertFunc(MLIRContext *context, LLVMTypeConverter &converter) - : ConvertToLLVMPattern(mlir::func::FuncOp::getOperationName(), context, - converter, 100) {} + explicit ConvertFunc(LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter, 100) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto funcOp = cast(op); FunctionType fnType = funcOp.getFunctionType(); (void)fnType; if (!funcOp.isPublic()) @@ -302,13 +302,9 @@ class ConvertFunc : public ConvertToLLVMPattern { } }; -class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern { -public: - explicit ConvertIREEBindingSubspanOp(MLIRContext *context, - LLVMTypeConverter &converter) - : ConvertToLLVMPattern( - IREE::HAL::InterfaceBindingSubspanOp::getOperationName(), context, - converter) {} +struct ConvertIREEBindingSubspanOp final + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; /// Checks all subspanOps with the same binding has readonly attribute static bool checkAllSubspansReadonly(LLVM::LLVMFuncOp llvmFuncOp, @@ -330,7 +326,7 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern { } LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); @@ -341,8 +337,6 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern { auto argMapping = getKernelArgMapping(llvmFuncOp); Location loc = op->getLoc(); auto subspanOp = cast(op); - IREE::HAL::InterfaceBindingSubspanOpAdaptor adaptor( - operands, op->getAttrDictionary()); MemRefType memrefType = llvm::dyn_cast(subspanOp.getResult().getType()); mlir::BlockArgument llvmBufferArg = @@ -453,15 +447,12 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern { } }; -class ConvertIREEConstantOp : public ConvertToLLVMPattern { -public: - explicit ConvertIREEConstantOp(MLIRContext *context, - LLVMTypeConverter &converter) - : ConvertToLLVMPattern( - IREE::HAL::InterfaceConstantLoadOp::getOperationName(), context, - converter) {} +struct ConvertIREEConstantOp final + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(IREE::HAL::InterfaceConstantLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); @@ -470,9 +461,8 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern { assert(llvmFuncOp.getNumArguments() > 0); auto argMapping = getKernelArgMapping(llvmFuncOp); - auto ireeConstantOp = cast(op); mlir::BlockArgument llvmBufferArg = llvmFuncOp.getArgument( - argMapping.size() + ireeConstantOp.getOrdinal().getZExtValue()); + argMapping.size() + op.getOrdinal().getZExtValue()); assert(llvmBufferArg.getType().isInteger(32)); // Push constants are never `undef`, annotate that here, just as with @@ -481,7 +471,53 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern { LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr()); - Type dstType = getTypeConverter()->convertType(ireeConstantOp.getType()); + // If the constant has non-trivial assumptions placed on it about + // its min and max values or divisibility, use that information to + // annotate the corresponding arguments. + if (op.getResult().hasOneUse()) { + OpOperand *operand = op.getResult().getUses().begin().getOperand(); + auto assumeOp = dyn_cast(operand->getOwner()); + if (assumeOp) { + unsigned opIdx = operand->getOperandNumber(); + auto [min, max] = assumeOp.getUnionedUnsignedRange(opIdx); + + if (min.has_value() && max.has_value()) { + assert(*min <= std::numeric_limits::max() && + "Push-constant's maximum value can't be outside 32 bits, but " + "this is assumed"); + // Note: LLVM's range(iN lb, ub) is [lb, ub), while MLIR's is [lb, + // ub], so we add 1 to the upper bound. + llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(), + LLVM::LLVMDialect::getRangeAttrName(), + rewriter.getAttr( + APInt(32, *min), APInt(32, *max) + 1)); + } + + auto divisibility = assumeOp.getUnionedUnsignedDivisor(opIdx); + + auto makeI32Const = [&](uint32_t val) -> Value { + return rewriter.create( + assumeOp.getLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(val)); + }; + if (divisibility.has_value() && *divisibility > 1) { + Location loc = assumeOp.getLoc(); + assert(*divisibility <= std::numeric_limits::max() && + "push constant shouldn't be statically divisible by a value " + "it can't hold"); + Value knownDivisibleBy = makeI32Const(*divisibility); + // This'll almost always become an and + Value lowPart = rewriter.create(loc, llvmBufferArg, + knownDivisibleBy); + Value zero = makeI32Const(0); + Value isEvenlyDivided = rewriter.create( + loc, LLVM::ICmpPredicate::eq, lowPart, zero); + rewriter.create(loc, isEvenlyDivided); + } + } + } + + Type dstType = getTypeConverter()->convertType(op.getType()); // llvm.zext requires that the result type has a larger bitwidth. if (dstType == llvmBufferArg.getType()) { rewriter.replaceOp(op, llvmBufferArg); @@ -510,14 +546,28 @@ struct HALInterfaceWorkgroupOpsConverter final } }; +struct ConvertIREEUtilAssumeIntOp final + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(IREE::Util::AssumeIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Bail until nested under an LLVMFuncOp. + auto llvmFuncOp = op->getParentOfType(); + if (!llvmFuncOp) + return failure(); + rewriter.replaceOp(op, adaptor.getOperands()); + return success(); + } +}; } // namespace void populateLLVMConversionPatterns(MLIRContext *context, RewritePatternSet &patterns, LLVMTypeConverter &converter) { - patterns - .insert( - context, converter); + patterns.insert(converter); } void populateScalarizeMathOps(RewritePatternSet &patterns) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 53e49efbf66a..ee72a8dbe22a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1138,7 +1138,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager, modulePassManager.addPass(createStripDebugInfoPass()); // Cast address spaces of all function arguments to generic. modulePassManager.addPass(createLLVMGPUCastAddressSpaceFunctionPass()); - modulePassManager.addPass(IREE::Util::createDropCompilerHintsPass()); + modulePassManager.addPass(IREE::Util::createDropCompilerHintsPass( + IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true})); if (forROCDL) { // convert to ROCDL. @@ -1267,7 +1268,9 @@ void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) { .addPass(createVerifyWorkgroupDistributionPass); } variantPassManager.addPass(createReconcileTranslationInfoPass()); - variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass()); + variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass( + IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true})); + ; addLowerToLLVMGPUPasses(variantPassManager.nest(), /*forROCDL=*/true); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir index bd876a377857..6a9c4e82490b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir @@ -55,10 +55,11 @@ hal.executable @abs_dynamic { %c5 = arith.constant 5 : index %c7 = arith.constant 7 : index %o = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %o.tagged = util.assume.int %o[, ] : index %d0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index %d1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index %d2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%o) : memref>{%d0, %d1, %d2} + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%o.tagged) : memref>{%d0, %d1, %d2} %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref{%d0, %d1, %d2} %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref{%d0, %d1, %d2} %9 = memref.load %0[%c3, %c5, %c7] : memref> @@ -75,10 +76,15 @@ hal.executable @abs_dynamic { // CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef}, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef}, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef}, -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32 {llvm.noundef}, +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32 {llvm.noundef, llvm.range = #llvm.constant_range}, // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32 {llvm.noundef}, // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32 {llvm.noundef}, // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32 {llvm.noundef}) +// CHECK-DAG: %[[C4096_i32:.+]] = llvm.mlir.constant(4096 : i32) : i32 +// CHECK-DAG: %[[C0_i32:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-DAG: %[[ARG3_UREM:.+]] = llvm.urem %[[ARG3]], %[[C4096_i32]] : i32 +// CHECK-DAG: %[[ARG3_CMP:.+]] = llvm.icmp "eq" %[[ARG3_UREM]], %[[C0_i32]] +// CHECK: llvm.intr.assume %[[ARG3_CMP]] // CHECK-DAG: %[[OFFSET:.+]] = llvm.zext %[[ARG3]] : i32 to i64 // CHECK-DAG: %[[D1:.+]] = llvm.zext %[[ARG5]] : i32 to i64 // CHECK-DAG: %[[D2:.+]] = llvm.zext %[[ARG6]] : i32 to i64 diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index f4099c450849..58c6bf9843de 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp @@ -277,7 +277,12 @@ namespace { /// ops to load from a global variable representing the push constant storage. struct HALInterfaceLoadConstantConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + bool supportsAssume = false; + + HALInterfaceLoadConstantConverter(TypeConverter &typeConverter, + MLIRContext *context, bool supportsAssume) + : OpConversionPattern(typeConverter, context), + supportsAssume(supportsAssume) {} LogicalResult matchAndRewrite(IREE::HAL::InterfaceConstantLoadOp loadOp, OpAdaptor adaptor, @@ -299,11 +304,61 @@ struct HALInterfaceLoadConstantConverter final Value value = spirv::getPushConstantValue(loadOp, elementCount, index, i32Type, rewriter); + if (loadOp.getResult().hasOneUse() && supportsAssume) { + OpOperand *operand = loadOp.getResult().getUses().begin().getOperand(); + auto assumeOp = dyn_cast(operand->getOwner()); + if (assumeOp) { + Location loc = assumeOp.getLoc(); + unsigned opIdx = operand->getOperandNumber(); + + auto [min, max] = assumeOp.getUnionedUnsignedRange(opIdx); + if (min.has_value() && max.has_value()) { + Value minConst = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(*min)); + Value maxConst = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(*max)); + Value minBound = + rewriter.create(loc, value, minConst); + rewriter.create(loc, minBound); + Value maxBound = + rewriter.create(loc, value, maxConst); + rewriter.create(loc, maxBound); + } + + std::optional divisibility = + assumeOp.getUnionedUnsignedDivisor(opIdx); + if (divisibility.has_value() && *divisibility > 1) { + Value divisor = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(*divisibility)); + Value zero = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(0)); + Value lowPart = rewriter.create(loc, value, divisor); + Value dividesExactly = + rewriter.create(loc, lowPart, zero); + rewriter.create(loc, dividesExactly); + } + } + } + rewriter.replaceOp(loadOp, value); return success(); } }; +/// A pattern to convert util.assume.int into a noop, since we're using it +/// to annotate push constants. +struct UtilAssumeIntConverter final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(IREE::Util::AssumeIntOp assumeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(assumeOp, adaptor.getOperands()); + return success(); + } +}; + /// A pattern to convert hal.interface.workgroup.id/count/size into /// corresponding SPIR-V Builtin ops. template @@ -615,6 +670,8 @@ void ConvertToSPIRVPass::runOnOperation() { return signalPassFailure(); } + bool supportsAssume = targetEnv.allows(spirv::Capability::ExpectAssumeKHR); + SPIRVConversionOptions options = {}; options.use64bitIndex = use64bitIndex; @@ -660,14 +717,16 @@ void ConvertToSPIRVPass::runOnOperation() { // Add IREE HAL interface op conversions. patterns.add< - HALInterfaceLoadConstantConverter, HALInterfaceWorkgroupOpsConverter, HALInterfaceWorkgroupOpsConverter, HALInterfaceWorkgroupOpsConverter>( - typeConverter, context); + spirv::BuiltIn::NumWorkgroups>, + UtilAssumeIntConverter>(typeConverter, context); + + patterns.add(typeConverter, context, + supportsAssume); // Performs a prelimiary step to analyze all hal.interface.binding.subspan ops // and creates spirv.GlobalVariables. diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index ea0aa9f45116..dbb233af0f3b 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -636,7 +636,8 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &variantPassManager) { addMemRefLoweringPasses(modulePassManager); } variantPassManager.addPass(createReconcileTranslationInfoPass()); - variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass()); + variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass( + IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true})); { OpPassManager &modulePassManager = variantPassManager.nest(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp index 9e0c7a32247e..f62a72909566 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp @@ -16,6 +16,7 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -71,6 +72,65 @@ struct ConvertHalInterfaceBindingSubspan final } }; +/// Rewrite away assumptions on integers. If the input to the operation is the +/// result of an extui (so usually it's an i32 that's been extended to an i64) +/// we port the assumptions over to the underlying 32-bit value. Otherwise, if +/// there's a type conversion, we drop the assumptions. +struct ConvertUtilAssumeIntOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(IREE::Util::AssumeIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector replacements; + SmallVector newArgs; + SmallVector newAssumptions; + for (auto [result, oldArg, newArg, assumeList] : + llvm::zip_equal(op.getResults(), op.getOperands(), + adaptor.getOperands(), op.getAssumptions())) { + if (auto isExt = oldArg.getDefiningOp()) { + Value smallerOp = rewriter.getRemappedValue(isExt.getIn()); + newArgs.push_back(smallerOp); + newAssumptions.push_back(cast(assumeList)); + replacements.push_back(nullptr); + } else if (oldArg.getType() == newArg.getType()) { + newArgs.push_back(newArg); + newAssumptions.push_back(cast(assumeList)); + replacements.push_back(nullptr); + } else { + replacements.push_back(newArg); + } + } + + if (!newArgs.empty()) { + auto newOp = rewriter.create( + op.getLoc(), newArgs, newAssumptions); + LLVM_DEBUG(llvm::dbgs() + << "WideIntegerEmulation: new op: " << newOp << "\n"); + + unsigned replacementLoc = 0; + for (auto result : newOp.getResults()) { + while (replacements[replacementLoc] != nullptr) + replacementLoc++; + Value replacement = result; + Type newType = getTypeConverter()->convertType( + op.getResult(replacementLoc).getType()); + if (auto vecType = dyn_cast_if_present(newType)) { + Value zeros = rewriter.create( + op.getLoc(), newType, rewriter.getZeroAttr(newType)); + replacement = rewriter.create( + op.getLoc(), result, zeros, ArrayRef{0}); + } + replacements[replacementLoc] = replacement; + } + } + + rewriter.replaceOp(op, replacements); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Rewrite patterns //===----------------------------------------------------------------------===// @@ -150,8 +210,8 @@ struct FlattenElementwisePattern final : RewritePattern { static void populateIreeI64EmulationPatterns(arith::WideIntEmulationConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter, - patterns.getContext()); + patterns.add( + converter, patterns.getContext()); } static bool supportsI64(FunctionOpInterface op) { @@ -190,13 +250,13 @@ struct SPIRVEmulateI64Pass final }); target.addDynamicallyLegalDialect< arith::ArithDialect, func::FuncDialect, IREE::HAL::HALDialect, - memref::MemRefDialect, vector::VectorDialect>( - [&typeConverter](Operation *op) { - bool legal = typeConverter.isLegal(op); - LLVM_DEBUG(if (!legal) llvm::dbgs() - << "WideIntegerEmulation: illegal op: " << *op << "\n"); - return legal; - }); + memref::MemRefDialect, vector::VectorDialect, + IREE::Util::UtilDialect>([&typeConverter](Operation *op) { + bool legal = typeConverter.isLegal(op); + LLVM_DEBUG(if (!legal) llvm::dbgs() + << "WideIntegerEmulation: illegal op: " << *op << "\n"); + return legal; + }); RewritePatternSet patterns(ctx); arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir index b43700618377..2bfdb8f6195f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir @@ -36,6 +36,45 @@ hal.executable private @push_constant { // ----- +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +hal.executable private @push_constant_annotated { + hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) { + hal.executable.export @push_constant_annotated layout(#pipeline_layout) attributes { + workgroup_size = [32: index, 1: index, 1: index] + } + builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + // CHECK-LABEL: spirv.func @push_constant_annotated() + func.func @push_constant_annotated() -> index { + // CHECK: %[[LOAD:.+]] = spirv.Load "PushConstant" %{{.*}} : i32 + // CHECK-DAG: %[[C0:.+]] = spirv.Constant 0 : i32 + // CHECK-DAG: %[[C4096:.+]] = spirv.Constant 4096 : i32 + // CHECK-DAG: %[[UGE:.+]] = spirv.UGreaterThanEqual %[[LOAD]], %[[C0]] + // CHECK-DAG: spirv.KHR.AssumeTrue %[[UGE]] + // CHECK-DAG: %[[ULE:.+]] = spirv.ULessThanEqual %[[LOAD]], %[[C4096]] + // CHECK-DAG: spirv.KHR.AssumeTrue %[[ULE]] + // CHECK-DAG: %[[C2048:.+]] = spirv.Constant 2048 : i32 + // CHECK-DAG: %[[C0_2:.+]] = spirv.Constant 0 : i32 + // CHECK-DAG: %[[MOD:.+]] = spirv.UMod %[[LOAD]], %[[C2048]] + // CHECK-DAG: %[[LOWCLEAR:.+]] = spirv.IEqual %[[MOD]], %[[C0_2]] + // CHECK-DAG: spirv.KHR.AssumeTrue %[[LOWCLEAR]] + // CHECK: spirv.ReturnValue %[[LOAD]] + %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : i32 + %1 = util.assume.int %0[ + , + , + ] : i32 + %2 = arith.index_castui %1 : i32 to index + return %2 : index + } + } + } +} + +// ----- + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir index 28405723f7ab..bafaa8da7034 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir @@ -39,6 +39,40 @@ func.func @buffer_types() attributes {hal.executable.target = #executable_target // ----- +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { + iree.gpu.target = #iree_gpu.target> +}> +func.func @splat_i64_with_assume() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} { + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = hal.interface.constant.load layout(], flags = Indirect>) ordinal(0) : i32 + %1 = arith.extui %0 : i32 to i64 + %2 = util.assume.int %1[, , , , ] : i64 + %3 = hal.interface.binding.subspan layout(], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags(Indirect) : memref>{%c1} + memref.store %2, %3[%c0] : memref> + return +} + +// Check that assume operations that annonatate i64 values which were really only +// 32 bits become assumptions on the underlying values +// CHECK-LABEL: func.func @splat_i64_with_assume +// CHECK: %[[PUSH_CONST:.+]] = hal.interface.constant.load +// CHECK: %[[ASSUME:.+]] = util.assume.int %[[PUSH_CONST]] +// CHECK: %[[ASSUME_EXT:.+]] = vector.insert %[[ASSUME]], %{{.*}}[0] +// CHECK: memref.store %[[ASSUME_EXT]] + +// ----- + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp index 2d38d54a9748..1f5a21b1d010 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp @@ -17,9 +17,10 @@ namespace mlir::iree_compiler::IREE::Util { namespace { -class DropCompilerHintsPass +struct DropCompilerHintsPass : public impl::DropCompilerHintsPassBase { -public: + using Base::Base; + void runOnOperation() override { // We can't use patterns and applyPatternsAndFoldGreedily because that // automatically does canonicalization. @@ -28,8 +29,10 @@ class DropCompilerHintsPass op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto op = dyn_cast(genericOp)) { - op.replaceAllUsesWith(op.getOperands()); - op.erase(); + if (!keepAssumeInt) { + op.replaceAllUsesWith(op.getOperands()); + op.erase(); + } } }); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td index fb0cf7d028a0..68d1e3421d35 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td @@ -27,7 +27,20 @@ def DropCompilerHintsPass : Pass<"iree-util-drop-compiler-hints", ""> { Deletes operations that have no runtime equivalent and are only used in the compiler. This should be performed after all other compiler passes. + + With keep-assume-int=true, leaves util.int.assume operations in place + so they can be propagated to backends. This is a temporary measure + until all bbackends have a rewrite for those assumptions (currently + they're only handled by the patterns that target LLVM). }]; + + let options = [ + Option< + "keepAssumeInt", "keep-assume-int", + "bool", "false", + "Whether annotations about the ranges and divisibility of integers should be kept." + >, + ]; } def DumpModulePass : Pass<"iree-util-dump-module", "mlir::ModuleOp"> {