Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Introduce call_dps_packed #14183

Merged
merged 7 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,10 @@ ExprPattern IsOp(const String& op_name);
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> args = NullOpt);
/*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */
CallPattern IsCallTIR(const String& name, TuplePattern var_args);
/*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */
CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> args = NullOpt);
/*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */
CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args);
/*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */
DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered = false);
/*! \brief Syntatic Sugar for creating a TupleGetItemPattern */
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TVM_DLL Pass LambdaLift();
TVM_DLL Pass ToNonDataflow();

/*!
* \brief Perform explicit tensor allocation for call_tir.
* \brief Perform explicit tensor allocation for call_tir and call_dps_packed.
*
* \return The Pass.
*/
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .exec_builder import ExecBuilder

# Operator
from .op.base import call_tir
from .op.base import call_tir, call_dps_packed

# BlockBuilder
from .block_builder import BlockBuilder
Expand Down
18 changes: 15 additions & 3 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,23 @@ def is_call_tir(
return _is_call_tir(func_pattern, args)


def is_call_tir_extern(
def _is_call_dps_packed(
func_pattern: DFPattern,
args: Union[List, Tuple, TuplePattern] = None,
) -> CallPattern:
if args is None:
args = wildcard()
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)

return is_op("relax.call_dps_packed")(func_pattern, args)


def is_call_dps_packed(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
) -> CallPattern:
"""Syntax sugar for creating a CallPattern for call_tir that calls an extern function
"""Syntax sugar for creating a CallPattern for call_dps_packed

Parameters
----------
Expand All @@ -881,7 +893,7 @@ def is_call_tir_extern(
The resulting CallPattern
"""
func_pattern = ExternFuncPattern(func_name)
return _is_call_tir(func_pattern, args)
return _is_call_dps_packed(func_pattern, args)


def is_call_packed(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def __call__(self, *args):

@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
"""extern function, which can represent a TIR PrimFunc or a PackedFunc."""
"""extern function, which represents a PackedFunc."""

global_symbol: String

Expand Down
46 changes: 41 additions & 5 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def call_tir(
tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None,
) -> Call:
"""
Call a destination-passing-style function and return the output.
Call a tir.prim_func function and return the output.

Parameters
----------
func : Union[str, Expr]
yongwww marked this conversation as resolved.
Show resolved Hide resolved
The destination-passing-style function, can be ExternFunc or PrimFunc.
The tir PrimFunc function.

args : Expr
The input arguments.
Expand All @@ -74,9 +74,6 @@ def call_tir(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(func, str):
func = ExternFunc(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))

Expand All @@ -89,6 +86,45 @@ def call_tir(
return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore


@args_converter.auto
def call_dps_packed(
func: Union[str, Expr],
args: Expr,
out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
) -> Call:
"""
Call a destination-passing-style packed function and return the output.

Parameters
----------
func : Union[str, Expr]
The destination-passing-style function, can be ExternFunc.

args : Expr
The input arguments.

out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
The structure info of the call_dps_packed output.
It should be a single or a list of TensorStructInfo. Each one denotes the
structure info of a returned tensor.

Returns
-------
ret: Call
A call node for the call_dps_packed operator.
"""
if isinstance(func, str):
func = ExternFunc(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]

return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore


@args_converter.auto
def call_builtin_with_ctx(
func: Union[str, Expr],
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def LambdaLift():


def CallTIRRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_tir.
"""Perform explicit tensor allocation for call_tir and call_dps_packed.

Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
builtin,
call_builtin_with_ctx,
call_tir,
call_dps_packed,
ceil,
clip,
collapse_sum_like,
Expand Down Expand Up @@ -534,6 +535,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"builtin",
"call_packed",
"call_tir",
"call_dps_packed",
"call_builtin_with_ctx",
"ceil",
"clip",
Expand Down
14 changes: 14 additions & 0 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ class WellFormedChecker : public relax::ExprVisitor,
} else {
Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression");
}

for (size_t i = 0; i < op->args.size(); i++) {
Expr arg = op->args[i];
if (IsLeafOrTuple(arg)) {
Expand All @@ -260,6 +261,19 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}

// call_tir works for tir PrimFunc, call_dps_packed works for ExternFunc
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
if (op->op == call_tir_op && !op->args[0]->IsInstance<GlobalVarNode>()) {
Malformed(Diagnostic::Error(op->args[0]->span)
<< "call_tir expects a prim_func, but gets " << op->args[0]->GetTypeKey());
}
slyubomirsky marked this conversation as resolved.
Show resolved Hide resolved
if (op->op == call_dps_packed_op && !op->args[0]->IsInstance<ExternFuncNode>()) {
Malformed(Diagnostic::Error(op->args[0]->span)
<< "call_dps_packed expects an extern func, but gets "
<< op->args[0]->GetTypeKey());
}

for (const StructInfo& sinfo_arg : op->sinfo_args) {
this->VisitStructInfo(sinfo_arg);
}
Expand Down
14 changes: 14 additions & 0 deletions src/relax/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,20 @@ CallPattern IsCallTIR(const String& name, Optional<TuplePattern> var_args) {
CallPattern IsCallTIR(const String& name, TuplePattern var_args) {
return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args);
}
CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> var_args) {
DFPattern arg_pattern;
if (!var_args.defined()) {
arg_pattern = Wildcard();
} else {
arg_pattern = var_args.value();
}

return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern);
}

CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) {
return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args);
}

DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered) {
if (unordered)
Expand Down
38 changes: 38 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,44 @@ Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,

TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR);

// call_dps_packed

StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exact 1 output struct info.");
}
return call->sinfo_args[0];
}

RELAY_REGISTER_OP("relax.call_dps_packed")
.set_num_inputs(2)
.add_argument("func", "Expr", "The destination-passing-style function.")
.add_argument("args", "Tuple", "The input arguments.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallDPSPacked);

Expr MakeCallDPSPacked(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
yongwww marked this conversation as resolved.
Show resolved Hide resolved
const auto* shape = sinfo->shape.as<ShapeExprNode>();
CHECK(shape != nullptr)
<< "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. "
"However, one given structure info is "
<< sinfo;
}

StructInfo out_sinfo{nullptr};
if (out_sinfo_list.size() == 1) {
out_sinfo = out_sinfo_list[0];
} else {
out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()});
}

static const Op& op = Op::Get("relax.call_dps_packed");
return Call(op, {func, args}, {}, {out_sinfo});
}

TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked);

// call builtin
StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() == 0) {
Expand Down
5 changes: 3 additions & 2 deletions src/relax/transform/call_tir_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace relax {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we rename this file and pass to reflect that it handles both call_dps_packed and call_tir?

Copy link
Member Author

@yongwww yongwww Mar 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this, but don't have a better name, any suggestions? if we change this, then the pass name CallTIRRewrite needs to be updated accordingly. We can keep it untouched for now if we don't have an alternative names,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DPSRewrite would be my choice but we can discuss and change it later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! We can discuss it later, it might need to get wider agreement for pass renaming.

// ==================
// CallTIRMutator
// Perform explicit tensor allocation for call_tir.
// Perform explicit tensor allocation for call_tir or call_dps_packed.
// Example:
// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32")
// -->
Expand All @@ -49,10 +49,11 @@ class CallTIRMutator : public ExprMutator {
call = expr.as<CallNode>();

static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");

if (call->op == call_tir_op) {
if (call->op == call_tir_op || call->op == call_dps_packed_op) {
Array<Expr> outs;
if (const auto& _tensor_sinfo = MatchStructInfo<TensorStructInfo>(expr)) {
// single output case
Expand Down
9 changes: 5 additions & 4 deletions src/relax/transform/run_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,18 @@ class CodeGenRunner : ExprMutator {
if (auto const* gvar_node = call_node->op.as<GlobalVarNode>()) {
const GlobalVar gvar = GetRef<GlobalVar>(gvar_node);

auto create_call_tir = [call_node, this](Expr extern_func, StructInfo ret_struct_info) {
auto create_call_dps_packed = [call_node, this](Expr extern_func,
StructInfo ret_struct_info) {
Array<Expr> new_args({extern_func});
new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); })));

static const Op& call_op = Op::Get("relax.call_tir");
static const Op& call_op = Op::Get("relax.call_dps_packed");

return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info});
};

if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
return create_call_tir(it->second.first, it->second.second);
return create_call_dps_packed(it->second.first, it->second.second);
} else {
// TODO(@sunggg): Is there any better way to get this func?
Function func = Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
Expand All @@ -101,7 +102,7 @@ class CodeGenRunner : ExprMutator {
func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol);
func = (*RemoveFuncAttrFunc)(func, attr::kCodegen);
builder_->UpdateFunction(gvar, func);
return create_call_tir(new_func, func->ret_struct_info);
return create_call_dps_packed(new_func, func->ret_struct_info);
}
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/script/printer/relax/call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifi

Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
if (!n->op.same_as(call_tir_op)) {
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op)) {
return NullOpt;
}
ICHECK(n->args.size() == 2 || n->args.size() == 3);
Expand All @@ -123,6 +124,9 @@ Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, cons
} else {
kwargs_values.push_back(d->AsDoc<ExprDoc>(o_sinfo, o_sinfo_p));
}
if (n->op.same_as(call_dps_packed_op)) {
return Relax(d, "call_dps_packed")->Call(args, kwargs_keys, kwargs_values);
}
// Step 4. Print n->args[2], the tir variables
if (n->args.size() == 3) {
kwargs_keys.push_back("tir_vars");
Expand All @@ -134,7 +138,7 @@ Optional<ExprDoc> PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, cons
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::Call>( //
"", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc {
// Special case: call_tir
// Special case: call_tir, call_dps_packed
if (Optional<ExprDoc> doc = PrintCallTIR(n, n_p, d)) {
yongwww marked this conversation as resolved.
Show resolved Hide resolved
return doc.value();
}
Expand Down
3 changes: 2 additions & 1 deletion src/script/printer/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ inline Optional<ExprDoc> StructInfoAsAnn(const relax::Var& v, const ObjectPath&
}
if (const auto* call = rhs.as<relax::CallNode>()) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
if (call->op.same_as(call_tir_op)) {
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) {
return NullOpt;
}
}
Expand Down
12 changes: 8 additions & 4 deletions tests/python/relax/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ class IdentityUnused:
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
lv0 = x
unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32"))
unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32"))
unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32"))
unused1 = R.call_dps_packed(
"my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")
)
R.output(lv0)
return lv0

Expand All @@ -92,8 +94,10 @@ class IdentityUnused:
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
with R.dataflow():
lv0 = x
unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32"))
unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32"))
unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32"))
unused1 = R.call_dps_packed(
"my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")
)
R.output(lv0)
z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32")))
return z
Expand Down
Loading