Skip to content

Commit

Permalink
[Torch] support decompose aten.einsum with ellipsis slicing (llvm#3056)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xinyu Yang authored Mar 27, 2024
1 parent 5f32574 commit e6e7689
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 11 deletions.
134 changes: 124 additions & 10 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include <cstdint>
#include <set>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -158,6 +159,105 @@ static SmallVector<int64_t> computeDimsOrderForMoveDim(int64_t srcDimInt,
return dimsOrder;
}

static bool
rewriteEquationWithEllipsisSlicing(std::string &equation,
SmallVector<int64_t> &inputRanks) {
// split equation into input and result
size_t arrowPos = equation.find("->");
if (arrowPos == std::string::npos) {
return false;
}
std::string inputStr = equation.substr(0, arrowPos);
std::string resultStr = equation.substr(arrowPos + 2);

// split input into tokens
SmallVector<std::string> inputTokens;
size_t start = 0;
size_t end = 0;
std::set<char> usedTokens;
while (end < inputStr.size()) {
end = inputStr.find(",", start);
if (end == std::string::npos) {
end = inputStr.size();
}
std::string token = inputStr.substr(start, end - start);
inputTokens.push_back(token);
start = end + 1;
}
if (inputTokens.size() != inputRanks.size()) {
return false;
}

// find the rank which ellipsis represents, and max ellipsis rank because a
// tensor can be broadcasted
SmallVector<int64_t> ellipsisRanks;
int maxEllipsisRank = 0;
for (const auto &[token, inputRank] : llvm::zip(inputTokens, inputRanks)) {
int explictRank = 0;
for (auto c : token) {
if (std::isalpha(c)) {
usedTokens.insert(c);
explictRank++;
} else if (c == '.' || c == ' ') {
continue;
} else {
return false;
}
}
int ellipsisRank = inputRank - explictRank;
if (ellipsisRank > maxEllipsisRank) {
maxEllipsisRank = ellipsisRank;
}
if (ellipsisRank < 0) {
return false;
}
ellipsisRanks.push_back(inputRank - explictRank);
}

auto isTokenUsed = [&usedTokens](char c) {
return usedTokens.find(c) != usedTokens.end();
};
std::string ellipsisToken;
int usedCount = 0;
// Iterate over the alphabet to create a new token for ellipsis
for (char c = 'a'; c <= 'z'; ++c) {
if (!isTokenUsed(c)) {
ellipsisToken.push_back(c);
usedCount++;
if (usedCount == maxEllipsisRank) {
break;
}
}
}

// replace ellipsis with ellipsisToken
for (size_t i = 0; i < inputTokens.size(); i++) {
size_t ellipsisPos = inputTokens[i].find("...");
if (ellipsisPos == std::string::npos) {
continue;
}
if (ellipsisRanks[i] == maxEllipsisRank) {
inputTokens[i].replace(ellipsisPos, 3, ellipsisToken);
} else if (ellipsisRanks[i] == 0) {
inputTokens[i].replace(ellipsisPos, 3, "");
} else {
inputTokens[i].replace(
ellipsisPos, 3,
ellipsisToken.substr(ellipsisToken.size() - ellipsisRanks[i]));
}
}

// replace ellipsis in result
size_t ellipsisPos = resultStr.find("...");
if (ellipsisPos != std::string::npos) {
resultStr.replace(ellipsisPos, 3, ellipsisToken);
}

// join input and result
equation = llvm::join(inputTokens, ",") + " -> " + resultStr;
return true;
}

static bool parseEquation(const std::string &equation,
SmallVector<SmallVector<char>> &inputTokens,
SmallVector<char> &resultTokens) {
Expand Down Expand Up @@ -1135,16 +1235,6 @@ class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
LogicalResult matchAndRewrite(AtenEinsumOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
std::string equation;
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
}
SmallVector<char> resultTokens;
SmallVector<SmallVector<char>> inputTokens;
if (!parseEquation(equation, inputTokens, resultTokens)) {
return rewriter.notifyMatchFailure(
op, "Unexpected character in equations encountered");
}

SmallVector<Value> inputTensors;
if (!getListConstructElements(op.getTensors(), inputTensors)) {
Expand All @@ -1164,6 +1254,30 @@ class DecomposeAtenEinsumOp : public OpRewritePattern<AtenEinsumOp> {
"all input tensors should have sizes");
}

std::string equation;
if (!matchPattern(op.getEquation(), m_TorchConstantStr(equation))) {
return rewriter.notifyMatchFailure(op, "Unsupported value of equation");
}
// if "..." in equation, modify it
if (equation.find("...") != std::string::npos) {
SmallVector<int64_t> inputRanks;
for (Value tensor : inputTensors) {
auto type = tensor.getType().cast<BaseTensorType>();
inputRanks.push_back(type.getSizes().size());
}

if (!rewriteEquationWithEllipsisSlicing(equation, inputRanks)) {
return rewriter.notifyMatchFailure(
op, "Unexpected character in equations encountered");
}
}
SmallVector<char> resultTokens;
SmallVector<SmallVector<char>> inputTokens;
if (!parseEquation(equation, inputTokens, resultTokens)) {
return rewriter.notifyMatchFailure(
op, "Unexpected character in equations encountered");
}

SmallVector<char> lhsTokens = inputTokens[0];
Value lhs = inputTensors[0];
Value result;
Expand Down
6 changes: 6 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,8 @@
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
Expand Down Expand Up @@ -954,6 +956,8 @@
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
"ElementwiseAbsFloatModule_basic",
"ElementwiseAbsIntModule_basic",
"ElementwiseAddModule_basic",
Expand Down Expand Up @@ -1923,6 +1927,8 @@
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",

# Failure - onnx_lowering: onnx.MaxPool
"MaxPool2dWithIndicesAllNegativeValuesModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1099,4 +1099,39 @@ def forward(self, tensor1, tensor2):

@register_test_case(module_factory=lambda: EinsumStaticContractRhsModule())
def EinsumStaticContractRhsModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5))
module.forward(tu.rand(3, 4, 5), tu.rand(4, 5))

class EinsumStaticWithEllipsisSlicingModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([3, 4, 6], torch.float32, True),
([3, 6, 5], torch.float32, True),
])
def forward(self, tensor1, tensor2):
return torch.ops.aten.einsum('...mn,...nd->...md', [tensor1, tensor2])

@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingModule())
def EinsumStaticWithEllipsisSlicingModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 6), tu.rand(3, 6, 5))

class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([2, 6, 4, 5], torch.float32, True),
([6, 5], torch.float32, True),
])
def forward(self, tensor1, tensor2):
# should be abnd,bd -> abn
return torch.ops.aten.einsum('...nd,...d->...n', [tensor1, tensor2])

@register_test_case(module_factory=lambda: EinsumStaticWithEllipsisSlicingAndBroadcastModule())
def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5))

0 comments on commit e6e7689

Please sign in to comment.