Skip to content

Commit

Permalink
[midend] Canonicalize dynamic rank detection.
Browse files Browse the repository at this point in the history
  • Loading branch information
EllisLambda committed Oct 31, 2023
1 parent fea1ffd commit c21c500
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//===- BatchMatMulOptimize.cpp
//-------------------------------------------------===//
//===- BatchMatMulOptimize.cpp --------------------------------------------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -145,7 +144,7 @@ class BatchMatMulOptimizePattern : public ConversionPattern {
Value loopVarColOfB = ivRange.front();

// Compile time branch detection.
if (C.getType().cast<MemRefType>().getDimSize(2) < 0 or
if (C.getType().cast<MemRefType>().isDynamicDim(2) or
C.getType().cast<MemRefType>().getDimSize(2) % affineVectorSize !=
0) {

Expand Down
2 changes: 1 addition & 1 deletion midend/lib/Conversion/MatMulOptimization/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ add_mlir_library(BatchMatMulOptimization

add_mlir_library(MatMulParallelVectorization
MatMulParallelVectorization.cpp
)
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//===- MatMulParallelVectorization.cpp
//-------------------------------------------------===//
//===- MatMulParallelVectorization.cpp ------------------------------------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -141,7 +140,7 @@ class MatMulParallelVectorizationPattern : public ConversionPattern {
ArrayRef<Value>{aRow, bRow}, false, 3, true);

// Compile time branch detection.
if (C.getType().cast<MemRefType>().getDimSize(1) < 0 or
if (C.getType().cast<MemRefType>().isDynamicDim(1) or
C.getType().cast<MemRefType>().getDimSize(1) % affineVectorSize != 0) {

// Depending on the position, use either full vectors or tail vectors.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//===- BuiltinTransposeVectorization.cpp
//-------------------------------------------------===//
//===- BuiltinTransposeVectorization.cpp ----------------------------------===//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -205,7 +204,7 @@ class TransposeOptimizationPattern : public ConversionPattern {
});

// Compile time branch detection.
if (A.getType().cast<MemRefType>().getDimSize(0) < 0 or
if (A.getType().cast<MemRefType>().isDynamicDim(0) or
A.getType().cast<MemRefType>().getDimSize(0) % affineVectorSize !=
0) {
// Depending on the position, use either full vectors or tail
Expand Down Expand Up @@ -265,7 +264,7 @@ class TransposeOptimizationPattern : public ConversionPattern {
parallelColLoop.getRegion().push_back(loopBody);
rewriter.setInsertionPointAfter(parallelColLoop);

if (A.getType().cast<MemRefType>().getDimSize(1) < 0 or
if (A.getType().cast<MemRefType>().isDynamicDim(1) or
A.getType().cast<MemRefType>().getDimSize(1) % affineVectorSize != 0) {

affine::AffineIfOp branchingColUnaligned =
Expand Down Expand Up @@ -325,7 +324,7 @@ class TransposeOptimizationPattern : public ConversionPattern {
});
});

if (A.getType().cast<MemRefType>().getDimSize(0) < 0 or
if (A.getType().cast<MemRefType>().isDynamicDim(0) or
A.getType().cast<MemRefType>().getDimSize(0) % affineVectorSize !=
0) {
affine::AffineIfOp branchingRowColUnaligned =
Expand Down Expand Up @@ -413,7 +412,9 @@ class TransposeOptimizationPass
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TransposeOptimizationPass)
StringRef getArgument() const final { return "transpose-optimize"; }
StringRef getDescription() const final { return "Transpose Optimization only for rank 2 tensor."; }
StringRef getDescription() const final {
return "Transpose Optimization only for rank 2 tensor.";
}
TransposeOptimizationPass() = default;
TransposeOptimizationPass(const TransposeOptimizationPass &) {}
explicit TransposeOptimizationPass(int64_t affineVectorSizeParam) {
Expand Down
2 changes: 1 addition & 1 deletion midend/lib/Conversion/TransposeOptimization/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ add_mlir_library(TransposeOptimization
BuiltinTransposeVectorization.cpp
LINK_LIBS PUBLIC
BuddyUtils
)
)
2 changes: 1 addition & 1 deletion tools/buddy-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ target_link_libraries(buddy-opt
BuddyGemmini
LowerGemminiPass
LowerLinalgToGemminiPass
)
)

0 comments on commit c21c500

Please sign in to comment.