Skip to content

Commit

Permalink
[Unity] Introduce call_dps_packed
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Mar 3, 2023
1 parent 69f0abb commit 416cd24
Show file tree
Hide file tree
Showing 33 changed files with 651 additions and 210 deletions.
4 changes: 4 additions & 0 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,10 @@ ExprPattern IsOp(const String& op_name);
CallPattern IsCallTIR(const String& name, Optional<TuplePattern> args = NullOpt);
/*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */
CallPattern IsCallTIR(const String& name, TuplePattern var_args);
/*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */
CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> args = NullOpt);
/*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */
CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args);
/*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */
DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered = false);
/*! \brief Syntatic Sugar for creating a TupleGetItemPattern */
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ TVM_DLL Pass ToNonDataflow();
*/
TVM_DLL Pass CallTIRRewrite();

/*!
* \brief Perform explicit tensor allocation for call_dps_packed.
*
* \return The Pass.
*/
TVM_DLL Pass CallDPSPackedRewrite();

/*!
* \brief Convert all reshape-like call_tir whose corresponding binding
* vars are DataflowVars to relax.reshape operator calls. The relax.reshape
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from .exec_builder import ExecBuilder

# Operator
from .op.base import call_tir
from .op.base import call_tir, call_dps_packed

# BlockBuilder
from .block_builder import BlockBuilder
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Binding,
)
from .struct_info import PrimStructInfo, ShapeStructInfo, StructInfo, TensorStructInfo
from .op.base import call_tir
from .op.base import call_tir, call_dps_packed
from . import _ffi_api


Expand Down
18 changes: 15 additions & 3 deletions python/tvm/relax/dpl/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,23 @@ def is_call_tir(
return _is_call_tir(func_pattern, args)


def is_call_tir_extern(
def _is_call_dps_packed(
func_pattern: DFPattern,
args: Union[List, Tuple, TuplePattern] = None,
) -> CallPattern:
if args is None:
args = wildcard()
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)

return is_op("relax.call_dps_packed")(func_pattern, args)


def is_call_dps_packed(
func_name: str,
args: Union[List, Tuple, TuplePattern] = None,
) -> CallPattern:
"""Syntax sugar for creating a CallPattern for call_tir that calls an extern function
"""Syntax sugar for creating a CallPattern for call_dps_packed
Parameters
----------
Expand All @@ -881,7 +893,7 @@ def is_call_tir_extern(
The resulting CallPattern
"""
func_pattern = ExternFuncPattern(func_name)
return _is_call_tir(func_pattern, args)
return _is_call_dps_packed(func_pattern, args)


def is_call_packed(
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def __call__(self, *args):

@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
"""extern function, which can represent a TIR PrimFunc or a PackedFunc."""
"""extern function, which represents a PackedFunc."""

global_symbol: String

Expand Down
47 changes: 43 additions & 4 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.runtime.object import Object

from . import _ffi_api
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ...ir import PrimExpr
Expand Down Expand Up @@ -51,12 +51,12 @@ def call_tir(
tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None,
) -> Call:
"""
Call a destination-passing-style function and return the output.
Call a tir.prim_func function and return the output.
Parameters
----------
func : Union[str, Expr]
The destination-passing-style function, can be ExternFunc or PrimFunc.
The tir PrimFunc function.
args : Expr
The input arguments.
Expand All @@ -75,7 +75,7 @@ def call_tir(
A call node for the call_tir operator.
"""
if isinstance(func, str):
func = ExternFunc(func)
func = GlobalVar(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
Expand All @@ -89,6 +89,45 @@ def call_tir(
return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore


@args_converter.auto
def call_dps_packed(
func: Union[str, Expr],
args: Expr,
out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]],
) -> Call:
"""
Call a destination-passing-style packed function and return the output.
Parameters
----------
func : Union[str, Expr]
The destination-passing-style function, can be ExternFunc.
args : Expr
The input arguments.
out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]]
The structure info of the call_dps_packed output.
It should be a single or a list of TensorStructInfo. Each one denotes the
structure info of a returned tensor.
Returns
-------
ret: Call
A call node for the call_dps_packed operator.
"""
if isinstance(func, str):
func = ExternFunc(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]

return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore


@args_converter.auto
def call_builtin_with_ctx(
func: Union[str, Expr],
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass:
return _ffi_api.CallTIRRewrite() # type: ignore


def CallDPSPackedRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_dps_packed.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.CallDPSPackedRewrite() # type: ignore


def Normalize() -> tvm.ir.transform.Pass:
"""Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting
and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):
passes.append(relax.transform.RewriteDataflowReshape())
passes.append(relax.transform.ToNonDataflow())
passes.append(relax.transform.CallTIRRewrite())
passes.append(relax.transform.CallDPSPackedRewrite())
passes.append(relax.transform.StaticPlanBlockMemory())
passes.append(relax.transform.VMBuiltinLower())
passes.append(relax.transform.VMShapeLower())
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
builtin,
call_builtin_with_ctx,
call_tir,
call_dps_packed,
ceil,
clip,
collapse_sum_like,
Expand Down Expand Up @@ -529,6 +530,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"builtin",
"call_packed",
"call_tir",
"call_dps_packed",
"call_builtin_with_ctx",
"ceil",
"clip",
Expand Down
14 changes: 14 additions & 0 deletions src/relax/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,20 @@ CallPattern IsCallTIR(const String& name, Optional<TuplePattern> var_args) {
CallPattern IsCallTIR(const String& name, TuplePattern var_args) {
return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args);
}
CallPattern IsCallDPSPacked(const String& name, Optional<TuplePattern> var_args) {
DFPattern arg_pattern;
if (!var_args.defined()) {
arg_pattern = Wildcard();
} else {
arg_pattern = var_args.value();
}

return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern);
}

CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) {
return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args);
}

DFPattern IsTuple(const Array<DFPattern>& fields, bool unordered) {
if (unordered)
Expand Down
39 changes: 39 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,45 @@ Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,

TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR);


// call_dps_packed

StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exact 1 output struct info.");
}
return call->sinfo_args[0];
}

RELAY_REGISTER_OP("relax.call_dps_packed")
.set_num_inputs(2)
.add_argument("func", "Expr", "The destination-passing-style function.")
.add_argument("args", "Tuple", "The input arguments.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallDPSPacked);

Expr MakeCallDPSPacked(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list) {
for (const TensorStructInfo& sinfo : out_sinfo_list) {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
CHECK(shape != nullptr) << "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. "
"However, one given structure info is "
<< sinfo;
}

StructInfo out_sinfo{nullptr};
if (out_sinfo_list.size() == 1) {
out_sinfo = out_sinfo_list[0];
} else {
out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()});
}

static const Op& op = Op::Get("relax.call_dps_packed");
return Call(op, {func, args}, {}, {out_sinfo});
}

TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked);


// call builtin
StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) {
if (call->sinfo_args.size() == 0) {
Expand Down
Loading

0 comments on commit 416cd24

Please sign in to comment.