diff --git a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir index df87f70469..f04407b985 100644 --- a/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir +++ b/stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir @@ -110,3 +110,56 @@ func.func @eval_convert_f64_precision_loss() -> (tensor<1xf32>, tensor) { %3 = stablehlo.convert %1 : (tensor) -> tensor func.return %2, %3 : tensor<1xf32>, tensor } + +// ----- + +// CHECK-LABEL: func @eval_transpose +func.func @eval_transpose() -> (tensor<2x3x2xi32>, tensor<2x4x3xi32>, tensor<4x3x2xi32>) { + // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense< + // CHECK-SAME: {{\[\[}}[1, 7], [3, 9], [5, 11]], + // CHECK-SAME: {{\[}}[2, 8], [4, 10], [6, 12]]]> : tensor<2x3x2xi32> + // + // CHECK: [[RESULT1:%.*]] = stablehlo.constant dense< + // CHECK-SAME: {{\[\[}}[1, 3, 5], [7, 9, 11], [13, 15, 17], [19, 21, 23]], + // CHECK-SAME: {{\[}}[2, 4, 6], [8, 10, 12], [14, 16, 18], [20, 22, 24]]]> : tensor<2x4x3xi32> + // + // CHECK: [[RESULT2:%.*]] = stablehlo.constant dense< + // CHECK-SAME: {{\[\[}}[1, 2], [3, 4], [5, 6]] + // CHECK-SAME: {{\[}}[7, 8], [9, 10], [11, 12]], + // CHECK-SAME: {{\[}}[13, 14], [15, 16], [17, 18]], + // CHECK-SAME: {{\[}}[19, 20], [21, 22], [23, 24]]]> : tensor<4x3x2xi32> + // + // CHECK: return [[RESULT0]], [[RESULT1]], [[RESULT2]] + %0 = stablehlo.constant dense<[[[1,2], [3,4], [5,6]], + [[7,8], [9,10], [11,12]]]> : tensor<2x3x2xi32> + %1 = stablehlo.constant dense<[[[1, 2], [3, 4], [5, 6]], + [[7, 8], [9, 10], [11,12]], + [[13,14], [15,16], [17,18]], + [[19,20], [21,22], [23,24]]]> : tensor<4x3x2xi32> + %2 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32> + %3 = stablehlo.transpose %1, dims = [2, 0, 1] : (tensor<4x3x2xi32>) -> tensor<2x4x3xi32> + %4 = stablehlo.transpose %3, dims = [1, 2, 0] : (tensor<2x4x3xi32>) -> tensor<4x3x2xi32> + func.return %2, %3, %4 : tensor<2x3x2xi32>, tensor<2x4x3xi32>, tensor<4x3x2xi32> +} + +// ----- + +// CHECK-LABEL: func @eval_transpose_zerodim +func.func @eval_transpose_zerodim() -> (tensor<10x3x0xf32>) { + // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<> : tensor<10x3x0xf32> + // CHECK: return [[RESULT0]] + %0 = stablehlo.constant dense<> : tensor<3x0x10xf32> + %1 = stablehlo.transpose %0, dims = [2, 0, 1] : (tensor<3x0x10xf32>) -> tensor<10x3x0xf32> + func.return %1 : tensor<10x3x0xf32> +} + +// ----- + +// CHECK-LABEL: func @eval_transpose_zerorank +func.func @eval_transpose_zerorank() -> tensor { + // CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<1> : tensor + // CHECK: return [[RESULT0]] + %0 = stablehlo.constant dense<1> : tensor + %1 = stablehlo.transpose %0, dims = [] : (tensor) -> tensor + func.return %1 : tensor +} diff --git a/stablehlo/transforms/StablehloAggressiveFolder.cpp b/stablehlo/transforms/StablehloAggressiveFolder.cpp index a5768f4c9f..510b1864d1 100644 --- a/stablehlo/transforms/StablehloAggressiveFolder.cpp +++ b/stablehlo/transforms/StablehloAggressiveFolder.cpp @@ -643,6 +643,58 @@ struct EvalIotaOpPattern : public OpRewritePattern { } }; +template +DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) { + using ElementType = std::decay_t; + + RankedTensorType operandType = op.getOperand().getType(); + RankedTensorType resultType = op.getResult().getType(); + + const auto operandStrides = computeStrides(operandType.getShape()); + const auto resultStrides = computeStrides(resultType.getShape()); + const auto inversePermutation = invertPermutationVector(op.getPermutation()); + + SmallVector result; + result.reserve(resultType.getNumElements()); + + for (int64_t i = 0; i < resultType.getNumElements(); ++i) { + auto dstDimIndex = delinearize(i, resultStrides); + auto srcDimIndex = applyPermutation(dstDimIndex, inversePermutation); + auto srcLinearIndex = linearize(srcDimIndex, operandStrides); + result.push_back(data[srcLinearIndex]); + } + + return DenseElementsAttr::get(op.getResult().getType(), + ArrayRef(result)); +} + +struct EvalTransposeOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + ElementsAttr els; + if (!matchPattern(op.getOperand(), m_Constant(&els))) + return rewriter.notifyMatchFailure( + op, "expected constant integer or float operand"); + + DenseElementsAttr resAttr; + if (auto data = els.tryGetValues()) + resAttr = transposeType(op, *data); + else if (auto data = els.tryGetValues()) + resAttr = transposeType(op, *data); + else + return rewriter.notifyMatchFailure(op.getLoc(), + "unsupported element type"); + + rewriter.replaceOpWithNewOp(op, resAttr); + return success(); + } +}; + struct StablehloAggressiveFolderPass : public impl::StablehloAggressiveFolderPassBase< StablehloAggressiveFolderPass> { @@ -672,6 +724,7 @@ void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns, bool foldFloat) { populateStablehloShapeFolderPatterns(patterns, context, foldFloat); patterns->add(context); + patterns->add(context); } void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns,