Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change to remove DPS style calling convention in plan dialect #161

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir-tensorrt/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ compile_commands.json
**/*.private.*
*.private
**/tmp/**
**/tmp**
**/tripy/**

# TRT Timing Cache artifacts
*.timing-cache
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpAsmInterface.td"

class Plan_NativeOpTrait<string name,
Expand Down Expand Up @@ -136,8 +135,6 @@ def Plan_InlineGroupOp : Plan_GroupOpBase<"inline_group", [

def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
IsolatedFromAbove,
AttrSizedOperandSegments,
DestinationStyleOpInterface,
SingleBlockImplicitTerminator<"plan::YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
Expand Down Expand Up @@ -199,19 +196,16 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [

}];
let arguments = (ins Variadic<AnyTypeOf<[AnyRankedTensor, AnySignlessIntegerOrIndex]>>:$inputs,
Variadic<AnyRankedTensor>:$outs,
BoundsAttrArray:$input_attrs,
BoundsAttrArray:$res_attrs,
AnyAttr:$target);

let results = (outs Variadic<AnyType>:$results);

let assemblyFormat = [{
`target` `(` $target `)` `\n`
`inputs` `(` ( $inputs^ `:` type($inputs) `)` ) : ( `)` ) ? `\n`
`outs` `(` $outs `:` type($outs) `)` `\n`
`in_attrs` $input_attrs `\n`
`res_attrs` $res_attrs attr-dict-with-keyword `->` type($results)
attr-dict-with-keyword `->` type($results)
$body
}];

Expand All @@ -220,18 +214,14 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
let skipDefaultBuilders = 1;

let builders = [
OpBuilder<(ins "Attribute":$target,
"ValueRange":$inputs, "ValueRange":$outs,
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs,
CArg<"ArrayRef<BoundsAttr>", "{}">:$res_attrs)>
OpBuilder<(ins "TypeRange":$results,
"Attribute":$target,
"ValueRange":$inputs,
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs)>
];

let extraClassDeclaration = baseExtraClassDeclaration # [{

MutableOperandRange getDpsInitsMutable() {
return getOutsMutable();
}

/// Returns true if the `i-th` input argument has a tensor type.
bool argHasTensorType(unsigned inputIdx) {
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
Expand All @@ -244,17 +234,6 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
}

ArrayRef<BlockArgument> getRegionOutArgs() {
return getBody().getArguments().take_back(getOuts().size());
}

/// Populate the `res_attrs` from an array of BoundsAttrs.
void setResAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
setResAttrsAttr(::mlir::ArrayAttr::get(
getOperation()->getContext(),
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
));
}

/// Populate the `input_attrs` from an array of BoundsAttrs.
void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ include "mlir-tensorrt/Dialect/CUDA/IR/CUDATypes.td"
include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"



def TensorRTRuntime_TensorRTRuntimeOpTrait: NativeOpTrait<"TensorRTRuntimeOpTrait"> {
Expand Down Expand Up @@ -61,11 +62,8 @@ def TensorRTRuntime_CompileOp : TensorRTRuntime_Op<"compile", [Pure]> {
//===----------------------------------------------------------------------===//

def TensorRTRuntime_EnqueueOp : TensorRTRuntime_Op<"enqueue", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<TensorKindOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]> {
let description = [{

Expand All @@ -88,23 +86,19 @@ def TensorRTRuntime_EnqueueOp : TensorRTRuntime_Op<"enqueue", [
let arguments = (ins TensorRTRuntime_Context:$execution_context,
CUDA_Stream:$stream,
Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outs,
OptionalAttr<DenseI64ArrayAttr>:$host_tensor_args);
let results = (outs Variadic<AnyType>:$results);

let assemblyFormat = [{
$execution_context `stream` `(` $stream `)` ` `
(`host_tensor_args` $host_tensor_args^ ` ` )?
`(` $inputs `)` `outs` `(` $outs `)`
attr-dict `:` functional-type($inputs, $outs)
`(` $inputs `)`
attr-dict `:` functional-type($inputs, $results)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
// Declare the outs as inits/outs to DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getOutsMutable(); }

/// Return true if the operand at the specified index is a host tensor
/// argument.
bool isOperandOnHost(int64_t operandIdx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir-tensorrt/Dialect/CUDA/IR/CUDADialect.h"
#include "mlir-tensorrt/Dialect/TensorRTRuntime/IR/TensorRTRuntime.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -41,6 +42,9 @@ using namespace mlir;
using namespace mlir::executor;
using namespace mlir::cuda;

static ExecutorOpaqueType getTrtOutputsOpaqueType(MLIRContext *ctx) {
return ExecutorOpaqueType::get(ctx, "trtrt_outputs");
}
static ExecutorOpaqueType getTrtRuntimeOpaqueType(MLIRContext *ctx) {
return ExecutorOpaqueType::get(ctx, "trtrt_runtime");
}
Expand Down Expand Up @@ -184,8 +188,6 @@ struct ConvertEnqueueToCall
std::string funcName;
funcName =
"_" + llvm::join(llvm::split(op->getName().getStringRef(), "."), "_");
if (op->getNumResults() > 0)
return failure();

SmallVector<Value> newOperands = {adaptor.getExecutionContext(),
adaptor.getStream()};
Expand Down Expand Up @@ -217,10 +219,6 @@ struct ConvertEnqueueToCall
if (failed(createMemRefAndExractPtr(oldVal, newVal)))
return failure();
}
for (auto [oldVal, newVal] : llvm::zip(op.getOuts(), adaptor.getOuts())) {
if (failed(createMemRefAndExractPtr(oldVal, newVal)))
return failure();
}

// Create the table containing the pointer/offset args and append it to the
// arguments for the call op.
Expand All @@ -230,21 +228,39 @@ struct ConvertEnqueueToCall
argTablePack);
newOperands.push_back(args);

SmallVector<Type, 4> resultTypes(op->getResultTypes().begin(), op->getResultTypes().end());

auto parentModule = op->getParentOfType<ModuleOp>();
auto enqueueFunc = getOrInsertFuncDeclaration(
rewriter, op.getLoc(), parentModule, funcName,
ExecutorFunctionType::get(rewriter.getContext(),
{adaptor.getExecutionContext().getType(),
adaptor.getStream().getType()},
{}, rewriter.getUnitAttr()));
resultTypes, rewriter.getUnitAttr()));

rewriter.replaceOpWithNewOp<CallOp>(
op, TypeRange{}, enqueueFunc.getLeafReference(), newOperands);
op, op->getResultTypes(), enqueueFunc.getLeafReference(), newOperands);

return success();
}
};

class RemoveBufferizationClonePattern : public OpRewritePattern<bufferization::CloneOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(bufferization::CloneOp op,
PatternRewriter &rewriter) const override {
// Replace all uses of the clone op with its input
rewriter.replaceAllUsesWith(op.getResult(), op.getInput());

// Erase the clone op
rewriter.eraseOp(op);

return success();
}
};

struct ConvertTrtrtOpToCall : public ConvertToExecutorPattern {
ConvertTrtrtOpToCall(ExecutorTypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
Expand Down Expand Up @@ -282,6 +298,11 @@ struct ConvertTrtrtOpToCall : public ConvertToExecutorPattern {
}
};

void populateRemoveBufferizationClonePatterns(RewritePatternSet &patterns) {
patterns.add<RemoveBufferizationClonePattern>(patterns.getContext());
}


} // namespace

namespace {
Expand Down Expand Up @@ -320,6 +341,7 @@ class TensorRTRuntimeToExecutorPass
typeConverter.addConversion([](cuda::StreamType t) {
return getCudaStreamOpaqueType(t.getContext());
});

// Convert `trtrt.compile` to globals that create execution context from
// serialized TensorRT engine data.
{
Expand All @@ -335,6 +357,15 @@ class TensorRTRuntimeToExecutorPass
return signalPassFailure();
}

{
RewritePatternSet patterns(&getContext());
populateRemoveBufferizationClonePatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns))))
return signalPassFailure();
}

// Convert `trtrt.enqueue|create_runtime|execution_context|load` to
// `executor.call` and function declarations.
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "tensorrt-to-tensorrt-runtime"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")

namespace mlir {
#define GEN_PASS_DEF_CONVERTTENSORRTTOTENSORRTRUNTIMEPASS
#include "mlir-tensorrt/Conversion/Passes.h.inc"
Expand Down Expand Up @@ -123,12 +128,15 @@ class ConvertTensorRTToRuntimePass
{FlatSymbolRefAttr::get(trtFunc)}));
Value stream = rewriter.create<cuda::GetGlobalStreamOp>(loc, 0);
auto enqueueOp = rewriter.create<trtrt::EnqueueOp>(
loc, executionContext, stream, callOp.getInputs(),
callOp.getOutputs(),
loc, callOp->getResultTypes(), executionContext, stream, callOp.getInputs(),
/*host_tensors_args=*/hostTensorArgs.empty()
? DenseI64ArrayAttr{}
: DenseI64ArrayAttr::get(ctx, hostTensorArgs));
rewriter.setInsertionPointAfter(enqueueOp);

DBGS() << "Number of call op results: " << callOp->getNumResults() << "\n";
DBGS() << "Number of enqueue op results: " << enqueueOp->getNumResults() << "\n";

rewriter.replaceOp(callOp, enqueueOp->getResults());
}
}
Expand Down
42 changes: 7 additions & 35 deletions mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,21 +387,6 @@ LogicalResult InlineClosedGroupOp::verify() {
return failure();
}

SmallVector<BoundsAttr> resAttrs =
llvm::to_vector(getResAttrs().getAsRange<BoundsAttr>());
if (resAttrs.size() != getNumResults())
return emitOpError("expected number of results (")
<< getNumResults()
<< ") to equal the number of res_attrs BoundsAttrs ("
<< resAttrs.size() << ")";

for (auto [idx, type] : llvm::enumerate(getResultTypes())) {
BoundsAttr boundsAttr = resAttrs[idx];
if (failed(verifyBoundsAttr("result", idx, type, boundsAttr,
[&]() { return emitOpError(); })))
return failure();
}

return success();
}

Expand All @@ -424,33 +409,22 @@ InlineClosedGroupOp::getEntrySuccessorOperands(RegionBranchPoint point) {

void InlineClosedGroupOp::getAsmBlockArgumentNames(
Region &region, OpAsmSetValueNameFn setNameFn) {
assert(region.front().getNumArguments() ==
getInputs().size() + getOuts().size() &&
"expected one block arg for each input and destination argument");
unsigned numInputs = getInputs().size();
assert(region.front().getNumArguments() == getInputs().size() &&
"expected one block arg for each input argument");
for (BlockArgument arg : region.front().getArguments()) {
StringRef name = arg.getArgNumber() < numInputs ? "in" : "out";
setNameFn(arg, name);
setNameFn(arg, "in");
}
}

void InlineClosedGroupOp::build(OpBuilder &b, OperationState &state,
Attribute target, ValueRange inputs,
ValueRange outs,
ArrayRef<BoundsAttr> input_attrs,
ArrayRef<BoundsAttr> result_attrs) {
TypeRange resultTypes, Attribute target,
ValueRange inputs,
ArrayRef<BoundsAttr> input_attrs) {
state.addTypes(resultTypes);
state.addOperands(inputs);
state.addOperands(outs);
state.getOrAddProperties<Properties>().target = target;
state.getOrAddProperties<Properties>().setInputAttrs(b.getArrayAttr(
SmallVector<Attribute>(input_attrs.begin(), input_attrs.end())));
state.getOrAddProperties<Properties>().setResAttrs(b.getArrayAttr(
SmallVector<Attribute>(result_attrs.begin(), result_attrs.end())));

llvm::copy(
ArrayRef<int32_t>{static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outs.size())},
state.getOrAddProperties<Properties>().operandSegmentSizes.begin());
Region *body = state.addRegion();
auto getLocs = [](ValueRange r) {
SmallVector<Location> locs;
Expand All @@ -461,8 +435,6 @@ void InlineClosedGroupOp::build(OpBuilder &b, OperationState &state,
};
(void)body->emplaceBlock();
body->addArguments(TypeRange(inputs), getLocs(inputs));
body->addArguments(TypeRange(outs), getLocs(outs));
state.addTypes(TypeRange(outs));
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -854,10 +854,10 @@ class AllocTensorsPass

// First rewrite public functions to conform to DPS style.
IRRewriter rewriter(ctx);
if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
op->emitError("Failed to convert non-private functions to DPS");
return signalPassFailure();
}
// if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
// op->emitError("Failed to convert non-private functions to DPS");
// return signalPassFailure();
// }

// Rewrite SCF for and while loop bodies for better bufferization results,
// if possible.
Expand Down
Loading