Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tcp.gather_nd and rework index.Tensor_hacked_twin to use gather_nd #101

Merged
merged 5 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,30 @@ def Tcp_GatherOp : Tcp_Op<"gather", [Pure, AllElementTypesMatch<["input", "out"]
let hasVerifier = 1;
}

def Tcp_GatherNDOp : Tcp_Op<"gather_nd", [Pure, AllElementTypesMatch<["input", "out"]>]> {

let summary = "Gather elements from input based on indices over multiple dimensions";

let description = [{
Gathers elements from a given tensor based on indices that index along multiple dimensions.

More details regarding this op: docs/gather.md
}];

let arguments = (ins
Tcp_Tensor:$input,
Tcp_IntTensor:$indices
);

let results = (outs
Tcp_Tensor:$out
);

let assemblyFormat = "$input `,` $indices attr-dict `:` type($input) `,` type($indices) `->` type($out)";

let hasVerifier = 1;
}

def Tcp_SliceOp : Tcp_Op<"slice", [Pure, AllElementTypesMatch<["in", "out"]>, SameVariadicOperandSize]> {

let summary = "Extracts a slice of the input tensor";
Expand Down
97 changes: 97 additions & 0 deletions lib/Conversion/TcpToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,101 @@ class ConvertGatherOp : public OpConversionPattern<GatherOp> {
}
};

/**
* tcp.gather_nd is lowered to linalg.generic, which allows us to define every
* element in the result tensor using a programmatic expression. The last
* dimension of the indicies tensor is used to index into the input tensor.
*
* For example, we have an indices tensor of shape 9x4x3x2 and an input
* tensor of shape 5x6x7x8, then the resulting tensor will be of shape
* 9x4x3x7x8. Where the first three dimensions of the resulting tensor are used
* to index into the indicies tensor. Then the last dimension of the index
* tensor (the 2 sized dimension) is used to index into the input tensor.
*/
class ConvertGatherNDOp : public OpConversionPattern<GatherNDOp> {
matthewfl marked this conversation as resolved.
Show resolved Hide resolved
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(GatherNDOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTensorType = getTypeConverter()
->convertType(op.getOut().getType())
.cast<RankedTensorType>();

auto inputTensor = adaptor.getInput();
auto indicesTensor = adaptor.getIndices();
auto indicesType = cast<RankedTensorType>(indicesTensor.getType());
auto inputType = cast<RankedTensorType>(inputTensor.getType());
int numGatherAxes = indicesType.getShape().back();

SmallVector<Value> resultDimSizes;
for (int i = 0; i < indicesType.getRank() - 1; i++) {
resultDimSizes.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, indicesTensor, i));
}
for (int i = numGatherAxes; i < inputType.getRank(); i++) {
resultDimSizes.push_back(
rewriter.createOrFold<tensor::DimOp>(loc, inputTensor, i));
}

assert(resultDimSizes.size() == resultTensorType.getRank());

Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, getAsOpFoldResult(resultDimSizes),
resultTensorType.getElementType());

auto bodyBuilder = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) {
SmallVector<Value> valueIndices, gatherIndices;
for (int i = 0; i < indicesType.getRank() - 1; i++) {
auto idx = b.create<linalg::IndexOp>(loc, b.getIndexType(),
b.getI64IntegerAttr(i));
gatherIndices.push_back(idx);
}
for (int i = 0; i < numGatherAxes; i++) {
SmallVector<Value> gi = gatherIndices;
auto gidx = b.create<arith::ConstantOp>(loc, b.getIndexAttr(i));
gi.push_back(gidx);
assert(gi.size() == indicesType.getRank());
auto idxExtract = b.create<tensor::ExtractOp>(
loc, indicesType.getElementType(), indicesTensor, gi);
auto idxCast =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), idxExtract);
valueIndices.push_back(idxCast);
}
for (int i = indicesType.getRank() - 1; i < resultTensorType.getRank();
i++) {
auto idx = b.create<linalg::IndexOp>(loc, b.getIndexType(),
b.getI64IntegerAttr(i));
valueIndices.push_back(idx);
}
assert(valueIndices.size() == inputType.getRank());
auto extract =
b.create<tensor::ExtractOp>(loc, resultTensorType.getElementType(),
inputTensor, valueIndices)
.getResult();

b.create<linalg::YieldOp>(loc, extract);
};

SmallVector<Value> empty;
SmallVector<AffineMap> indexingMaps;
indexingMaps.push_back(
rewriter.getMultiDimIdentityMap(resultTensorType.getRank()));
SmallVector<utils::IteratorType> iteratorTypes(
resultTensorType.getRank(), utils::IteratorType::parallel);

auto generic = rewriter.create<linalg::GenericOp>(
loc, resultTensorType, empty, emptyTensor, indexingMaps, iteratorTypes,
bodyBuilder);

rewriter.replaceOp(op, generic.getResult(0));

return success();
}
};

} // namespace

void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality(
Expand All @@ -100,4 +195,6 @@ void mlir::TcpToLinalg::populateDataMovementPatternsAndLegality(

target.addIllegalOp<GatherOp>();
patterns.add<ConvertGatherOp>(typeConverter, context);
target.addIllegalOp<GatherNDOp>();
patterns.add<ConvertGatherNDOp>(typeConverter, context);
}
94 changes: 46 additions & 48 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,75 +278,73 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern<AtenIndexSelectOp> {
}
};

/**
* The index.Tensor_hacked_twin takes a list of tensors which have to be
* broadcast together to be the same shape, and then those are fed into a
* gather which will select the different axes
*/
class ConvertAtenIndexTensorHackedTwin
: public OpConversionPattern<AtenIndexTensorHackedTwinOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// ------- Matching the OP -------
auto self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType());
auto indicesList = op.getIndices();
SmallVector<Value> indices;
if (!getListConstructElements(indicesList, indices))
return op.emitError("Failed to match list of indices");

for (unsigned int i = 0; i < indices.size(); i++) {
auto ttype = cast<RankedTensorType>(
getTypeConverter()->convertType(indices[i].getType()));
if (ttype.getRank() != selfType.getRank() - i) {
// Can use tensor.gather instead for this. But will require that there
// are some broadcasting to get the shapes to match what is expected
return failure("Failed to rewrite Tensor_hacked_twin. Need the "
"element gather for this");
}
for (int j = 1; j < ttype.getRank(); j++) {
if (ttype.getShape()[j] != 1)
return failure("Expected the axes >=1 to have size 1");
indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
indices);

if (auto indiciesBroadcasted = torch_to_tcp::broadcastManyToMatchShape(
rewriter, op.getLoc(), indices)) {
indices = indiciesBroadcasted.value();
} else {
return failure("failed to broadcast the shapes of the input indicies");
}

for (int i = 0; i < indices.size(); i++) {
Value v =
torch_to_tcp::broadcastRankInTrailingDims(rewriter, indices[i], 1);
if (!cast<RankedTensorType>(v.getType()).getElementType().isInteger(64)) {
v = rewriter.createOrFold<tcp::CastOp>(
op.getLoc(),
RankedTensorType::get(
cast<RankedTensorType>(v.getType()).getShape(),
rewriter.getI64Type()),
v, SignednessAttr::get(op->getContext(), Signedness::Signed),
SignednessAttr::get(op->getContext(), Signedness::Signless));
}
indices[i] = v;
}

// ------ Rewriting the OP ---------
auto indicesType = cast<RankedTensorType>(indices[0].getType());
int indicesRank = indicesType.getRank();
SmallVector<int64_t> outIndexShape;
outIndexShape.insert(outIndexShape.begin(), indicesType.getShape().begin(),
indicesType.getShape().end());
outIndexShape.back() = indices.size();

indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
indices);
auto outIndexType =
RankedTensorType::get(outIndexShape, indicesType.getElementType());
auto indexTensor =
rewriter
.create<tensor::ConcatOp>(
op.getLoc(), outIndexType,
rewriter.getI64IntegerAttr(indicesRank - 1), indices)
.getResult();

for (unsigned int i = 0; i < indices.size(); i++) {
auto idx = indices[i];
auto ttype = cast<RankedTensorType>(idx.getType());
auto selfType = cast<RankedTensorType>(self.getType());
SmallVector<int64_t> outShape(selfType.getShape());
outShape[i] = ttype.getNumElements();
auto outType = RankedTensorType::get(
outShape, cast<RankedTensorType>(self.getType()).getElementType());

auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(
rewriter, idx, outShape.size() - ttype.getRank());

SmallVector<Value> broadcastValues;
SmallVector<int64_t> broadcastAxes;
for (unsigned int j = 0; j < selfType.getRank(); j++) {
if (j != i) {
broadcastAxes.push_back(j);
broadcastValues.push_back(
rewriter.create<tensor::DimOp>(op.getLoc(), self, j));
}
}
auto outType =
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));

auto broadcastedShape = rewriter.create<tcp::BroadcastOp>(
op.getLoc(), RankedTensorType::get(outShape, ttype.getElementType()),
expandedShape, broadcastValues,
rewriter.getI64ArrayAttr(broadcastAxes));
auto gatherOp = rewriter.create<tcp::GatherNDOp>(op.getLoc(), outType, self,
indexTensor);

auto gather = rewriter.create<tcp::GatherOp>(op.getLoc(), outType, self,
broadcastedShape.getResult(),
rewriter.getIndexAttr(i));
self = gather.getResult();
}
rewriter.replaceOp(op, gatherOp);

rewriter.replaceOp(op, self);
return success();
}
};
Expand Down
118 changes: 118 additions & 0 deletions lib/Conversion/TorchToTcp/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,32 @@ Value broadcastRankInLeadingDims(ConversionPatternRewriter &rewriter,
input.getDefiningOp()->getLoc(), resultType, input, reassociationMap);
}

// The parameter input is expected to be of RankedTensorType.
Value broadcastRankInTrailingDims(ConversionPatternRewriter &rewriter,
Value input, int64_t rankIncrease) {
if (rankIncrease == 0)
return input;
RankedTensorType inputType = input.getType().cast<RankedTensorType>();

SmallVector<ReassociationExprs> reassociationMap(inputType.getRank());
if (inputType.getRank() > 0) {
for (int64_t inputAxis = 0; inputAxis < inputType.getRank(); inputAxis++)
reassociationMap[inputAxis].push_back(
rewriter.getAffineDimExpr(inputAxis));
for (int64_t axis = 0; axis < rankIncrease; axis++)
reassociationMap.back().push_back(
rewriter.getAffineDimExpr(axis + inputType.getRank()));
}

SmallVector<int64_t> resultShape(inputType.getShape());
resultShape.insert(resultShape.end(), rankIncrease, 1);
auto resultType =
inputType.cloneWith(ArrayRef(resultShape), inputType.getElementType());

return rewriter.create<tensor::ExpandShapeOp>(
input.getDefiningOp()->getLoc(), resultType, input, reassociationMap);
}

Value broadcastRank0Dor1DToND(ConversionPatternRewriter &rewriter, Value input,
int64_t targetRank, int64_t axisInOutput) {
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Expand Down Expand Up @@ -130,6 +156,98 @@ Value broadcastShapeExceptDims(ConversionPatternRewriter &rewriter, Value input,
axesAttr);
}

// the parameter values is expected to be an array of RankedTensorType tensors
std::optional<SmallVector<Value>>
broadcastManyToMatchShape(ConversionPatternRewriter &rewriter, Location loc,
ValueRange values) {
if (values.size() <= 1) {
return values;
}
SmallVector<Value> ret;

int64_t maxRank = 0;
for (auto v : values) {
assert(isa<RankedTensorType>(v.getType()) && "assert 1");
auto t = cast<RankedTensorType>(v.getType());
if (t.getRank() > maxRank)
maxRank = t.getRank();
}

for (auto v : values) {
auto type = cast<RankedTensorType>(v.getType());
v = broadcastRankInLeadingDims(rewriter, v, maxRank - type.getRank());
ret.push_back(v);
}

// figure out what the shape should be for each dim
struct DimInfo {
Value value;
bool found = false;
int64_t staticValue = 1;
};
SmallVector<DimInfo> resultShape(maxRank);

for (auto v : ret) {
auto t = cast<RankedTensorType>(v.getType());
auto shape = t.getShape();
for (int64_t i = 0; i < maxRank; i++) {
if (shape[i] != 1) {
// meaning that this is not something that is already 1, and therefore
// would get broadcast
if (resultShape[i].found) {
// then there are multiple inputs which have non-1 values for this
// axis we should check that the size is the same. If there are
// different shapes then this would result in an error when
// broadcasting
if (shape[i] != ShapedType::kDynamic &&
resultShape[i].staticValue != ShapedType::kDynamic &&
resultShape[i].staticValue != shape[i]) {
// the broadcast failed as there are two different shapes for this
llvm::errs()
<< "failed with broadcasting, have two different shapes "
<< shape[i] << " " << resultShape[i].staticValue << "\n";
return {};
}
} else {
resultShape[i].found = true;
if (shape[i] == ShapedType::kDynamic) {
resultShape[i].value = rewriter.create<tensor::DimOp>(loc, v, i);
resultShape[i].staticValue = ShapedType::kDynamic;
} else {
resultShape[i].value = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(shape[i]));
resultShape[i].staticValue = shape[i];
}
}
}
}
}

// do the broadcasts into the shapes
for (int64_t i = 0; i < ret.size(); i++) {
auto v = ret[i];
auto t = cast<RankedTensorType>(v.getType());
SmallVector<int64_t> axes;
SmallVector<Value> sizes;
SmallVector<int64_t> staticShape;
for (int64_t j = 0; j < maxRank; j++) {
if (t.getShape()[j] == 1 && resultShape[j].found) {
axes.push_back(j);
sizes.push_back(resultShape[j].value);
}
staticShape.push_back(resultShape[j].staticValue);
}
if (!axes.empty()) {
// there is something to broadcast here, so add the op
Type resultType = t.cloneWith(staticShape, t.getElementType());
ret[i] = rewriter.create<tcp::BroadcastOp>(
loc, resultType, ret[i], sizes, rewriter.getI64ArrayAttr(axes));
}
}

return ret;
}

// The parameters input are expected to be of RankedTensorType.
std::pair<Value, Value>
broadcastToMatchShape(ConversionPatternRewriter &rewriter, Value lhs,
Expand Down
Loading