diff --git a/include/polygeist/Passes/Passes.h b/include/polygeist/Passes/Passes.h index 7a4fe27c5be4..76721b4fc638 100644 --- a/include/polygeist/Passes/Passes.h +++ b/include/polygeist/Passes/Passes.h @@ -19,6 +19,7 @@ std::unique_ptr createParallelLowerPass(); std::unique_ptr createConvertPolygeistToLLVMPass(const LowerToLLVMOptions &options); std::unique_ptr createConvertPolygeistToLLVMPass(); +std::unique_ptr createLowerPolygeistOpsPass(); } // namespace polygeist } // namespace mlir diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 86eed651d7d6..d209bcd0ffee 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -61,6 +61,12 @@ def RemoveTrivialUse : FunctionPass<"trivialuse"> { let constructor = "mlir::polygeist::createRemoveTrivialUsePass()"; } +def LowerPolygeistOps : FunctionPass<"lower-polygeist-ops"> { + let summary = "Lower polygeist ops to memref operations"; + let constructor = "mlir::polygeist::createLowerPolygeistOpsPass()"; + let dependentDialects = ["::mlir::memref::MemRefDialect"]; +} + def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> { let summary = "Convert scalar and vector operations from the Standard to the " "LLVM dialect"; diff --git a/lib/polygeist/Ops.cpp b/lib/polygeist/Ops.cpp index 9d40ad88d610..0ecb833b38b6 100644 --- a/lib/polygeist/Ops.cpp +++ b/lib/polygeist/Ops.cpp @@ -176,48 +176,6 @@ class SubToCast final : public OpRewritePattern { } }; -// Simplify polygeist.subindex to memref.subview. -class SubToSubView final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SubIndexOp op, - PatternRewriter &rewriter) const override { - auto srcMemRefType = op.source().getType().cast(); - auto resMemRefType = op.result().getType().cast(); - auto dims = srcMemRefType.getShape().size(); - - // For now, restrict subview lowering to statically defined memref's - if (!srcMemRefType.hasStaticShape() | !resMemRefType.hasStaticShape()) - return failure(); - - // For now, restrict to simple rank-reducing indexing - if (srcMemRefType.getShape().size() <= resMemRefType.getShape().size()) - return failure(); - - // Build offset, sizes and strides - SmallVector sizes(dims, rewriter.getIndexAttr(0)); - sizes[0] = op.index(); - SmallVector offsets(dims); - for (auto dim : llvm::enumerate(srcMemRefType.getShape())) { - if (dim.index() == 0) - offsets[0] = rewriter.getIndexAttr(1); - else - offsets[dim.index()] = rewriter.getIndexAttr(dim.value()); - } - SmallVector strides(dims, rewriter.getIndexAttr(1)); - - // Generate the appropriate return type: - auto subMemRefType = MemRefType::get(srcMemRefType.getShape().drop_front(), - srcMemRefType.getElementType()); - - rewriter.replaceOpWithNewOp( - op, subMemRefType, op.source(), sizes, offsets, strides); - - return success(); - } -}; - // Simplify redundant dynamic subindex patterns which tries to represent // rank-reducing indexing: // %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) -> @@ -678,7 +636,7 @@ void SubIndexOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); - // Disabled: SubToSubView + // Disabled: } /// Simplify pointer2memref(memref2pointer(x)) to cast(x) diff --git a/lib/polygeist/Passes/CMakeLists.txt b/lib/polygeist/Passes/CMakeLists.txt index 19f5ec443855..371c5cef2bee 100644 --- a/lib/polygeist/Passes/CMakeLists.txt +++ b/lib/polygeist/Passes/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRPolygeistTransforms ParallelLower.cpp TrivialUse.cpp ConvertPolygeistToLLVM.cpp + LowerPolygeistOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine diff --git a/lib/polygeist/Passes/LowerPolygeistOps.cpp b/lib/polygeist/Passes/LowerPolygeistOps.cpp new file mode 100644 index 000000000000..be3152b0d513 --- /dev/null +++ b/lib/polygeist/Passes/LowerPolygeistOps.cpp @@ -0,0 +1,88 @@ +//===- TrivialUse.cpp - Remove trivial use instruction ---------------- -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower gpu kernels in NVVM/gpu dialects into +// a generic parallel for representation +//===----------------------------------------------------------------------===// +#include "PassDetails.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/DialectConversion.h" +#include "polygeist/Dialect.h" +#include "polygeist/Ops.h" + +using namespace mlir; +using namespace polygeist; +using namespace mlir::arith; + +namespace { + +struct SubIndexToReinterpretCast + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(polygeist::SubIndexOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcMemRefType = op.source().getType().cast(); + auto resMemRefType = op.result().getType().cast(); + auto shape = srcMemRefType.getShape(); + + if (!resMemRefType.hasStaticShape()) + return failure(); + + int64_t innerSize = resMemRefType.getNumElements(); + auto offset = rewriter.create( + op.getLoc(), op.index(), + rewriter.create(op.getLoc(), innerSize)); + + llvm::SmallVector sizes, strides; + for (auto dim : shape.drop_front()) { + sizes.push_back(rewriter.getIndexAttr(dim)); + strides.push_back(rewriter.getIndexAttr(1)); + } + + rewriter.replaceOpWithNewOp( + op, resMemRefType, op.source(), offset.getResult(), sizes, strides); + + return success(); + } +}; + +struct LowerPolygeistOpsPass + : public LowerPolygeistOpsBase { + + void runOnFunction() override { + auto op = getOperation(); + auto ctx = op.getContext(); + RewritePatternSet patterns(ctx); + patterns.insert(ctx); + + ConversionTarget target(*ctx); + target.addIllegalDialect(); + target.addLegalDialect(); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +namespace mlir { +namespace polygeist { +std::unique_ptr createLowerPolygeistOpsPass() { + return std::make_unique(); +} + +} // namespace polygeist +} // namespace mlir diff --git a/test/polygeist-opt/canonicalization.mlir b/test/polygeist-opt/canonicalization.mlir deleted file mode 100644 index d68b8c40dc34..000000000000 --- a/test/polygeist-opt/canonicalization.mlir +++ /dev/null @@ -1,29 +0,0 @@ -// RUN: polygeist-opt --canonicalize --split-input-file %s | FileCheck %s -// XFAIL: * -// CHECK: func @main(%arg0: index) -> memref<30xi32> { -// CHECK: %0 = memref.alloca() : memref<30x30xi32> -// CHECK: %1 = memref.subview %0[%arg0, 0] [1, 30] [1, 1] : memref<30x30xi32> to memref<30xi32> -// CHECK: return %1 : memref<30xi32> -// CHECK: } -module { - func @main(%arg0 : index) -> memref<30xi32> { - %0 = memref.alloca() : memref<30x30xi32> - %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32> - return %1 : memref<30xi32> - } -} - -// ----- - -// CHECK: func @main(%arg0: index) -> memref<1000xi32> { -// CHECK: %0 = memref.alloca() : memref<2x1000xi32> -// CHECK: %1 = memref.subview %0[%arg0, 0] [1, 1000] [1, 1] : memref<2x1000xi32> to memref<1000xi32> -// CHECK: return %1 : memref<1000xi32> -// CHECK: } -func @main(%arg0 : index) -> memref<1000xi32> { - %c0 = arith.constant 0 : index - %1 = memref.alloca() : memref<2x1000xi32> - %3 = "polygeist.subindex"(%1, %arg0) : (memref<2x1000xi32>, index) -> memref - %4 = "polygeist.subindex"(%3, %c0) : (memref, index) -> memref<1000xi32> - return %4 : memref<1000xi32> -} diff --git a/test/polygeist-opt/lower_polygeist_ops.mlir b/test/polygeist-opt/lower_polygeist_ops.mlir new file mode 100644 index 000000000000..cd84039e637b --- /dev/null +++ b/test/polygeist-opt/lower_polygeist_ops.mlir @@ -0,0 +1,17 @@ +// RUN: polygeist-opt --lower-polygeist-ops --split-input-file %s | FileCheck %s + +// CHECK-LABEL: func @main( +// CHECK-SAME: %[[VAL_0:.*]]: index) -> memref<30xi32> { +// CHECK: %[[VAL_1:.*]] = memref.alloca() : memref<30x30xi32> +// CHECK: %[[VAL_2:.*]] = arith.constant 30 : index +// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_2]] : index +// CHECK: %[[VAL_4:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: {{\[}}%[[VAL_3]]], sizes: [30], strides: [1] : memref<30x30xi32> to memref<30xi32> +// CHECK: return %[[VAL_4]] : memref<30xi32> +// CHECK: } +module { + func @main(%arg0 : index) -> memref<30xi32> { + %0 = memref.alloca() : memref<30x30xi32> + %1 = "polygeist.subindex"(%0, %arg0) : (memref<30x30xi32>, index) -> memref<30xi32> + return %1 : memref<30xi32> + } +}