Skip to content

Commit

Permalink
Cherry-pick: Add TCP custom op builder helper (cruise-automation#33) (c…
Browse files Browse the repository at this point in the history
…ruise-automation#12)

Add a utility to aid in converting torch ops to `tcp.custom_op`

Cherry picking cruise-automation#33

---------

Co-authored-by: Srinath Avadhanula <[email protected]>
  • Loading branch information
Srinath Avadhanula authored and GitHub Enterprise committed Jan 23, 2024
2 parents bf6191f + 4e1d4c9 commit ef63c4d
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 214 deletions.
175 changes: 32 additions & 143 deletions lib/Conversion/TorchToTcp/TcpCustomOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,14 @@ class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
LogicalResult
matchAndRewrite(AtenGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(
OpConversionPattern<AtenGatherOp>::getTypeConverter()->convertTypes(
op->getResultTypes(), resultTypes))) {
return failure();
}

int64_t dimVal;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimVal)))
return failure();
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};

auto indexAttr =
rewriter.getNamedAttr("axis", rewriter.getI64IntegerAttr(dimVal));
helper.addOperand("self", adaptor.getSelf());
helper.addOperand("index", adaptor.getIndex());
helper.addIntAttr("axis", op.getDim());

auto newOp = rewriter.replaceOpWithNewOp<tcp::CustomOp>(
op, resultTypes, ValueRange{adaptor.getSelf(), adaptor.getIndex()},
indexAttr);
newOp.setOpName(op->getName().getStringRef());
return success();
return helper.replace();
}
};

Expand All @@ -64,37 +53,21 @@ class ConvertAtenIndexTensorHackedTwinOp
LogicalResult
matchAndRewrite(AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(
OpConversionPattern<AtenIndexTensorHackedTwinOp>::getTypeConverter()
->convertTypes(op->getResultTypes(), resultTypes))) {
return failure();
}

SmallVector<Value> tensorOperands;
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};

Value input = adaptor.getSelf();
auto inputTensorType = input.getType().dyn_cast<RankedTensorType>();
// Check input is a tensor type.
if (!inputTensorType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
tensorOperands.push_back(input);

// Deal with torch.prim.ListConstruct of non const value to get the index
Value indexList = op.getIndices();
SmallVector<Value> indicesTorchType;
if (!getListConstructElements(indexList, indicesTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
SmallVector<Value> indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), indicesTorchType);
helper.addOperand("self", input);
helper.addAsMultipleTensorOperands("index_", op.getIndices());

tensorOperands.append(indexTensors.begin(), indexTensors.end());

auto newOp = rewriter.replaceOpWithNewOp<tcp::CustomOp>(op, resultTypes,
tensorOperands);
newOp.setOpName(op->getName().getStringRef());
return success();
return helper.replace();
}
};

Expand All @@ -106,44 +79,16 @@ class ConvertAten_IndexPutImplOp
LogicalResult
matchAndRewrite(Aten_IndexPutImplOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(OpConversionPattern<Aten_IndexPutImplOp>::getTypeConverter()
->convertTypes(op->getResultTypes(), resultTypes))) {
return failure();
}

SmallVector<Value> operands;
operands.push_back(adaptor.getSelf());

// Handle indices
SmallVector<Value> indicesTorchType;
if (!getListConstructElements(adaptor.getIndices(), indicesTorchType))
return op.emitError(
"unimplemented: the tensor list is not from list construct");
SmallVector<Value> indexTensors = getTypeConvertedValues(
rewriter, op->getLoc(), getTypeConverter(), indicesTorchType);
operands.append(indexTensors.begin(), indexTensors.end());
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};
helper.addOperand("self", adaptor.getSelf());
helper.addAsMultipleTensorOperands("index_", adaptor.getIndices());
helper.addOperand("values", adaptor.getValues());
helper.addBoolAttr("accumulate", op.getAccumulate());
helper.addBoolAttr("unsafe", op.getUnsafe());

operands.push_back(adaptor.getValues());

bool accumulate;
if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate)))
return op.emitError("expected accumulate operand to be a bool constant");
bool unsafe;
if (!matchPattern(op.getUnsafe(), m_TorchConstantBool(&unsafe)))
return op.emitError("expected unsafe operand to be a bool constant");

SmallVector<NamedAttribute> attrs;
attrs.push_back(
rewriter.getNamedAttr("accumulate", rewriter.getBoolAttr(accumulate)));
attrs.push_back(
rewriter.getNamedAttr("unsafe", rewriter.getBoolAttr(unsafe)));
attrs.push_back(rewriter.getNamedAttr(
"op_name", rewriter.getStringAttr(op->getName().getStringRef())));

rewriter.replaceOpWithNewOp<tcp::CustomOp>(op, resultTypes, operands,
attrs);
return success();
return helper.replace();
}
};

Expand All @@ -154,78 +99,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
LogicalResult
matchAndRewrite(AtenConvolutionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> resultTypes;
if (failed(OpConversionPattern<AtenConvolutionOp>::getTypeConverter()
->convertTypes(op->getResultTypes(), resultTypes))) {
return failure();
}

SmallVector<Value> operands;
SmallVector<StringRef> operandNames;

auto addOperand = [&](std::string name, Value value) {
operandNames.push_back(name);
operands.push_back(value);
};

addOperand("input", adaptor.getInput());
addOperand("weight", adaptor.getWeight());
torch_to_tcp::TorchToTcpCustomOpConversionHelper helper{op, rewriter,
getTypeConverter()};
helper.addOperand("input", adaptor.getInput());
helper.addOperand("weight", adaptor.getWeight());
if (!adaptor.getBias().getType().isa<Torch::NoneType>()) {
addOperand("bias", adaptor.getBias());
helper.addOperand("bias", adaptor.getBias());
}

SmallVector<NamedAttribute> attrs;

attrs.push_back(rewriter.getNamedAttr(
"torch_operand_names", rewriter.getStrArrayAttr(operandNames)));

auto addListOfIntAttr = [&](const std::string &name, Value value) {
SmallVector<int64_t> valueInt;
if (!matchPattern(value, m_TorchListOfConstantInts(valueInt)))
return rewriter.notifyMatchFailure(op, std::string("non-const") + name +
"list unsupported");
attrs.push_back(
rewriter.getNamedAttr(name, rewriter.getIndexArrayAttr(valueInt)));
return success();
};

if (auto result = addListOfIntAttr("stride", adaptor.getStride());
result.failed()) {
return result;
}
if (auto result = addListOfIntAttr("padding", adaptor.getPadding());
result.failed()) {
return result;
}
if (auto result = addListOfIntAttr("dilation", adaptor.getDilation());
result.failed()) {
return result;
}
if (auto result =
addListOfIntAttr("output_padding", adaptor.getOutputPadding());
result.failed()) {
return result;
}

bool transposed;
if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed)))
return rewriter.notifyMatchFailure(op,
"non const transposed unsupported");
attrs.push_back(
rewriter.getNamedAttr("transposed", rewriter.getBoolAttr(transposed)));

int64_t groups;
if (!matchPattern(op.getGroups(), m_TorchConstantInt(&groups)))
return rewriter.notifyMatchFailure(op, "non const groups unsupported");
attrs.push_back(
rewriter.getNamedAttr("groups", rewriter.getI64IntegerAttr(groups)));

auto replOp = rewriter.replaceOpWithNewOp<tcp::CustomOp>(op, resultTypes,
operands, attrs);

replOp.setOpName(op->getName().getStringRef());
helper.addListOfIntsAttr("stride", adaptor.getStride());
helper.addListOfIntsAttr("padding", adaptor.getPadding());
helper.addListOfIntsAttr("dilation", adaptor.getDilation());
helper.addListOfIntsAttr("output_padding", adaptor.getOutputPadding());
helper.addBoolAttr("transposed", op.getTransposed());
helper.addIntAttr("groups", op.getGroups());

return success();
return helper.replace();
}
};

Expand Down
Loading

0 comments on commit ef63c4d

Please sign in to comment.