From 60fdf0d7004a9e68787326156cc269b4ecbb9e48 Mon Sep 17 00:00:00 2001 From: Matthew Francis-Landau Date: Fri, 20 Sep 2024 08:15:47 -0700 Subject: [PATCH] add casting of int types for resources --- include/mlir-tcp/Dialect/IR/TcpOps.td | 1 + lib/Conversion/TorchToTcp/DataMovement.cpp | 44 +--------------------- lib/Conversion/TorchToTcp/Misc.cpp | 9 +++++ lib/Dialect/IR/TcpOps.cpp | 6 +++ 4 files changed, 17 insertions(+), 43 deletions(-) diff --git a/include/mlir-tcp/Dialect/IR/TcpOps.td b/include/mlir-tcp/Dialect/IR/TcpOps.td index e6d6e3a..6aa267a 100644 --- a/include/mlir-tcp/Dialect/IR/TcpOps.td +++ b/include/mlir-tcp/Dialect/IR/TcpOps.td @@ -178,6 +178,7 @@ def Tcp_ConstOp : Tcp_Op<"const", [ConstantLike, Pure]> { let assemblyFormat = "attr-dict `:` type($out)"; let hasFolder = 1; + let hasVerifier = 1; } def Tcp_BroadcastOp : Tcp_Op<"broadcast", [ diff --git a/lib/Conversion/TorchToTcp/DataMovement.cpp b/lib/Conversion/TorchToTcp/DataMovement.cpp index 0edea8d..80868fa 100644 --- a/lib/Conversion/TorchToTcp/DataMovement.cpp +++ b/lib/Conversion/TorchToTcp/DataMovement.cpp @@ -311,8 +311,6 @@ class ConvertAtenIndexTensorHackedTwin indices = getTypeConvertedValues(rewriter, op.getLoc(), getTypeConverter(), indices); - - for(unsigned int i = 0; i < indices.size(); i++) { auto idx = indices[i]; auto ttype = cast(idx.getType()); @@ -323,10 +321,9 @@ class ConvertAtenIndexTensorHackedTwin outShape, cast(self.getType()).getElementType()); auto expandedShape = torch_to_tcp::broadcastRankInLeadingDims(rewriter, idx, outShape.size() - ttype.getRank()); + SmallVector broadcastValues; SmallVector broadcastAxes; - - for(unsigned int j = 0; j < selfType.getRank(); j++) { if(j != i) { broadcastAxes.push_back(j); @@ -348,45 +345,6 @@ class ConvertAtenIndexTensorHackedTwin self = gather.getResult(); } - /*for (unsigned int i = 0; i < indices.size(); i++) { - auto idx = indices[i]; - int numNonOneAxis = 0; - auto ttype = cast(idx.getType()); - if(ttype.getRank() != indices.size() - i) { - // there is a version of this op, where everything comes in as a single dim and then is should instead select the different indicies from each? - // so the difference would be if it keeps the dim or shrinks it. But not 100% clear on what the definition of the different semantics are - return op.emitError("unsure what to do"); - } - for (int j = 0; j < ttype.getRank(); j++) - if (ttype.getShape()[j] != 1) - numNonOneAxis++; - if (numNonOneAxis > 1) - return op.emitError( - "Expected the input shape to have a single non-one axis"); - // convert it to a 1-dim vector - if (ttype.getRank() != 1) { - ReassociationIndices reassocIndices; - for (int j = 0; j < ttype.getRank(); j++) - reassocIndices.push_back(j); - SmallVector ri = {reassocIndices}; - auto reshape = - rewriter.create(op.getLoc(), idx, ri); - idx = reshape.getResult(); - } - - SmallVector outShape( - cast(self.getType()).getShape()); - outShape[i] = ttype.getNumElements(); - auto outType = RankedTensorType::get( - outShape, cast(self.getType()).getElementType()); - - auto gather = rewriter.create( - op.getLoc(), outType, self, idx, rewriter.getIndexAttr(i)); - self = gather.getResult(); - }*/ - - // assert(op.getType() == self.getType()); - rewriter.replaceOp(op, self); return success(); } diff --git a/lib/Conversion/TorchToTcp/Misc.cpp b/lib/Conversion/TorchToTcp/Misc.cpp index c66453a..3acffd1 100644 --- a/lib/Conversion/TorchToTcp/Misc.cpp +++ b/lib/Conversion/TorchToTcp/Misc.cpp @@ -16,6 +16,7 @@ #include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -162,6 +163,14 @@ class ConvertValueTensorLiteralOp rewriter.replaceOpWithNewOp(op, resultType, denseIntAttr); return success(); } + if (auto elements = dyn_cast(op.getValueAttr())) { + if(resultType.getElementType().isInteger() && + resultType != adaptor.getValue().getType()) { + auto attr = DenseResourceElementsAttr::get(resultType, elements.getRawHandle()); + rewriter.replaceOpWithNewOp(op, resultType, attr); + return success(); + } + } rewriter.replaceOpWithNewOp(op, resultType, adaptor.getValue()); diff --git a/lib/Dialect/IR/TcpOps.cpp b/lib/Dialect/IR/TcpOps.cpp index 83c098d..ffafaa9 100644 --- a/lib/Dialect/IR/TcpOps.cpp +++ b/lib/Dialect/IR/TcpOps.cpp @@ -127,6 +127,12 @@ LogicalResult IsolatedGroupOp::verify() { OpFoldResult ConstOp::fold(FoldAdaptor) { return getValueAttr(); } +LogicalResult ConstOp::verify() { + if(getValueAttr().getType() != getType()) + return emitOpError("can not be used to cast types"); + return success(); +} + LogicalResult CastOp::verify() { auto inputType = getIn().getType().cast(); auto outputType = getOut().getType().cast();