diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp index c056d44538bb..6d6b4c19e5fb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp @@ -9,7 +9,9 @@ #include "iree/compiler/Codegen/LLVMGPU/Passes.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/Utils.h" +#include "iree/compiler/Dialect/HAL/IR/HALOps.h" #include "iree/compiler/Dialect/Util/IR/UtilOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" @@ -225,15 +227,13 @@ getKernelArgMapping(Operation *funcOp) { return mapBindingArgIndex; } -class ConvertFunc : public ConvertToLLVMPattern { +class ConvertFunc : public ConvertOpToLLVMPattern { public: - explicit ConvertFunc(MLIRContext *context, LLVMTypeConverter &converter) - : ConvertToLLVMPattern(mlir::func::FuncOp::getOperationName(), context, - converter, 100) {} + explicit ConvertFunc(LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter, 100) {} LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto funcOp = cast(op); FunctionType fnType = funcOp.getFunctionType(); (void)fnType; if (!funcOp.isPublic()) @@ -302,13 +302,9 @@ class ConvertFunc : public ConvertToLLVMPattern { } }; -class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern { -public: - explicit ConvertIREEBindingSubspanOp(MLIRContext *context, - LLVMTypeConverter &converter) - : ConvertToLLVMPattern( - IREE::HAL::InterfaceBindingSubspanOp::getOperationName(), context, - converter) {} +struct ConvertIREEBindingSubspanOp final + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; /// Checks all subspanOps with the same binding has readonly attribute static bool checkAllSubspansReadonly(LLVM::LLVMFuncOp llvmFuncOp, @@ -330,7 +326,7 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern { } LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(IREE::HAL::InterfaceBindingSubspanOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); @@ -341,8 +337,6 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern { auto argMapping = getKernelArgMapping(llvmFuncOp); Location loc = op->getLoc(); auto subspanOp = cast(op); - IREE::HAL::InterfaceBindingSubspanOpAdaptor adaptor( - operands, op->getAttrDictionary()); MemRefType memrefType = llvm::dyn_cast(subspanOp.getResult().getType()); mlir::BlockArgument llvmBufferArg = @@ -453,15 +447,12 @@ class ConvertIREEBindingSubspanOp : public ConvertToLLVMPattern { } }; -class ConvertIREEConstantOp : public ConvertToLLVMPattern { -public: - explicit ConvertIREEConstantOp(MLIRContext *context, - LLVMTypeConverter &converter) - : ConvertToLLVMPattern( - IREE::HAL::InterfaceConstantLoadOp::getOperationName(), context, - converter) {} +struct ConvertIREEConstantOp final + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(IREE::HAL::InterfaceConstantLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Bail until nested under an LLVMFuncOp. auto llvmFuncOp = op->getParentOfType(); @@ -470,9 +461,8 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern { assert(llvmFuncOp.getNumArguments() > 0); auto argMapping = getKernelArgMapping(llvmFuncOp); - auto ireeConstantOp = cast(op); mlir::BlockArgument llvmBufferArg = llvmFuncOp.getArgument( - argMapping.size() + ireeConstantOp.getOrdinal().getZExtValue()); + argMapping.size() + op.getOrdinal().getZExtValue()); assert(llvmBufferArg.getType().isInteger(32)); // Push constants are never `undef`, annotate that here, just as with @@ -481,7 +471,53 @@ class ConvertIREEConstantOp : public ConvertToLLVMPattern { LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr()); - Type dstType = getTypeConverter()->convertType(ireeConstantOp.getType()); + // If the constant has non-trivial assumptions placed on it about + // its min and max values or divisibility, use that information to + // annotate the corresponding arguments. + if (op.getResult().hasOneUse()) { + OpOperand *operand = op.getResult().getUses().begin().getOperand(); + auto assumeOp = dyn_cast(operand->getOwner()); + if (assumeOp) { + unsigned opIdx = operand->getOperandNumber(); + auto [min, max] = assumeOp.getUnionedUnsignedRange(opIdx); + + if (min.has_value() && max.has_value()) { + assert(*min <= std::numeric_limits::max() && + "Push-constant's maximum value can't be outside 32 bits, but " + "this is assumed"); + // Note: LLVM's range(iN lb, ub) is [lb, ub), while MLIR's is [lb, + // ub], so we add 1 to the upper bound. + llvmFuncOp.setArgAttr(llvmBufferArg.getArgNumber(), + LLVM::LLVMDialect::getRangeAttrName(), + rewriter.getAttr( + APInt(32, *min), APInt(32, *max) + 1)); + } + + auto divisibility = assumeOp.getUnionedUnsignedDivisor(opIdx); + + auto makeI32Const = [&](uint32_t val) -> Value { + return rewriter.create( + assumeOp.getLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(val)); + }; + if (divisibility.has_value() && *divisibility > 1) { + Location loc = assumeOp.getLoc(); + assert(*divisibility <= std::numeric_limits::max() && + "push constant shouldn't be statically divisible by a value " + "it can't hold"); + Value knownDivisibleBy = makeI32Const(*divisibility); + // This'll almost always become an and + Value lowPart = rewriter.create(loc, llvmBufferArg, + knownDivisibleBy); + Value zero = makeI32Const(0); + Value isEvenlyDivided = rewriter.create( + loc, LLVM::ICmpPredicate::eq, lowPart, zero); + rewriter.create(loc, isEvenlyDivided); + } + } + } + + Type dstType = getTypeConverter()->convertType(op.getType()); // llvm.zext requires that the result type has a larger bitwidth. if (dstType == llvmBufferArg.getType()) { rewriter.replaceOp(op, llvmBufferArg); @@ -510,14 +546,28 @@ struct HALInterfaceWorkgroupOpsConverter final } }; +struct ConvertIREEUtilAssumeIntOp final + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(IREE::Util::AssumeIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Bail until nested under an LLVMFuncOp. + auto llvmFuncOp = op->getParentOfType(); + if (!llvmFuncOp) + return failure(); + rewriter.replaceOp(op, adaptor.getOperands()); + return success(); + } +}; } // namespace void populateLLVMConversionPatterns(MLIRContext *context, RewritePatternSet &patterns, LLVMTypeConverter &converter) { - patterns - .insert( - context, converter); + patterns.insert(converter); } void populateScalarizeMathOps(RewritePatternSet &patterns) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp index 53e49efbf66a..ee72a8dbe22a 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp @@ -1138,7 +1138,8 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager, modulePassManager.addPass(createStripDebugInfoPass()); // Cast address spaces of all function arguments to generic. modulePassManager.addPass(createLLVMGPUCastAddressSpaceFunctionPass()); - modulePassManager.addPass(IREE::Util::createDropCompilerHintsPass()); + modulePassManager.addPass(IREE::Util::createDropCompilerHintsPass( + IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true})); if (forROCDL) { // convert to ROCDL. @@ -1267,7 +1268,9 @@ void buildROCDLCodegenPassPipeline(OpPassManager &variantPassManager) { .addPass(createVerifyWorkgroupDistributionPass); } variantPassManager.addPass(createReconcileTranslationInfoPass()); - variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass()); + variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass( + IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true})); + ; addLowerToLLVMGPUPasses(variantPassManager.nest(), /*forROCDL=*/true); diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir index bd876a377857..6a9c4e82490b 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_nvvm.mlir @@ -55,10 +55,11 @@ hal.executable @abs_dynamic { %c5 = arith.constant 5 : index %c7 = arith.constant 7 : index %o = hal.interface.constant.load layout(#pipeline_layout) ordinal(0) : index + %o.tagged = util.assume.int %o[, ] : index %d0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(1) : index %d1 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : index %d2 = hal.interface.constant.load layout(#pipeline_layout) ordinal(3) : index - %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%o) : memref>{%d0, %d1, %d2} + %0 = hal.interface.binding.subspan layout(#pipeline_layout) binding(1) offset(%o.tagged) : memref>{%d0, %d1, %d2} %1 = hal.interface.binding.subspan layout(#pipeline_layout) binding(0) : memref{%d0, %d1, %d2} %2 = hal.interface.binding.subspan layout(#pipeline_layout) binding(2) : memref{%d0, %d1, %d2} %9 = memref.load %0[%c3, %c5, %c7] : memref> @@ -75,10 +76,15 @@ hal.executable @abs_dynamic { // CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef}, // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef}, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: !llvm.ptr {llvm.align = 16 : i32, llvm.noalias, llvm.nonnull, llvm.noundef}, -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32 {llvm.noundef}, +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32 {llvm.noundef, llvm.range = #llvm.constant_range}, // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32 {llvm.noundef}, // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32 {llvm.noundef}, // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i32 {llvm.noundef}) +// CHECK-DAG: %[[C4096_i32:.+]] = llvm.mlir.constant(4096 : i32) : i32 +// CHECK-DAG: %[[C0_i32:.+]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-DAG: %[[ARG3_UREM:.+]] = llvm.urem %[[ARG3]], %[[C4096_i32]] : i32 +// CHECK-DAG: %[[ARG3_CMP:.+]] = llvm.icmp "eq" %[[ARG3_UREM]], %[[C0_i32]] +// CHECK: llvm.intr.assume %[[ARG3_CMP]] // CHECK-DAG: %[[OFFSET:.+]] = llvm.zext %[[ARG3]] : i32 to i64 // CHECK-DAG: %[[D1:.+]] = llvm.zext %[[ARG5]] : i32 to i64 // CHECK-DAG: %[[D2:.+]] = llvm.zext %[[ARG6]] : i32 to i64 diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp index f4099c450849..58c6bf9843de 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/ConvertToSPIRVPass.cpp @@ -277,7 +277,12 @@ namespace { /// ops to load from a global variable representing the push constant storage. struct HALInterfaceLoadConstantConverter final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + bool supportsAssume = false; + + HALInterfaceLoadConstantConverter(TypeConverter &typeConverter, + MLIRContext *context, bool supportsAssume) + : OpConversionPattern(typeConverter, context), + supportsAssume(supportsAssume) {} LogicalResult matchAndRewrite(IREE::HAL::InterfaceConstantLoadOp loadOp, OpAdaptor adaptor, @@ -299,11 +304,61 @@ struct HALInterfaceLoadConstantConverter final Value value = spirv::getPushConstantValue(loadOp, elementCount, index, i32Type, rewriter); + if (loadOp.getResult().hasOneUse() && supportsAssume) { + OpOperand *operand = loadOp.getResult().getUses().begin().getOperand(); + auto assumeOp = dyn_cast(operand->getOwner()); + if (assumeOp) { + Location loc = assumeOp.getLoc(); + unsigned opIdx = operand->getOperandNumber(); + + auto [min, max] = assumeOp.getUnionedUnsignedRange(opIdx); + if (min.has_value() && max.has_value()) { + Value minConst = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(*min)); + Value maxConst = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(*max)); + Value minBound = + rewriter.create(loc, value, minConst); + rewriter.create(loc, minBound); + Value maxBound = + rewriter.create(loc, value, maxConst); + rewriter.create(loc, maxBound); + } + + std::optional divisibility = + assumeOp.getUnionedUnsignedDivisor(opIdx); + if (divisibility.has_value() && *divisibility > 1) { + Value divisor = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(*divisibility)); + Value zero = rewriter.create( + loc, i32Type, rewriter.getI32IntegerAttr(0)); + Value lowPart = rewriter.create(loc, value, divisor); + Value dividesExactly = + rewriter.create(loc, lowPart, zero); + rewriter.create(loc, dividesExactly); + } + } + } + rewriter.replaceOp(loadOp, value); return success(); } }; +/// A pattern to convert util.assume.int into a noop, since we're using it +/// to annotate push constants. +struct UtilAssumeIntConverter final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(IREE::Util::AssumeIntOp assumeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(assumeOp, adaptor.getOperands()); + return success(); + } +}; + /// A pattern to convert hal.interface.workgroup.id/count/size into /// corresponding SPIR-V Builtin ops. template @@ -615,6 +670,8 @@ void ConvertToSPIRVPass::runOnOperation() { return signalPassFailure(); } + bool supportsAssume = targetEnv.allows(spirv::Capability::ExpectAssumeKHR); + SPIRVConversionOptions options = {}; options.use64bitIndex = use64bitIndex; @@ -660,14 +717,16 @@ void ConvertToSPIRVPass::runOnOperation() { // Add IREE HAL interface op conversions. patterns.add< - HALInterfaceLoadConstantConverter, HALInterfaceWorkgroupOpsConverter, HALInterfaceWorkgroupOpsConverter, HALInterfaceWorkgroupOpsConverter>( - typeConverter, context); + spirv::BuiltIn::NumWorkgroups>, + UtilAssumeIntConverter>(typeConverter, context); + + patterns.add(typeConverter, context, + supportsAssume); // Performs a prelimiary step to analyze all hal.interface.binding.subspan ops // and creates spirv.GlobalVariables. diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp index ea0aa9f45116..dbb233af0f3b 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp @@ -636,7 +636,8 @@ void buildSPIRVCodegenPassPipeline(OpPassManager &variantPassManager) { addMemRefLoweringPasses(modulePassManager); } variantPassManager.addPass(createReconcileTranslationInfoPass()); - variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass()); + variantPassManager.addPass(IREE::Util::createDropCompilerHintsPass( + IREE::Util::DropCompilerHintsPassOptions{/*keepAssumeInt=*/true})); { OpPassManager &modulePassManager = variantPassManager.nest(); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp index 9e0c7a32247e..f62a72909566 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/SPIRVEmulateI64.cpp @@ -16,6 +16,7 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Dialect/HAL/IR/HALDialect.h" #include "iree/compiler/Dialect/HAL/IR/HALOps.h" +#include "iree/compiler/Dialect/Util/IR/UtilDialect.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -71,6 +72,65 @@ struct ConvertHalInterfaceBindingSubspan final } }; +/// Rewrite away assumptions on integers. If the input to the operation is the +/// result of an extui (so usually it's an i32 that's been extended to an i64) +/// we port the assumptions over to the underlying 32-bit value. Otherwise, if +/// there's a type conversion, we drop the assumptions. +struct ConvertUtilAssumeIntOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(IREE::Util::AssumeIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector replacements; + SmallVector newArgs; + SmallVector newAssumptions; + for (auto [result, oldArg, newArg, assumeList] : + llvm::zip_equal(op.getResults(), op.getOperands(), + adaptor.getOperands(), op.getAssumptions())) { + if (auto isExt = oldArg.getDefiningOp()) { + Value smallerOp = rewriter.getRemappedValue(isExt.getIn()); + newArgs.push_back(smallerOp); + newAssumptions.push_back(cast(assumeList)); + replacements.push_back(nullptr); + } else if (oldArg.getType() == newArg.getType()) { + newArgs.push_back(newArg); + newAssumptions.push_back(cast(assumeList)); + replacements.push_back(nullptr); + } else { + replacements.push_back(newArg); + } + } + + if (!newArgs.empty()) { + auto newOp = rewriter.create( + op.getLoc(), newArgs, newAssumptions); + LLVM_DEBUG(llvm::dbgs() + << "WideIntegerEmulation: new op: " << newOp << "\n"); + + unsigned replacementLoc = 0; + for (auto result : newOp.getResults()) { + while (replacements[replacementLoc] != nullptr) + replacementLoc++; + Value replacement = result; + Type newType = getTypeConverter()->convertType( + op.getResult(replacementLoc).getType()); + if (auto vecType = dyn_cast_if_present(newType)) { + Value zeros = rewriter.create( + op.getLoc(), newType, rewriter.getZeroAttr(newType)); + replacement = rewriter.create( + op.getLoc(), result, zeros, ArrayRef{0}); + } + replacements[replacementLoc] = replacement; + } + } + + rewriter.replaceOp(op, replacements); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Rewrite patterns //===----------------------------------------------------------------------===// @@ -150,8 +210,8 @@ struct FlattenElementwisePattern final : RewritePattern { static void populateIreeI64EmulationPatterns(arith::WideIntEmulationConverter &converter, RewritePatternSet &patterns) { - patterns.add(converter, - patterns.getContext()); + patterns.add( + converter, patterns.getContext()); } static bool supportsI64(FunctionOpInterface op) { @@ -190,13 +250,13 @@ struct SPIRVEmulateI64Pass final }); target.addDynamicallyLegalDialect< arith::ArithDialect, func::FuncDialect, IREE::HAL::HALDialect, - memref::MemRefDialect, vector::VectorDialect>( - [&typeConverter](Operation *op) { - bool legal = typeConverter.isLegal(op); - LLVM_DEBUG(if (!legal) llvm::dbgs() - << "WideIntegerEmulation: illegal op: " << *op << "\n"); - return legal; - }); + memref::MemRefDialect, vector::VectorDialect, + IREE::Util::UtilDialect>([&typeConverter](Operation *op) { + bool legal = typeConverter.isLegal(op); + LLVM_DEBUG(if (!legal) llvm::dbgs() + << "WideIntegerEmulation: illegal op: " << *op << "\n"); + return legal; + }); RewritePatternSet patterns(ctx); arith::populateArithWideIntEmulationPatterns(typeConverter, patterns); diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir index b43700618377..2bfdb8f6195f 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/convert_to_spirv.mlir @@ -36,6 +36,45 @@ hal.executable private @push_constant { // ----- +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding +]> +hal.executable private @push_constant_annotated { + hal.executable.variant @vulkan target(<"vulkan-spirv", "vulkan-spirv-fb">) { + hal.executable.export @push_constant_annotated layout(#pipeline_layout) attributes { + workgroup_size = [32: index, 1: index, 1: index] + } + builtin.module attributes {spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>>} { + // CHECK-LABEL: spirv.func @push_constant_annotated() + func.func @push_constant_annotated() -> index { + // CHECK: %[[LOAD:.+]] = spirv.Load "PushConstant" %{{.*}} : i32 + // CHECK-DAG: %[[C0:.+]] = spirv.Constant 0 : i32 + // CHECK-DAG: %[[C4096:.+]] = spirv.Constant 4096 : i32 + // CHECK-DAG: %[[UGE:.+]] = spirv.UGreaterThanEqual %[[LOAD]], %[[C0]] + // CHECK-DAG: spirv.KHR.AssumeTrue %[[UGE]] + // CHECK-DAG: %[[ULE:.+]] = spirv.ULessThanEqual %[[LOAD]], %[[C4096]] + // CHECK-DAG: spirv.KHR.AssumeTrue %[[ULE]] + // CHECK-DAG: %[[C2048:.+]] = spirv.Constant 2048 : i32 + // CHECK-DAG: %[[C0_2:.+]] = spirv.Constant 0 : i32 + // CHECK-DAG: %[[MOD:.+]] = spirv.UMod %[[LOAD]], %[[C2048]] + // CHECK-DAG: %[[LOWCLEAR:.+]] = spirv.IEqual %[[MOD]], %[[C0_2]] + // CHECK-DAG: spirv.KHR.AssumeTrue %[[LOWCLEAR]] + // CHECK: spirv.ReturnValue %[[LOAD]] + %0 = hal.interface.constant.load layout(#pipeline_layout) ordinal(2) : i32 + %1 = util.assume.int %0[ + , + , + ] : i32 + %2 = arith.index_castui %1 : i32 to index + return %2 : index + } + } + } +} + +// ----- + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding, diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir index 28405723f7ab..bafaa8da7034 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/emulate_i64.mlir @@ -39,6 +39,40 @@ func.func @buffer_types() attributes {hal.executable.target = #executable_target // ----- +#pipeline_layout = #hal.pipeline.layout, + #hal.pipeline.binding, + #hal.pipeline.binding +]> +#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { + iree.gpu.target = #iree_gpu.target> +}> +func.func @splat_i64_with_assume() attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} { + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = hal.interface.constant.load layout(], flags = Indirect>) ordinal(0) : i32 + %1 = arith.extui %0 : i32 to i64 + %2 = util.assume.int %1[, , , , ] : i64 + %3 = hal.interface.binding.subspan layout(], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags(Indirect) : memref>{%c1} + memref.store %2, %3[%c0] : memref> + return +} + +// Check that assume operations that annonatate i64 values which were really only +// 32 bits become assumptions on the underlying values +// CHECK-LABEL: func.func @splat_i64_with_assume +// CHECK: %[[PUSH_CONST:.+]] = hal.interface.constant.load +// CHECK: %[[ASSUME:.+]] = util.assume.int %[[PUSH_CONST]] +// CHECK: %[[ASSUME_EXT:.+]] = vector.insert %[[ASSUME]], %{{.*}}[0] +// CHECK: memref.store %[[ASSUME_EXT]] + +// ----- + #pipeline_layout = #hal.pipeline.layout, #hal.pipeline.binding diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp index 2d38d54a9748..1f5a21b1d010 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/DropCompilerHints.cpp @@ -17,9 +17,10 @@ namespace mlir::iree_compiler::IREE::Util { namespace { -class DropCompilerHintsPass +struct DropCompilerHintsPass : public impl::DropCompilerHintsPassBase { -public: + using Base::Base; + void runOnOperation() override { // We can't use patterns and applyPatternsAndFoldGreedily because that // automatically does canonicalization. @@ -28,8 +29,10 @@ class DropCompilerHintsPass op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto op = dyn_cast(genericOp)) { - op.replaceAllUsesWith(op.getOperands()); - op.erase(); + if (!keepAssumeInt) { + op.replaceAllUsesWith(op.getOperands()); + op.erase(); + } } }); } diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td index fb0cf7d028a0..68d1e3421d35 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/Passes.td @@ -27,7 +27,20 @@ def DropCompilerHintsPass : Pass<"iree-util-drop-compiler-hints", ""> { Deletes operations that have no runtime equivalent and are only used in the compiler. This should be performed after all other compiler passes. + + With keep-assume-int=true, leaves util.int.assume operations in place + so they can be propagated to backends. This is a temporary measure + until all bbackends have a rewrite for those assumptions (currently + they're only handled by the patterns that target LLVM). }]; + + let options = [ + Option< + "keepAssumeInt", "keep-assume-int", + "bool", "false", + "Whether annotations about the ranges and divisibility of integers should be kept." + >, + ]; } def DumpModulePass : Pass<"iree-util-dump-module", "mlir::ModuleOp"> {