Skip to content

Commit

Permalink
add casting of int types for resources
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewfl committed Sep 20, 2024
1 parent 0e8e84a commit 60fdf0d
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 43 deletions.
1 change: 1 addition & 0 deletions include/mlir-tcp/Dialect/IR/TcpOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> {
let assemblyFormat = "attr-dict `:` type($out)";

let hasFolder = 1;
let hasVerifier = 1;
}

def Tcp_BroadcastOp : Tcp_Op<"broadcast", [
Expand Down
44 changes: 1 addition & 43 deletions lib/Conversion/TorchToTcp/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,6 @@ class ConvertAtenIndexTensorHackedTwin
indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(),
indices);



for(unsigned int i = 0; i < indices.size(); i++) {
auto idx = indices[i];
auto ttype = cast<RankedTensorType>(idx.getType());
Expand All @@ -323,10 +321,9 @@ class ConvertAtenIndexTensorHackedTwin
outShape, cast<RankedTensorType>(self.getType()).getElementType());

auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(rewriter, idx, outShape.size() - ttype.getRank());

SmallVector<Value> broadcastValues;
SmallVector<int64_t> broadcastAxes;


for(unsigned int j = 0; j < selfType.getRank(); j++) {
if(j != i) {
broadcastAxes.push_back(j);
Expand All @@ -348,45 +345,6 @@ class ConvertAtenIndexTensorHackedTwin
self = gather.getResult();
}

/*for (unsigned int i = 0; i < indices.size(); i++) {
auto idx = indices[i];
int numNonOneAxis = 0;
auto ttype = cast<RankedTensorType>(idx.getType());
if(ttype.getRank() != indices.size() - i) {
// there is a version of this op, where everything comes in as a single dim and then is should instead select the different indicies from each?
// so the difference would be if it keeps the dim or shrinks it. But not 100% clear on what the definition of the different semantics are
return op.emitError("unsure what to do");
}
for (int j = 0; j < ttype.getRank(); j++)
if (ttype.getShape()[j] != 1)
numNonOneAxis++;
if (numNonOneAxis > 1)
return op.emitError(
"Expected the input shape to have a single non-one axis");
// convert it to a 1-dim vector
if (ttype.getRank() != 1) {
ReassociationIndices reassocIndices;
for (int j = 0; j < ttype.getRank(); j++)
reassocIndices.push_back(j);
SmallVector<ReassociationIndices> ri = {reassocIndices};
auto reshape =
rewriter.create<tensor::CollapseShapeOp>(op.getLoc(), idx, ri);
idx = reshape.getResult();
}
SmallVector<int64_t> outShape(
cast<RankedTensorType>(self.getType()).getShape());
outShape[i] = ttype.getNumElements();
auto outType = RankedTensorType::get(
outShape, cast<RankedTensorType>(self.getType()).getElementType());
auto gather = rewriter.create<tcp::GatherOp>(
op.getLoc(), outType, self, idx, rewriter.getIndexAttr(i));
self = gather.getResult();
}*/

// assert(op.getType() == self.getType());

rewriter.replaceOp(op, self);
return success();
}
Expand Down
9 changes: 9 additions & 0 deletions lib/Conversion/TorchToTcp/Misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
Expand Down Expand Up @@ -162,6 +163,14 @@ class ConvertValueTensorLiteralOp
rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType, denseIntAttr);
return success();
}
if (auto elements = dyn_cast<DenseResourceElementsAttr>(op.getValueAttr())) {
if(resultType.getElementType().isInteger() &&
resultType != adaptor.getValue().getType()) {
auto attr = DenseResourceElementsAttr::get(resultType, elements.getRawHandle());
rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType, attr);
return success();
}
}

rewriter.replaceOpWithNewOp<tcp::ConstOp>(op, resultType,
adaptor.getValue());
Expand Down
6 changes: 6 additions & 0 deletions lib/Dialect/IR/TcpOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@ LogicalResult IsolatedGroupOp::verify() {

OpFoldResult ConstOp::fold(FoldAdaptor) { return getValueAttr(); }

LogicalResult ConstOp::verify() {
if(getValueAttr().getType() != getType())
return emitOpError("can not be used to cast types");
return success();
}

LogicalResult CastOp::verify() {
auto inputType = getIn().getType().cast<RankedTensorType>();
auto outputType = getOut().getType().cast<RankedTensorType>();
Expand Down

0 comments on commit 60fdf0d

Please sign in to comment.