Skip to content

Commit

Permalink
Add EvalTranspose pattern to StablehloAggressiveFolder
Browse files Browse the repository at this point in the history
  • Loading branch information
mvpant committed Oct 1, 2024
1 parent 9e407a9 commit b3758fc
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
53 changes: 53 additions & 0 deletions stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,56 @@ func.func @eval_convert_f64_precision_loss() -> (tensor<1xf32>, tensor<f32>) {
%3 = stablehlo.convert %1 : (tensor<f64>) -> tensor<f32>
func.return %2, %3 : tensor<1xf32>, tensor<f32>
}

// -----

// CHECK-LABEL: func @eval_transpose
func.func @eval_transpose() -> (tensor<2x3x2xi32>, tensor<2x4x3xi32>, tensor<4x3x2xi32>) {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<
// CHECK-SAME: {{\[\[}}[1, 7], [3, 9], [5, 11]],
// CHECK-SAME: {{\[}}[2, 8], [4, 10], [6, 12]]]> : tensor<2x3x2xi32>
//
// CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<
// CHECK-SAME: {{\[\[}}[1, 3, 5], [7, 9, 11], [13, 15, 17], [19, 21, 23]],
// CHECK-SAME: {{\[}}[2, 4, 6], [8, 10, 12], [14, 16, 18], [20, 22, 24]]]> : tensor<2x4x3xi32>
//
// CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<
// CHECK-SAME: {{\[\[}}[1, 2], [3, 4], [5, 6]]
// CHECK-SAME: {{\[}}[7, 8], [9, 10], [11, 12]],
// CHECK-SAME: {{\[}}[13, 14], [15, 16], [17, 18]],
// CHECK-SAME: {{\[}}[19, 20], [21, 22], [23, 24]]]> : tensor<4x3x2xi32>
//
// CHECK: return [[RESULT0]], [[RESULT1]], [[RESULT2]]
%0 = stablehlo.constant dense<[[[1,2], [3,4], [5,6]],
[[7,8], [9,10], [11,12]]]> : tensor<2x3x2xi32>
%1 = stablehlo.constant dense<[[[1, 2], [3, 4], [5, 6]],
[[7, 8], [9, 10], [11,12]],
[[13,14], [15,16], [17,18]],
[[19,20], [21,22], [23,24]]]> : tensor<4x3x2xi32>
%2 = stablehlo.transpose %0, dims = [2, 1, 0] : (tensor<2x3x2xi32>) -> tensor<2x3x2xi32>
%3 = stablehlo.transpose %1, dims = [2, 0, 1] : (tensor<4x3x2xi32>) -> tensor<2x4x3xi32>
%4 = stablehlo.transpose %3, dims = [1, 2, 0] : (tensor<2x4x3xi32>) -> tensor<4x3x2xi32>
func.return %2, %3, %4 : tensor<2x3x2xi32>, tensor<2x4x3xi32>, tensor<4x3x2xi32>
}

// -----

// CHECK-LABEL: func @eval_transpose_zerodim
func.func @eval_transpose_zerodim() -> (tensor<10x3x0xf32>) {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<> : tensor<10x3x0xf32>
// CHECK: return [[RESULT0]]
%0 = stablehlo.constant dense<> : tensor<3x0x10xf32>
%1 = stablehlo.transpose %0, dims = [2, 0, 1] : (tensor<3x0x10xf32>) -> tensor<10x3x0xf32>
func.return %1 : tensor<10x3x0xf32>
}

// -----

// CHECK-LABEL: func @eval_transpose_zerorank
func.func @eval_transpose_zerorank() -> tensor<i32> {
// CHECK: [[RESULT0:%.*]] = stablehlo.constant dense<1> : tensor<i32>
// CHECK: return [[RESULT0]]
%0 = stablehlo.constant dense<1> : tensor<i32>
%1 = stablehlo.transpose %0, dims = [] : (tensor<i32>) -> tensor<i32>
func.return %1 : tensor<i32>
}
53 changes: 53 additions & 0 deletions stablehlo/transforms/StablehloAggressiveFolder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,58 @@ struct EvalIotaOpPattern : public OpRewritePattern<IotaOp> {
}
};

template <typename RangeType>
DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) {
using ElementType = std::decay_t<decltype(*std::begin(data))>;

RankedTensorType operandType = op.getOperand().getType();
RankedTensorType resultType = op.getResult().getType();

const auto operandStrides = computeStrides(operandType.getShape());
const auto resultStrides = computeStrides(resultType.getShape());
const auto inversePermutation = invertPermutationVector(op.getPermutation());

SmallVector<ElementType> result;
result.reserve(resultType.getNumElements());

for (int64_t i = 0; i < resultType.getNumElements(); ++i) {
auto dstDimIndex = delinearize(i, resultStrides);
auto srcDimIndex = applyPermutation(dstDimIndex, inversePermutation);
auto srcLinearIndex = linearize(srcDimIndex, operandStrides);
result.push_back(data[srcLinearIndex]);
}

return DenseElementsAttr::get(op.getResult().getType(),
ArrayRef<ElementType>(result));
}

struct EvalTransposeOpPattern : public OpRewritePattern<TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransposeOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (failed(validateResultTypeForEval(rewriter, op, resultType)))
return failure();

ElementsAttr els;
if (!matchPattern(op.getOperand(), m_Constant(&els)))
return rewriter.notifyMatchFailure(
op, "expected constant integer or float operand");

DenseElementsAttr resAttr;
if (auto data = els.tryGetValues<APInt>())
resAttr = transposeType(op, *data);
else if (auto data = els.tryGetValues<APFloat>())
resAttr = transposeType(op, *data);
else
return rewriter.notifyMatchFailure(op.getLoc(),
"unsupported element type");

rewriter.replaceOpWithNewOp<ConstantOp>(op, resAttr);
return success();
}
};

struct StablehloAggressiveFolderPass
: public impl::StablehloAggressiveFolderPassBase<
StablehloAggressiveFolderPass> {
Expand Down Expand Up @@ -672,6 +724,7 @@ void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns,
bool foldFloat) {
populateStablehloShapeFolderPatterns(patterns, context, foldFloat);
patterns->add<EvalIotaOpPattern>(context);
patterns->add<EvalTransposeOpPattern>(context);
}

void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns,
Expand Down

0 comments on commit b3758fc

Please sign in to comment.