Skip to content

Commit

Permalink
[Unity] Introduce call_dps_packed (#14183)
Browse files Browse the repository at this point in the history
Introduce call_dps_packed to call packed functions in destination-passing style, reserving call_tir for TIR PrimFuncs instead.

* [Unity] Introduce call_dps_packed

* fix lint

* Fix comments

* Remove well_form update, enforce in InferStructInfoCallTIR

* Update src/relax/op/op.cc

* Update description of call_tir

* Remove unnecessary check in passes
  • Loading branch information
yongwww authored and tqchen committed Mar 13, 2023
1 parent 0e56872 commit aaf0008
Show file tree
Hide file tree
Showing 37 changed files with 567 additions and 276 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,10 @@ ExprPattern IsOp(const String& op_name);
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> args = NullOpt);
/*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */
CallPattern IsCallTIR(const String& name, TuplePattern var_args);
/*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */
CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> args = NullOpt);
/*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */
CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args);
/*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */
DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered = false);
/*! \brief Syntatic Sugar for creating a TupleGetItemPattern */
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TVM_DLL Pass LambdaLift();
TVM_DLL Pass ToNonDataflow();

/*!
* \brief Perform explicit tensor allocation for call_tir.
* \brief Perform explicit tensor allocation for call_tir and call_dps_packed.
*
* \return The Pass.
*/
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .exec_builder import ExecBuilder

# Operator
from .op.base import call_tir
from .op.base import call_tir, call_dps_packed

# BlockBuilder
from .block_builder import BlockBuilder
Expand Down
18 changes: 15 additions & 3 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,23 @@ def is_call_tir(
return _is_call_tir(func_pattern, args)


def is_call_tir_extern(
def _is_call_dps_packed(
func_pattern: DFPattern,
args: Union[List, Tuple, TuplePattern] = None,
) -> CallPattern:
if args is None:
args = wildcard()
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)

return is_op("relax.call_dps_packed")(func_pattern, args)


def is_call_dps_packed(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
) -> CallPattern:
"""Syntax sugar for creating a CallPattern for call_tir that calls an extern function
"""Syntax sugar for creating a CallPattern for call_dps_packed
Parameters
----------
Expand All @@ -881,7 +893,7 @@ def is_call_tir_extern(
The resulting CallPattern
"""
func_pattern = ExternFuncPattern(func_name)
return _is_call_tir(func_pattern, args)
return _is_call_dps_packed(func_pattern, args)


def is_call_packed(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def __call__(self, *args):

@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
"""extern function, which can represent a TIR PrimFunc or a PackedFunc."""
"""extern function, which represents a PackedFunc."""

global_symbol: String

Expand Down
54 changes: 45 additions & 9 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.runtime.object import Object

from . import _ffi_api
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ...ir import PrimExpr
Expand All @@ -45,18 +45,18 @@ def null_value() -> Call:

@args_converter.auto
def call_tir(
func: Union[str, Expr],
gvar: GlobalVar,
args: Expr,
out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None,
) -> Call:
"""
Call a destination-passing-style function and return the output.
Call a tir.prim_func and return the output.
Parameters
----------
func : Union[str, Expr]
The destination-passing-style function, can be ExternFunc or PrimFunc.
gvar : GlobalVar
The GlobalVar referring to a tir PrimFunc.
args : Expr
The input arguments.
Expand All @@ -74,9 +74,6 @@ def call_tir(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(func, str):
func = ExternFunc(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))

Expand All @@ -86,7 +83,46 @@ def call_tir(
if isinstance(tir_vars, (list, tuple)):
tir_vars = ShapeExpr(tir_vars)

return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore
return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore


@args_converter.auto
def call_dps_packed(
func: Union[str, Expr],
args: Expr,
out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
) -> Call:
"""
Call a destination-passing-style packed function and return the output.
Parameters
----------
func : Union[str, Expr]
The destination-passing-style function, can be ExternFunc.
args : Expr
The input arguments.
out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
The structure info of the call_dps_packed output.
It should be a single or a list of TensorStructInfo. Each one denotes the
structure info of a returned tensor.
Returns
-------
ret: Call
A call node for the call_dps_packed operator.
"""
if isinstance(func, str):
func = ExternFunc(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]

return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore


@args_converter.auto
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def LambdaLift():


def CallTIRRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_tir.
"""Perform explicit tensor allocation for call_tir and call_dps_packed.
Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
builtin,
call_builtin_with_ctx,
call_tir,
call_dps_packed,
ceil,
clip,
collapse_sum_like,
Expand Down Expand Up @@ -545,6 +546,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"builtin",
"call_packed",
"call_tir",
"call_dps_packed",
"call_builtin_with_ctx",
"ceil",
"clip",
Expand Down
5 changes: 0 additions & 5 deletions src/relax/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ class TaskExtractor : public ExprVisitor {
return;
}

// Do not extract external function
if (call->args[0].as<ExternFuncNode>()) {
return;
}

const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
const tir::PrimFunc& func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));

Expand Down
14 changes: 14 additions & 0 deletions src/relax/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,20 @@ CallPattern IsCallTIR(const String& name, Optional<TuplePattern> var_args) {
CallPattern IsCallTIR(const String& name, TuplePattern var_args) {
return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args);
}
CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> var_args) {
DFPattern arg_pattern;
if (!var_args.defined()) {
arg_pattern = Wildcard();
} else {
arg_pattern = var_args.value();
}

return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern);
}

CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) {
return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args);
}

DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered) {
if (unordered)
Expand Down
41 changes: 41 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exact 1 output struct info.");
}
CHECK(call->args[0]->IsInstance<GlobalVarNode>())
<< "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. "
<< "However, gets " << call->args[0];
return call->sinfo_args[0];
}

Expand Down Expand Up @@ -121,6 +124,44 @@ Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,

TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR);

// call_dps_packed

StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exact 1 output struct info.");
}
return call->sinfo_args[0];
}

RELAY_REGISTER_OP("relax.call_dps_packed")
.set_num_inputs(2)
.add_argument("func", "Expr", "The destination-passing-style function.")
.add_argument("args", "Tuple", "The input arguments.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallDPSPacked);

Expr MakeCallDPSPacked(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
CHECK(shape != nullptr)
<< "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. "
"However, one given structure info is "
<< sinfo;
}

StructInfo out_sinfo{nullptr};
if (out_sinfo_list.size() == 1) {
out_sinfo = out_sinfo_list[0];
} else {
out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()});
}

static const Op& op = Op::Get("relax.call_dps_packed");
return Call(op, {func, args}, {}, {out_sinfo});
}

TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked);

// call builtin
StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() == 0) {
Expand Down
5 changes: 3 additions & 2 deletions src/relax/transform/call_tir_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace relax {

// ==================
// CallTIRMutator
// Perform explicit tensor allocation for call_tir.
// Perform explicit tensor allocation for call_tir or call_dps_packed.
// Example:
// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32")
// -->
Expand All @@ -49,10 +49,11 @@ class CallTIRMutator : public ExprMutator {
call = expr.as<CallNode>();

static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");

if (call->op == call_tir_op) {
if (call->op == call_tir_op || call->op == call_dps_packed_op) {
Array<Expr> outs;
if (const auto& _tensor_sinfo = MatchStructInfo<TensorStructInfo>(expr)) {
// single output case
Expand Down
12 changes: 5 additions & 7 deletions src/relax/transform/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,11 @@ class ConstantFolder : public ExprMutator {
* \return The TIR function, or nullopt if pattern match fails.
*/
Optional<tir::PrimFunc> MatchPrimFunc(const Expr& op) {
if (auto* ptr = op.as<GlobalVarNode>()) {
// NOTE: as check works for nullptr(returns null)
Optional<BaseFunc> base_func =
builder_->GetContextIRModule()->functions.Get(GetRef<GlobalVar>(ptr));
if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
return GetRef<tir::PrimFunc>(pfunc);
}
const GlobalVar& global_var = Downcast<GlobalVar>(op);
// NOTE: as check works for nullptr(returns null)
Optional<BaseFunc> base_func = builder_->GetContextIRModule()->functions.Get(global_var);
if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
return GetRef<tir::PrimFunc>(pfunc);
}
return NullOpt;
}
Expand Down
20 changes: 9 additions & 11 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,17 @@ class GraphCreator : public ExprVisitor {
// recurse into the call expression.
const auto* op = call->op.as<OpNode>();
if (op == call_tir_op_.get()) {
// Skip ExternFunc for call_dps_packed.
if (const auto* global_var = call->args[0].as<GlobalVarNode>()) {
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(GetRef<GlobalVar>(global_var)));
const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
tir::PrimFunc func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));

// Override args for call_tir
args = Downcast<Tuple>(call->args[1])->fields;
// Override args for call_tir
args = Downcast<Tuple>(call->args[1])->fields;

Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
if (opt_pattern.defined()) {
pattern = static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
} else {
pattern = OpPatternKind::kOpaque;
}
Optional<Integer> opt_pattern = func->GetAttr<Integer>("op_pattern");
if (opt_pattern.defined()) {
pattern = static_cast<OpPatternKind>(Downcast<IntImm>(opt_pattern)->value);
} else {
pattern = OpPatternKind::kOpaque;
}
}
// The pattern of the current binding variable node is set to the pattern of this operator.
Expand Down
23 changes: 3 additions & 20 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,10 @@ class FusedTIRConstructor : public ExprVisitor {

// Step 1. Get Global var and PrimFunc
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
Optional<tir::PrimFunc> prim_func_ = GetPrimFunc(gv);
ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir in the module: "
<< gv;
tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));

// Step 2. Renew all vars/buffer definitions and blocks to avoid duplication
tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value());
tir::PrimFunc prim_func = tir::RenewDefs(prim_func_);

// Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block
// TODO(Siyuan): support un-schedulable functions.
Expand Down Expand Up @@ -364,22 +363,6 @@ class FusedTIRConstructor : public ExprVisitor {
LOG(FATAL) << "Relax.Constant is not supported in primitive functions.";
}

/********** Helper Functions **********/

/*!
* \brief Pattern match op to a TIR function and look it up.
* \return The TIR function, or NullOpt if patter match fails.
*/
Optional<tir::PrimFunc> GetPrimFunc(const GlobalVar& global_var) {
// NOTE: as check works for nullptr(returns null)
Optional<BaseFunc> base_func = mod_->functions.Get(global_var);
if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
return GetRef<tir::PrimFunc>(pfunc);
} else {
return NullOpt;
}
}

/*!
* \brief Get the number of outputs for a call_tir node.
* \return The number of outputs.
Expand Down
Loading

0 comments on commit aaf0008

Please sign in to comment.