Skip to content

Commit

Permalink
Change to remove DPS style calling convention in plan dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Aug 30, 2024
1 parent ed92919 commit 44df0ce
Show file tree
Hide file tree
Showing 19 changed files with 169 additions and 424 deletions.
2 changes: 2 additions & 0 deletions mlir-tensorrt/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ compile_commands.json
**/*.private.*
*.private
**/tmp/**
**/tmp**
**/tripy/**

# TRT Timing Cache artifacts
*.timing-cache
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/OpAsmInterface.td"

class Plan_NativeOpTrait<string name,
Expand Down Expand Up @@ -136,8 +135,6 @@ def Plan_InlineGroupOp : Plan_GroupOpBase<"inline_group", [

def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
IsolatedFromAbove,
AttrSizedOperandSegments,
DestinationStyleOpInterface,
SingleBlockImplicitTerminator<"plan::YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
Expand Down Expand Up @@ -199,19 +196,16 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [

}];
let arguments = (ins Variadic<AnyTypeOf<[AnyRankedTensor, AnySignlessIntegerOrIndex]>>:$inputs,
Variadic<AnyRankedTensor>:$outs,
BoundsAttrArray:$input_attrs,
BoundsAttrArray:$res_attrs,
AnyAttr:$target);

let results = (outs Variadic<AnyType>:$results);

let assemblyFormat = [{
`target` `(` $target `)` `\n`
`inputs` `(` ( $inputs^ `:` type($inputs) `)` ) : ( `)` ) ? `\n`
`outs` `(` $outs `:` type($outs) `)` `\n`
`in_attrs` $input_attrs `\n`
`res_attrs` $res_attrs attr-dict-with-keyword `->` type($results)
attr-dict-with-keyword `->` type($results)
$body
}];

Expand All @@ -220,18 +214,14 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
let skipDefaultBuilders = 1;

let builders = [
OpBuilder<(ins "Attribute":$target,
"ValueRange":$inputs, "ValueRange":$outs,
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs,
CArg<"ArrayRef<BoundsAttr>", "{}">:$res_attrs)>
OpBuilder<(ins "TypeRange":$results,
"Attribute":$target,
"ValueRange":$inputs,
CArg<"ArrayRef<BoundsAttr>", "{}">:$input_attrs)>
];

let extraClassDeclaration = baseExtraClassDeclaration # [{

MutableOperandRange getDpsInitsMutable() {
return getOutsMutable();
}

/// Returns true if the `i-th` input argument has a tensor type.
bool argHasTensorType(unsigned inputIdx) {
assert(inputIdx < getInputs().size() && "input index out-of-bounds");
Expand All @@ -244,17 +234,6 @@ def Plan_InlineClosedGroupOp : Plan_GroupOpBase<"inline_closed_group", [
return cast<BoundsAttr>(getInputAttrs()[inputIdx]);
}

ArrayRef<BlockArgument> getRegionOutArgs() {
return getBody().getArguments().take_back(getOuts().size());
}

/// Populate the `res_attrs` from an array of BoundsAttrs.
void setResAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
setResAttrsAttr(::mlir::ArrayAttr::get(
getOperation()->getContext(),
ArrayRef<Attribute>(boundsAttrs.begin(), boundsAttrs.end())
));
}

/// Populate the `input_attrs` from an array of BoundsAttrs.
void setInputAttrsAttr(ArrayRef<BoundsAttr> boundsAttrs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ include "mlir-tensorrt/Dialect/CUDA/IR/CUDATypes.td"
include "mlir-tensorrt-dialect/Interface/TensorKindOpInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"



def TensorRTRuntime_TensorRTRuntimeOpTrait: NativeOpTrait<"TensorRTRuntimeOpTrait"> {
Expand Down Expand Up @@ -61,11 +62,8 @@ def TensorRTRuntime_CompileOp : TensorRTRuntime_Op<"compile", [Pure]> {
//===----------------------------------------------------------------------===//

def TensorRTRuntime_EnqueueOp : TensorRTRuntime_Op<"enqueue", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
DeclareOpInterfaceMethods<TensorKindOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
AttrSizedOperandSegments,
DestinationStyleOpInterface
]> {
let description = [{

Expand All @@ -88,23 +86,19 @@ def TensorRTRuntime_EnqueueOp : TensorRTRuntime_Op<"enqueue", [
let arguments = (ins TensorRTRuntime_Context:$execution_context,
CUDA_Stream:$stream,
Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outs,
OptionalAttr<DenseI64ArrayAttr>:$host_tensor_args);
let results = (outs Variadic<AnyType>:$results);

let assemblyFormat = [{
$execution_context `stream` `(` $stream `)` ` `
(`host_tensor_args` $host_tensor_args^ ` ` )?
`(` $inputs `)` `outs` `(` $outs `)`
attr-dict `:` functional-type($inputs, $outs)
`(` $inputs `)`
attr-dict `:` functional-type($inputs, $results)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
// Declare the outs as inits/outs to DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getOutsMutable(); }

/// Return true if the operand at the specified index is a host tensor
/// argument.
bool isOperandOnHost(int64_t operandIdx) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "mlir-tensorrt/Dialect/CUDA/IR/CUDADialect.h"
#include "mlir-tensorrt/Dialect/TensorRTRuntime/IR/TensorRTRuntime.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -41,6 +42,9 @@ using namespace mlir;
using namespace mlir::executor;
using namespace mlir::cuda;

static ExecutorOpaqueType getTrtOutputsOpaqueType(MLIRContext *ctx) {
return ExecutorOpaqueType::get(ctx, "trtrt_outputs");
}
static ExecutorOpaqueType getTrtRuntimeOpaqueType(MLIRContext *ctx) {
return ExecutorOpaqueType::get(ctx, "trtrt_runtime");
}
Expand Down Expand Up @@ -184,8 +188,6 @@ struct ConvertEnqueueToCall
std::string funcName;
funcName =
"_" + llvm::join(llvm::split(op->getName().getStringRef(), "."), "_");
if (op->getNumResults() > 0)
return failure();

SmallVector<Value> newOperands = {adaptor.getExecutionContext(),
adaptor.getStream()};
Expand Down Expand Up @@ -217,10 +219,6 @@ struct ConvertEnqueueToCall
if (failed(createMemRefAndExractPtr(oldVal, newVal)))
return failure();
}
for (auto [oldVal, newVal] : llvm::zip(op.getOuts(), adaptor.getOuts())) {
if (failed(createMemRefAndExractPtr(oldVal, newVal)))
return failure();
}

// Create the table containing the pointer/offset args and append it to the
// arguments for the call op.
Expand All @@ -230,21 +228,39 @@ struct ConvertEnqueueToCall
argTablePack);
newOperands.push_back(args);

SmallVector<Type, 4> resultTypes(op->getResultTypes().begin(), op->getResultTypes().end());

auto parentModule = op->getParentOfType<ModuleOp>();
auto enqueueFunc = getOrInsertFuncDeclaration(
rewriter, op.getLoc(), parentModule, funcName,
ExecutorFunctionType::get(rewriter.getContext(),
{adaptor.getExecutionContext().getType(),
adaptor.getStream().getType()},
{}, rewriter.getUnitAttr()));
resultTypes, rewriter.getUnitAttr()));

rewriter.replaceOpWithNewOp<CallOp>(
op, TypeRange{}, enqueueFunc.getLeafReference(), newOperands);
op, op->getResultTypes(), enqueueFunc.getLeafReference(), newOperands);

return success();
}
};

class RemoveBufferizationClonePattern : public OpRewritePattern<bufferization::CloneOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(bufferization::CloneOp op,
PatternRewriter &rewriter) const override {
// Replace all uses of the clone op with its input
rewriter.replaceAllUsesWith(op.getResult(), op.getInput());

// Erase the clone op
rewriter.eraseOp(op);

return success();
}
};

struct ConvertTrtrtOpToCall : public ConvertToExecutorPattern {
ConvertTrtrtOpToCall(ExecutorTypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
Expand Down Expand Up @@ -282,6 +298,11 @@ struct ConvertTrtrtOpToCall : public ConvertToExecutorPattern {
}
};

void populateRemoveBufferizationClonePatterns(RewritePatternSet &patterns) {
patterns.add<RemoveBufferizationClonePattern>(patterns.getContext());
}


} // namespace

namespace {
Expand Down Expand Up @@ -320,6 +341,7 @@ class TensorRTRuntimeToExecutorPass
typeConverter.addConversion([](cuda::StreamType t) {
return getCudaStreamOpaqueType(t.getContext());
});

// Convert `trtrt.compile` to globals that create execution context from
// serialized TensorRT engine data.
{
Expand All @@ -335,6 +357,15 @@ class TensorRTRuntimeToExecutorPass
return signalPassFailure();
}

{
RewritePatternSet patterns(&getContext());
populateRemoveBufferizationClonePatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns))))
return signalPassFailure();
}

// Convert `trtrt.enqueue|create_runtime|execution_context|load` to
// `executor.call` and function declarations.
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "tensorrt-to-tensorrt-runtime"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")

namespace mlir {
#define GEN_PASS_DEF_CONVERTTENSORRTTOTENSORRTRUNTIMEPASS
#include "mlir-tensorrt/Conversion/Passes.h.inc"
Expand Down Expand Up @@ -123,12 +128,15 @@ class ConvertTensorRTToRuntimePass
{FlatSymbolRefAttr::get(trtFunc)}));
Value stream = rewriter.create<cuda::GetGlobalStreamOp>(loc, 0);
auto enqueueOp = rewriter.create<trtrt::EnqueueOp>(
loc, executionContext, stream, callOp.getInputs(),
callOp.getOutputs(),
loc, callOp->getResultTypes(), executionContext, stream, callOp.getInputs(),
/*host_tensors_args=*/hostTensorArgs.empty()
? DenseI64ArrayAttr{}
: DenseI64ArrayAttr::get(ctx, hostTensorArgs));
rewriter.setInsertionPointAfter(enqueueOp);

DBGS() << "Number of call op results: " << callOp->getNumResults() << "\n";
DBGS() << "Number of enqueue op results: " << enqueueOp->getNumResults() << "\n";

rewriter.replaceOp(callOp, enqueueOp->getResults());
}
}
Expand Down
42 changes: 7 additions & 35 deletions mlir-tensorrt/compiler/lib/Dialect/Plan/IR/PlanOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,21 +387,6 @@ LogicalResult InlineClosedGroupOp::verify() {
return failure();
}

SmallVector<BoundsAttr> resAttrs =
llvm::to_vector(getResAttrs().getAsRange<BoundsAttr>());
if (resAttrs.size() != getNumResults())
return emitOpError("expected number of results (")
<< getNumResults()
<< ") to equal the number of res_attrs BoundsAttrs ("
<< resAttrs.size() << ")";

for (auto [idx, type] : llvm::enumerate(getResultTypes())) {
BoundsAttr boundsAttr = resAttrs[idx];
if (failed(verifyBoundsAttr("result", idx, type, boundsAttr,
[&]() { return emitOpError(); })))
return failure();
}

return success();
}

Expand All @@ -424,33 +409,22 @@ InlineClosedGroupOp::getEntrySuccessorOperands(RegionBranchPoint point) {

void InlineClosedGroupOp::getAsmBlockArgumentNames(
Region &region, OpAsmSetValueNameFn setNameFn) {
assert(region.front().getNumArguments() ==
getInputs().size() + getOuts().size() &&
"expected one block arg for each input and destination argument");
unsigned numInputs = getInputs().size();
assert(region.front().getNumArguments() == getInputs().size() &&
"expected one block arg for each input argument");
for (BlockArgument arg : region.front().getArguments()) {
StringRef name = arg.getArgNumber() < numInputs ? "in" : "out";
setNameFn(arg, name);
setNameFn(arg, "in");
}
}

void InlineClosedGroupOp::build(OpBuilder &b, OperationState &state,
Attribute target, ValueRange inputs,
ValueRange outs,
ArrayRef<BoundsAttr> input_attrs,
ArrayRef<BoundsAttr> result_attrs) {
TypeRange resultTypes, Attribute target,
ValueRange inputs,
ArrayRef<BoundsAttr> input_attrs) {
state.addTypes(resultTypes);
state.addOperands(inputs);
state.addOperands(outs);
state.getOrAddProperties<Properties>().target = target;
state.getOrAddProperties<Properties>().setInputAttrs(b.getArrayAttr(
SmallVector<Attribute>(input_attrs.begin(), input_attrs.end())));
state.getOrAddProperties<Properties>().setResAttrs(b.getArrayAttr(
SmallVector<Attribute>(result_attrs.begin(), result_attrs.end())));

llvm::copy(
ArrayRef<int32_t>{static_cast<int32_t>(inputs.size()),
static_cast<int32_t>(outs.size())},
state.getOrAddProperties<Properties>().operandSegmentSizes.begin());
Region *body = state.addRegion();
auto getLocs = [](ValueRange r) {
SmallVector<Location> locs;
Expand All @@ -461,8 +435,6 @@ void InlineClosedGroupOp::build(OpBuilder &b, OperationState &state,
};
(void)body->emplaceBlock();
body->addArguments(TypeRange(inputs), getLocs(inputs));
body->addArguments(TypeRange(outs), getLocs(outs));
state.addTypes(TypeRange(outs));
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -854,10 +854,10 @@ class AllocTensorsPass

// First rewrite public functions to conform to DPS style.
IRRewriter rewriter(ctx);
if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
op->emitError("Failed to convert non-private functions to DPS");
return signalPassFailure();
}
// if (failed(rewriteNotPrivateFuncsToDPS(rewriter, op))) {
// op->emitError("Failed to convert non-private functions to DPS");
// return signalPassFailure();
// }

// Rewrite SCF for and while loop bodies for better bufferization results,
// if possible.
Expand Down
Loading

0 comments on commit 44df0ce

Please sign in to comment.