66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
910#include " mlir/Dialect/Arith/IR/Arith.h"
1011#include " mlir/Dialect/SCF/IR/SCF.h"
12+ #include " mlir/Dialect/SCF/Utils/Utils.h"
13+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
1114#include " mlir/IR/Diagnostics.h"
1215#include " mlir/IR/MLIRContext.h"
1316#include " mlir/IR/OwningOpRef.h"
17+ #include " mlir/IR/PatternMatch.h"
18+ #include " mlir/Interfaces/LoopLikeInterface.h"
1419#include " gtest/gtest.h"
1520
1621using namespace mlir ;
@@ -23,7 +28,8 @@ using namespace mlir::scf;
2328class SCFLoopLikeTest : public ::testing::Test {
2429protected:
2530 SCFLoopLikeTest () : b(&context), loc(UnknownLoc::get(&context)) {
26- context.loadDialect <arith::ArithDialect, scf::SCFDialect>();
31+ context.loadDialect <affine::AffineDialect, arith::ArithDialect,
32+ scf::SCFDialect>();
2733 }
2834
2935 void checkUnidimensional (LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +94,24 @@ class SCFLoopLikeTest : public ::testing::Test {
8894 EXPECT_EQ ((*maybeInductionVars).size (), 2u );
8995 }
9096
97+ void checkNormalized (LoopLikeOpInterface loopLikeOp) {
98+ std::optional<SmallVector<OpFoldResult>> maybeLb =
99+ loopLikeOp.getLoopLowerBounds ();
100+ ASSERT_TRUE (maybeLb.has_value ());
101+ std::optional<SmallVector<OpFoldResult>> maybeStep =
102+ loopLikeOp.getLoopSteps ();
103+ ASSERT_TRUE (maybeStep.has_value ());
104+
105+ auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
106+ return llvm::all_of (results, [&](OpFoldResult ofr) {
107+ auto intValue = getConstantIntValue (ofr);
108+ return intValue.has_value () && intValue == val;
109+ });
110+ };
111+ EXPECT_TRUE (allEqual (*maybeLb, 0 ));
112+ EXPECT_TRUE (allEqual (*maybeStep, 1 ));
113+ }
114+
91115 MLIRContext context;
92116 OpBuilder b;
93117 Location loc;
@@ -138,3 +162,36 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
138162 ValueRange ({step->getResult (), step->getResult ()}), ValueRange ());
139163 checkMultidimensional (parallelOp.get ());
140164}
165+
166+ TEST_F (SCFLoopLikeTest, testForallNormalize) {
167+ OwningOpRef<arith::ConstantIndexOp> lb =
168+ b.create <arith::ConstantIndexOp>(loc, 1 );
169+ OwningOpRef<arith::ConstantIndexOp> ub =
170+ b.create <arith::ConstantIndexOp>(loc, 10 );
171+ OwningOpRef<arith::ConstantIndexOp> step =
172+ b.create <arith::ConstantIndexOp>(loc, 3 );
173+
174+ scf::ForallOp forallOp = b.create <scf::ForallOp>(
175+ loc, ArrayRef<OpFoldResult>({lb->getResult (), lb->getResult ()}),
176+ ArrayRef<OpFoldResult>({ub->getResult (), ub->getResult ()}),
177+ ArrayRef<OpFoldResult>({step->getResult (), step->getResult ()}),
178+ ValueRange (), std::nullopt );
179+ // Create a user of the induction variable. Bitcast is chosen for simplicity
180+ // since it is unary.
181+ b.setInsertionPointToStart (forallOp.getBody ());
182+ b.create <arith::BitcastOp>(UnknownLoc::get (&context), b.getF64Type (),
183+ forallOp.getInductionVar (0 ));
184+ IRRewriter rewriter (b);
185+ FailureOr<scf::ForallOp> maybeNormalizedForallOp =
186+ normalizeForallOp (rewriter, forallOp);
187+ EXPECT_TRUE (succeeded (maybeNormalizedForallOp));
188+ OwningOpRef<scf::ForallOp> normalizedForallOp (*maybeNormalizedForallOp);
189+ checkNormalized (normalizedForallOp.get ());
190+
191+ // Check that the IV user has been updated to use the denormalized variable.
192+ Block *body = normalizedForallOp->getBody ();
193+ auto bitcastOps = body->getOps <arith::BitcastOp>();
194+ ASSERT_EQ (std::distance (bitcastOps.begin (), bitcastOps.end ()), 1 );
195+ arith::BitcastOp ivUser = *bitcastOps.begin ();
196+ ASSERT_NE (ivUser.getIn (), normalizedForallOp->getInductionVar (0 ));
197+ }
0 commit comments