Skip to content

Commit

Permalink
[midend] Restrict the transpose optimize for 2 rank tensor and remove…
Browse files Browse the repository at this point in the history
… unused header files.
  • Loading branch information
EllisLambda committed Oct 30, 2023
1 parent 2fa21b1 commit d603b56
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,10 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/ValueRange.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Debug.h"
#include <cstdint>
#include <mlir/Dialect/Affine/Analysis/AffineAnalysis.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include <cstdint>
#include <mlir/Dialect/Affine/Analysis/AffineAnalysis.h>
Expand Down Expand Up @@ -62,12 +63,23 @@ class TransposeOptimizationPattern : public ConversionPattern {
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> /*operands*/,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto permutationArrayAttr =
op->getAttr(rewriter.getStringAttr("permutation"))
.cast<DenseI64ArrayAttr>()
.asArrayRef();

// Retrieve input tensors A, B, and C.
// Retrieve input tensors A, B.
Value A = op->getOperand(0);
Value B = op->getOperand(1);

// Only to rewrite the rank 2 tensor transpose.
if (permutationArrayAttr[0] != 1 or permutationArrayAttr[1] != 0 or
A.getType().cast<MemRefType>().getRank() != 2) {
return failure();
}

auto loc = op->getLoc();

// Acquire the element type of input tensors.
Type elementType = A.getType().cast<MemRefType>().getElementType();

Expand Down Expand Up @@ -124,7 +136,7 @@ class TransposeOptimizationPattern : public ConversionPattern {
llvm::map_range(ArrayRef<LoopReduction>{},
[](const LoopReduction &red) { return red.value; }));

// Create the primary parallel batch level loop.
// Create the primary parallel loop.
AffineParallelOp parallelColLoop =
rewriter.create<affine::AffineParallelOp>(
loc, ValueRange(reducedValues).getTypes(), ValueRange{Col},
Expand Down Expand Up @@ -253,7 +265,7 @@ class TransposeOptimizationPattern : public ConversionPattern {
parallelColLoop.getRegion().push_back(loopBody);
rewriter.setInsertionPointAfter(parallelColLoop);

if (A.getType().cast<MemRefType>().getDimSize(1) < 0 or
if (A.getType().cast<MemRefType>().getDimSize(0) < 0 or
A.getType().cast<MemRefType>().getDimSize(1) % affineVectorSize != 0) {

affine::AffineIfOp branchingColUnaligned =
Expand Down Expand Up @@ -401,7 +413,7 @@ class TransposeOptimizationPass
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransposeOptimizationPass)
StringRef getArgument() const final { return "transpose-optimize"; }
StringRef getDescription() const final { return "Transpose Optimization."; }
StringRef getDescription() const final { return "Transpose Optimization only for rank 2 tensor."; }
TransposeOptimizationPass() = default;
TransposeOptimizationPass(const TransposeOptimizationPass &) {}
explicit TransposeOptimizationPass(int64_t affineVectorSizeParam) {
Expand Down

0 comments on commit d603b56

Please sign in to comment.