diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index cc81f9d19aca7..99f356b203de4 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -784,6 +784,33 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) { return nullptr; } +/// Get the canonical order of two commutative exprs arguments. +static std::pair +orderCommutativeArgs(AffineExpr expr1, AffineExpr expr2) { + auto sym1 = dyn_cast(expr1); + auto sym2 = dyn_cast(expr2); + // Try to order by symbol/dim position first. + if (sym1 && sym2) + return sym1.getPosition() < sym2.getPosition() ? std::pair{expr1, expr2} + : std::pair{expr2, expr1}; + + auto dim1 = dyn_cast(expr1); + auto dim2 = dyn_cast(expr2); + if (dim1 && dim2) + return dim1.getPosition() < dim2.getPosition() ? std::pair{expr1, expr2} + : std::pair{expr2, expr1}; + + // Put dims before symbols. + if (dim1 && sym2) + return {dim1, sym2}; + + if (sym1 && dim2) + return {dim2, sym1}; + + // Otherwise, keep original order. + return {expr1, expr2}; +} + AffineExpr AffineExpr::operator+(int64_t v) const { return *this + getAffineConstantExpr(v, getContext()); } @@ -791,9 +818,11 @@ AffineExpr AffineExpr::operator+(AffineExpr other) const { if (auto simplified = simplifyAdd(*this, other)) return simplified; + auto [lhs, rhs] = orderCommutativeArgs(*this, other); + StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( - /*initFn=*/{}, static_cast(AffineExprKind::Add), *this, other); + /*initFn=*/{}, static_cast(AffineExprKind::Add), lhs, rhs); } /// Simplify a multiply expression. Return nullptr if it can't be simplified. @@ -856,9 +885,11 @@ AffineExpr AffineExpr::operator*(AffineExpr other) const { if (auto simplified = simplifyMul(*this, other)) return simplified; + auto [lhs, rhs] = orderCommutativeArgs(*this, other); + StorageUniquer &uniquer = getContext()->getAffineUniquer(); return uniquer.get( - /*initFn=*/{}, static_cast(AffineExprKind::Mul), *this, other); + /*initFn=*/{}, static_cast(AffineExprKind::Mul), lhs, rhs); } // Unary minus, delegate to operator*. diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir index 6f2737a982752..653c2cb521637 100644 --- a/mlir/test/Dialect/Affine/simplify-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-structures.mlir @@ -508,7 +508,7 @@ func.func @test_not_trivially_true_or_false_returning_three_results() -> (index, // ----- // Test simplification of mod expressions. -// CHECK-DAG: #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 * s1 + (s0 - s1) mod s2)> +// CHECK-DAG: #[[$MOD:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s4 + s3 + (s0 - s1) mod s2)> // CHECK-DAG: #[[$SIMPLIFIED_MOD_RHS:.*]] = affine_map<()[s0, s1, s2, s3] -> (s3 mod (s2 - s0 * s1))> // CHECK-DAG: #[[$MODULO_AND_PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3] -> (s0 * s1 + s3 - (-s0 + s3) mod s2)> // CHECK-LABEL: func @semiaffine_simplification_mod @@ -547,7 +547,7 @@ func.func @semiaffine_simplification_floordiv_and_ceildiv(%arg0: index, %arg1: i // Test simplification of product expressions. // CHECK-DAG: #[[$PRODUCT:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s3 + s4 + (s0 - s1) * s2)> -// CHECK-DAG: #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s2 + s2 * s0 + s3 + s3 * s0 + s3 * s1 + s4 + s4 * s1)> +// CHECK-DAG: #[[$SUM_OF_PRODUCTS:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> (s0 * s2 + s0 * s3 + s1 * s3 + s1 * s4 + s2 + s3 + s4)> // CHECK-LABEL: func @semiaffine_simplification_product // CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index) func.func @semiaffine_simplification_product(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index) { diff --git a/mlir/test/IR/affine-map.mlir b/mlir/test/IR/affine-map.mlir index 977aec2536b1e..6277b28561f36 100644 --- a/mlir/test/IR/affine-map.mlir +++ b/mlir/test/IR/affine-map.mlir @@ -139,7 +139,7 @@ #map44 = affine_map<(i, j) -> (i - 2*j, j * 6 floordiv 4)> // Simplifications -// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d2 + d1, (d0 * s0) * 8)> +// CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2)[s0] -> (d0 + d1 + d2 + 1, d1 + d2, (d0 * s0) * 8)> #map45 = affine_map<(i, j, k) [N] -> (1 + i + 3 + j - 3 + k, k + 5 + j - 5, 2*i*4*N)> // CHECK: #map{{[0-9]*}} = affine_map<(d0, d1, d2) -> (0, d1, d0 * 2, 0)> diff --git a/mlir/unittests/IR/AffineExprTest.cpp b/mlir/unittests/IR/AffineExprTest.cpp index 8a2d697540d5c..f8494ecb971c2 100644 --- a/mlir/unittests/IR/AffineExprTest.cpp +++ b/mlir/unittests/IR/AffineExprTest.cpp @@ -84,6 +84,20 @@ TEST(AffineExprTest, constantFolding) { ASSERT_EQ(cminfloordivcn1.getKind(), AffineExprKind::FloorDiv); } +TEST(AffineExprTest, commutative) { + MLIRContext ctx; + OpBuilder b(&ctx); + auto c2 = b.getAffineConstantExpr(1); + auto d0 = b.getAffineDimExpr(0); + auto d1 = b.getAffineDimExpr(1); + auto s0 = b.getAffineSymbolExpr(0); + auto s1 = b.getAffineSymbolExpr(1); + + ASSERT_EQ(d0 * d1, d1 * d0); + ASSERT_EQ(s0 + s1, s1 + s0); + ASSERT_EQ(s0 * c2, c2 * s0); +} + TEST(AffineExprTest, divisionSimplification) { MLIRContext ctx; OpBuilder b(&ctx); @@ -147,3 +161,12 @@ TEST(AffineExprTest, simpleAffineExprFlattenerRegression) { ASSERT_TRUE(isa(result)); ASSERT_EQ(cast(result).getValue(), 7); } + +TEST(AffineExprTest, simplifyCommutative) { + MLIRContext ctx; + OpBuilder b(&ctx); + auto s0 = b.getAffineSymbolExpr(0); + auto s1 = b.getAffineSymbolExpr(1); + + ASSERT_EQ(s0 * s1 - s1 * s0 + 1, 1); +}