From d268b13cac068907018c02bb0defcb591d61369e Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 10 Mar 2023 09:41:50 -0800 Subject: [PATCH] [Unity] Introduce call_dps_packed (#14183) Introduce call_dps_packed to call packed functions in destination-passing style, reserving call_tir for TIR PrimFuncs instead. * [Unity] Introduce call_dps_packed * fix lint * Fix comments * Remove well_form update, enforce in InferStructInfoCallTIR * Update src/relax/op/op.cc * Update description of call_tir * Remove unnecessary check in passes --- include/tvm/relax/dataflow_pattern.h | 4 + include/tvm/relax/transform.h | 2 +- python/tvm/relax/__init__.py | 2 +- python/tvm/relax/dpl/pattern.py | 18 +- python/tvm/relax/expr.py | 2 +- python/tvm/relax/op/base.py | 54 +++- python/tvm/relax/transform/transform.py | 2 +- python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/backend/task_extraction.cc | 5 - src/relax/ir/dataflow_pattern.cc | 14 + src/relax/op/op.cc | 41 +++ src/relax/transform/call_tir_rewrite.cc | 5 +- src/relax/transform/fold_constant.cc | 12 +- src/relax/transform/fuse_ops.cc | 20 +- src/relax/transform/fuse_tir.cc | 23 +- src/relax/transform/legalize_ops.cc | 3 +- .../transform/rewrite_dataflow_reshape.cc | 7 +- src/relax/transform/run_codegen.cc | 9 +- src/script/printer/relax/call.cc | 13 +- src/script/printer/relax/utils.h | 3 +- tests/python/relax/test_analysis.py | 12 +- .../python/relax/test_analysis_well_formed.py | 6 +- tests/python/relax/test_ast_printer.py | 72 ++++- tests/python/relax/test_binding_rewrite.py | 16 +- tests/python/relax/test_dataflow_pattern.py | 254 ++++++++++-------- tests/python/relax/test_op_misc.py | 2 +- tests/python/relax/test_transform.py | 46 +++- .../test_transform_attach_global_symbol.py | 4 +- .../relax/test_transform_bind_params.py | 4 +- .../relax/test_transform_codegen_pass.py | 4 +- tests/python/relax/test_transform_fuse_ops.py | 4 +- tests/python/relax/test_transform_fuse_tir.py | 2 +- .../python/relax/test_transform_normalize.py | 6 +- .../python/relax/test_tvmscript_ir_builder.py | 26 +- tests/python/relax/test_tvmscript_parser.py | 123 ++++++--- .../relax/test_tvmscript_printer_relax.py | 15 +- tests/python/relax/test_vm_build.py | 6 +- 37 files changed, 567 insertions(+), 276 deletions(-) diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 701879745efa6..37640750a8ef5 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -793,6 +793,10 @@ ExprPattern IsOp(const String& op_name); CallPattern IsCallTIR(const String& name, Optional args = NullOpt); /*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ CallPattern IsCallTIR(const String& name, TuplePattern var_args); +/*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */ +CallPattern IsCallDPSPacked(const String& name, Optional args = NullOpt); +/*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */ +CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args); /*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */ DFPattern IsTuple(const Array& fields, bool unordered = false); /*! \brief Syntatic Sugar for creating a TupleGetItemPattern */ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 9838fe53b3ca3..446b75da9f93b 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -83,7 +83,7 @@ TVM_DLL Pass LambdaLift(); TVM_DLL Pass ToNonDataflow(); /*! - * \brief Perform explicit tensor allocation for call_tir. + * \brief Perform explicit tensor allocation for call_tir and call_dps_packed. * * \return The Pass. */ diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index bbd2040dd9d64..edbd848bd598b 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -56,7 +56,7 @@ from .exec_builder import ExecBuilder # Operator -from .op.base import call_tir +from .op.base import call_tir, call_dps_packed # BlockBuilder from .block_builder import BlockBuilder diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 300b0af568c0a..1ca41b378da54 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -862,11 +862,23 @@ def is_call_tir( return _is_call_tir(func_pattern, args) -def is_call_tir_extern( +def _is_call_dps_packed( + func_pattern: DFPattern, + args: Union[List, Tuple, TuplePattern] = None, +) -> CallPattern: + if args is None: + args = wildcard() + elif isinstance(args, (list, tuple)): + args = TuplePattern(args) + + return is_op("relax.call_dps_packed")(func_pattern, args) + + +def is_call_dps_packed( func_name: str, args: Union[List, Tuple, TuplePattern] = None, ) -> CallPattern: - """Syntax sugar for creating a CallPattern for call_tir that calls an extern function + """Syntax sugar for creating a CallPattern for call_dps_packed Parameters ---------- @@ -881,7 +893,7 @@ def is_call_tir_extern( The resulting CallPattern """ func_pattern = ExternFuncPattern(func_name) - return _is_call_tir(func_pattern, args) + return _is_call_dps_packed(func_pattern, args) def is_call_packed( diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index ab332eed618ec..4af08a3118fc8 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -597,7 +597,7 @@ def __call__(self, *args): @tvm._ffi.register_object("relax.expr.ExternFunc") class ExternFunc(BaseFunc): - """extern function, which can represent a TIR PrimFunc or a PackedFunc.""" + """extern function, which represents a PackedFunc.""" global_symbol: String diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 0b298679c1c56..aef0e731db51e 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -22,7 +22,7 @@ from tvm.runtime.object import Object from . import _ffi_api -from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc +from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar from ..expr import Tuple as RxTuple from ..struct_info import StructInfo, TensorStructInfo from ...ir import PrimExpr @@ -45,18 +45,18 @@ def null_value() -> Call: @args_converter.auto def call_tir( - func: Union[str, Expr], + gvar: GlobalVar, args: Expr, out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None, ) -> Call: """ - Call a destination-passing-style function and return the output. + Call a tir.prim_func and return the output. Parameters ---------- - func : Union[str, Expr] - The destination-passing-style function, can be ExternFunc or PrimFunc. + gvar : GlobalVar + The GlobalVar referring to a tir PrimFunc. args : Expr The input arguments. @@ -74,9 +74,6 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if isinstance(func, str): - func = ExternFunc(func) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) @@ -86,7 +83,46 @@ def call_tir( if isinstance(tir_vars, (list, tuple)): tir_vars = ShapeExpr(tir_vars) - return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore + return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore + + +@args_converter.auto +def call_dps_packed( + func: Union[str, Expr], + args: Expr, + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], +) -> Call: + """ + Call a destination-passing-style packed function and return the output. + + Parameters + ---------- + func : Union[str, Expr] + The destination-passing-style function, can be ExternFunc. + + args : Expr + The input arguments. + + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_dps_packed output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + Returns + ------- + ret: Call + A call node for the call_dps_packed operator. + """ + if isinstance(func, str): + func = ExternFunc(func) + + if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + args = RxTuple((args,)) + + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore @args_converter.auto diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 48560792e143b..97c8772b3b59b 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -60,7 +60,7 @@ def LambdaLift(): def CallTIRRewrite() -> tvm.ir.transform.Pass: - """Perform explicit tensor allocation for call_tir. + """Perform explicit tensor allocation for call_tir and call_dps_packed. Returns ------- diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 466a38837f87d..c658b6f77dba8 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -48,6 +48,7 @@ builtin, call_builtin_with_ctx, call_tir, + call_dps_packed, ceil, clip, collapse_sum_like, @@ -545,6 +546,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "builtin", "call_packed", "call_tir", + "call_dps_packed", "call_builtin_with_ctx", "ceil", "clip", diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index beb3950af1d1d..5bd764c68e780 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -73,11 +73,6 @@ class TaskExtractor : public ExprVisitor { return; } - // Do not extract external function - if (call->args[0].as()) { - return; - } - const GlobalVar& global_var = Downcast(call->args[0]); const tir::PrimFunc& func = Downcast(mod_->Lookup(global_var)); diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 3768627c204cf..5eb1bf3ea6f66 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -563,6 +563,20 @@ CallPattern IsCallTIR(const String& name, Optional var_args) { CallPattern IsCallTIR(const String& name, TuplePattern var_args) { return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args); } +CallPattern IsCallDPSPacked(const String& name, Optional var_args) { + DFPattern arg_pattern; + if (!var_args.defined()) { + arg_pattern = Wildcard(); + } else { + arg_pattern = var_args.value(); + } + + return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern); +} + +CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) { + return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args); +} DFPattern IsTuple(const Array& fields, bool unordered) { if (unordered) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index c78ca539bd72b..cf084d6d2038b 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -80,6 +80,9 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exact 1 output struct info."); } + CHECK(call->args[0]->IsInstance()) + << "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " + << "However, gets " << call->args[0]; return call->sinfo_args[0]; } @@ -121,6 +124,44 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); +// call_dps_packed + +StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "sinfo_args should have exact 1 output struct info."); + } + return call->sinfo_args[0]; +} + +RELAY_REGISTER_OP("relax.call_dps_packed") + .set_num_inputs(2) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .set_attr("FInferStructInfo", InferStructInfoCallDPSPacked); + +Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_list) { + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) + << "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + + static const Op& op = Op::Get("relax.call_dps_packed"); + return Call(op, {func, args}, {}, {out_sinfo}); +} + +TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked); + // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() == 0) { diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index 2ea039e0229bb..6066ed8d2a7d4 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -33,7 +33,7 @@ namespace relax { // ================== // CallTIRMutator -// Perform explicit tensor allocation for call_tir. +// Perform explicit tensor allocation for call_tir or call_dps_packed. // Example: // lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32") // --> @@ -49,10 +49,11 @@ class CallTIRMutator : public ExprMutator { call = expr.as(); static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn"); - if (call->op == call_tir_op) { + if (call->op == call_tir_op || call->op == call_dps_packed_op) { Array outs; if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { // single output case diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 6b28f31889151..622dd9ad09b7b 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -87,13 +87,11 @@ class ConstantFolder : public ExprMutator { * \return The TIR function, or nullopt if pattern match fails. */ Optional MatchPrimFunc(const Expr& op) { - if (auto* ptr = op.as()) { - // NOTE: as check works for nullptr(returns null) - Optional base_func = - builder_->GetContextIRModule()->functions.Get(GetRef(ptr)); - if (auto* pfunc = base_func.as()) { - return GetRef(pfunc); - } + const GlobalVar& global_var = Downcast(op); + // NOTE: as check works for nullptr(returns null) + Optional base_func = builder_->GetContextIRModule()->functions.Get(global_var); + if (auto* pfunc = base_func.as()) { + return GetRef(pfunc); } return NullOpt; } diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 3b6b3c17ac9c3..6d7c278d8076d 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -185,19 +185,17 @@ class GraphCreator : public ExprVisitor { // recurse into the call expression. const auto* op = call->op.as(); if (op == call_tir_op_.get()) { - // Skip ExternFunc for call_dps_packed. - if (const auto* global_var = call->args[0].as()) { - tir::PrimFunc func = Downcast(mod_->Lookup(GetRef(global_var))); + const GlobalVar& global_var = Downcast(call->args[0]); + tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); - // Override args for call_tir - args = Downcast(call->args[1])->fields; + // Override args for call_tir + args = Downcast(call->args[1])->fields; - Optional opt_pattern = func->GetAttr("op_pattern"); - if (opt_pattern.defined()) { - pattern = static_cast(Downcast(opt_pattern)->value); - } else { - pattern = OpPatternKind::kOpaque; - } + Optional opt_pattern = func->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + pattern = static_cast(Downcast(opt_pattern)->value); + } else { + pattern = OpPatternKind::kOpaque; } } // The pattern of the current binding variable node is set to the pattern of this operator. diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 925f09d85d340..e90d6e4bc1d11 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -302,11 +302,10 @@ class FusedTIRConstructor : public ExprVisitor { // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); - Optional prim_func_ = GetPrimFunc(gv); - ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir in the module: " - << gv; + tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); + // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication - tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value()); + tir::PrimFunc prim_func = tir::RenewDefs(prim_func_); // Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block // TODO(Siyuan): support un-schedulable functions. @@ -364,22 +363,6 @@ class FusedTIRConstructor : public ExprVisitor { LOG(FATAL) << "Relax.Constant is not supported in primitive functions."; } - /********** Helper Functions **********/ - - /*! - * \brief Pattern match op to a TIR function and look it up. - * \return The TIR function, or NullOpt if patter match fails. - */ - Optional GetPrimFunc(const GlobalVar& global_var) { - // NOTE: as check works for nullptr(returns null) - Optional base_func = mod_->functions.Get(global_var); - if (auto* pfunc = base_func.as()) { - return GetRef(pfunc); - } else { - return NullOpt; - } - } - /*! * \brief Get the number of outputs for a call_tir node. * \return The number of outputs. diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index f9a84c536101d..350a40c37bf87 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -75,6 +75,7 @@ class LegalizeMutator : public ExprMutator { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); auto* op_node = visited_call->op.as(); // Not an OpNode @@ -102,7 +103,7 @@ class LegalizeMutator : public ExprMutator { } // No legalization. - if (op != call_tir_op) { + if (op != call_tir_op && op != call_dps_packed_op) { LOG(WARNING) << "No legalization func for " << op->name << " is found."; } return visited_call; diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index aec0911ecc5a5..e5d654fba355d 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -75,11 +75,8 @@ class DataflowReshapeRewriter : public ExprMutator { if (call->op != call_tir_op) { return false; } - const auto* gv = call->args[0].as(); - if (gv == nullptr) { - return false; - } - const auto* func = mod_->functions.Get(GetRef(gv)).as(); + const GlobalVar& global_var = Downcast(call->args[0]); + const auto* func = mod_->functions.Get(global_var).as(); ICHECK_NOTNULL(func); return HasReshapePattern(GetRef(func)); } diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 7deeb139d1a00..b5a4d7536f7f2 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -75,17 +75,18 @@ class CodeGenRunner : ExprMutator { if (auto const* gvar_node = call_node->op.as()) { const GlobalVar gvar = GetRef(gvar_node); - auto create_call_tir = [call_node, this](Expr extern_func, StructInfo ret_struct_info) { + auto create_call_dps_packed = [call_node, this](Expr extern_func, + StructInfo ret_struct_info) { Array new_args({extern_func}); new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); }))); - static const Op& call_op = Op::Get("relax.call_tir"); + static const Op& call_op = Op::Get("relax.call_dps_packed"); return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info}); }; if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) { - return create_call_tir(it->second.first, it->second.second); + return create_call_dps_packed(it->second.first, it->second.second); } else { // TODO(@sunggg): Is there any better way to get this func? Function func = Downcast(builder_->GetContextIRModule()->Lookup(gvar)); @@ -101,7 +102,7 @@ class CodeGenRunner : ExprMutator { func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol); func = (*RemoveFuncAttrFunc)(func, attr::kCodegen); builder_->UpdateFunction(gvar, func); - return create_call_tir(new_func, func->ret_struct_info); + return create_call_dps_packed(new_func, func->ret_struct_info); } } } diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 2feb2082c510c..e99b81df8b0cd 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -95,9 +95,11 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifi } } -Optional PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { +Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& n_p, + const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); - if (!n->op.same_as(call_tir_op)) { + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op)) { return NullOpt; } ICHECK(n->args.size() == 2 || n->args.size() == 3); @@ -123,6 +125,9 @@ Optional PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, cons } else { kwargs_values.push_back(d->AsDoc(o_sinfo, o_sinfo_p)); } + if (n->op.same_as(call_dps_packed_op)) { + return Relax(d, "call_dps_packed")->Call(args, kwargs_keys, kwargs_values); + } // Step 4. Print n->args[2], the tir variables if (n->args.size() == 3) { kwargs_keys.push_back("tir_vars"); @@ -134,8 +139,8 @@ Optional PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, cons TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { - // Special case: call_tir - if (Optional doc = PrintCallTIR(n, n_p, d)) { + // Special case: call_tir, call_dps_packed + if (Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { return doc.value(); } ExprDoc prefix{nullptr}; diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 7702f7b22dd2c..8c4281ad78903 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -82,7 +82,8 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& } if (const auto* call = rhs.as()) { static const Op& call_tir_op = Op::Get("relax.call_tir"); - if (call->op.same_as(call_tir_op)) { + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) { return NullOpt; } } diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 4a345224e57ea..8b26a2aa648a3 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -66,8 +66,10 @@ class IdentityUnused: def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x - unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) - unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") + ) R.output(lv0) return lv0 @@ -92,8 +94,10 @@ class IdentityUnused: def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x - unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) - unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") + ) R.output(lv0) z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) return z diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 7b8035b17c7e4..49d2b76011372 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -492,7 +492,7 @@ def test_sinfo_args_tir_var_used_before_define_call_tir(): # Error: Symbolic Var m1, n1 are not defined m1 = tir.Var("m1", "int64") n1 = tir.Var("n1", "int64") - call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32")) + call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32")) func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) assert not rx.analysis.well_formed(mod, check_struct_info=False) @@ -505,12 +505,12 @@ def test_sinfo_erase_to_well_formed(): def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtype="float32"): m = T.int64() n = T.int64() - gv = R.call_tir("my_func", (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + gv = R.call_dps_packed("my_func", (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) return gv """ m1 = tir.Var("m1", "int64") n1 = tir.Var("n1", "int64") - call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m, n), "float32")) + call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m, n), "float32")) blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])] seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var) func = rx.Function([x], seq_expr, R.Tensor((m1, n1), "float32")).with_attr( diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index e7f2feeaa0508..de71e81464c9b 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -429,10 +429,74 @@ def f( def test_call_tir(): # also from test_parser + @tvm.script.ir_module + class TestCallTIR: + @T.prim_func + def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + T.func_attr(({"global_symbol": "addone"})) + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.var("int64"), T.var("int64") + gv0 = R.call_tir(addone, (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + mod = TestCallTIR + foo = mod["foo"] + + foo_str = strip_whitespace( + dump_ast( + foo, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=False, + ) + ) + assert foo_str.startswith('Function(params=[Var(name_hint="x")]') + + # call_tir is an op in Relax and it takes an extern func as an argument + assert isinstance(foo.body, rx.SeqExpr) + tir_call = foo.body.blocks[0].bindings[0].value + tir_call_text = dump_ast( + tir_call, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=False, + ) + assert_fields( + "Call", + { + "op": 'Op(name="relax.call_tir")', + "args": """[ + GlobalVar(name_hint="addone"), + Tuple(fields=[Var(name_hint="x")]) + ]""", + "sinfo_args": """[ + TensorStructInfo( + dtype=float32, + shape=ShapeExpr( + values=[ + PrimExpr(value=`m`), + PrimExpr(value=`n`) + ] + ) + ) + ]""", + }, + tir_call_text, + ) + assert strip_whitespace(tir_call_text) in foo_str + + +def test_call_dps_packed(): @R.function def foo(x: R.Tensor(("m", "n"), "float32")): - m, n = T.int64(), T.int64() - gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + m, n = T.var("int64"), T.var("int64") + gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 foo_str = strip_whitespace( @@ -445,7 +509,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): ) assert foo_str.startswith('Function(params=[Var(name_hint="x")]') - # call_tir is an op in Relax and it takes an extern func as an argument + # call_dps_packed is an op in Relax and it takes an extern func as an argument assert isinstance(foo.body, rx.SeqExpr) tir_call = foo.body.blocks[0].bindings[0].value tir_call_text = dump_ast( @@ -457,7 +521,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert_fields( "Call", { - "op": 'Op(name="relax.call_tir")', + "op": 'Op(name="relax.call_dps_packed")', "args": """[ ExternFunc(global_symbol="test.op.identity"), Tuple(fields=[Var(name_hint="x")]) diff --git a/tests/python/relax/test_binding_rewrite.py b/tests/python/relax/test_binding_rewrite.py index 1b424b97923a0..d0d3344eb61e3 100644 --- a/tests/python/relax/test_binding_rewrite.py +++ b/tests/python/relax/test_binding_rewrite.py @@ -228,8 +228,10 @@ class IdentityChainedUnused: def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x - unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) - unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32") + ) R.output(lv0) return lv0 @@ -262,19 +264,19 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): # \ / # lv4 with R.dataflow(): - lv0: R.Tensor((32, 32), "float32") = R.call_tir( + lv0: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_relu", (x,), R.Tensor((32, 32), dtype="float32") ) - lv1: R.Tensor((32, 32), "float32") = R.call_tir( + lv1: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32") ) - lv2: R.Tensor((32, 32), "float32") = R.call_tir( + lv2: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32") ) - lv3: R.Tensor((32, 32), "float32") = R.call_tir( + lv3: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32") ) - lv4: R.Tensor((32, 32), "float32") = R.call_tir( + lv4: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_whatever", (lv2, lv3), R.Tensor((32, 32), dtype="float32") ) R.output(lv4) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index ab7a5540ad665..ba6ea995231ce 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -354,15 +354,15 @@ def test_simple_oub(): def test_counter_syntax_match(): with PatternContext() as ctx: - n0 = is_call_tir_extern("tir_matmul") - n1 = is_call_tir_extern("tir_impossible") + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_impossible") n0 >> n1 dfb = main_fn.body.blocks[0] assert not ctx.match_dfb(dfb) with PatternContext() as ctx: - n0 = is_call_tir_extern("tir_matmul") - n1 = is_call_tir_extern("tir_impossible") + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_impossible") n0 ^ n1 dfb = main_fn.body.blocks[0] assert not ctx.match_dfb(dfb) @@ -378,20 +378,20 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> # relu sigmoid # \ / # add - lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("tir_relu", (lv0,), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) - lv3 = R.call_tir("tir_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("extern_relu", (lv0,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_dps_packed("extern_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32")) R.output(lv3) return lv3 def test_diamond(): with PatternContext() as ctx: - n0 = is_call_tir_extern("tir_matmul") - n1 = is_call_tir_extern("tir_relu") - n2 = is_call_tir_extern("tir_sigmoid") - n3 = is_call_tir_extern("tir_add") + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_relu") + n2 = is_call_dps_packed("extern_sigmoid") + n3 = is_call_dps_packed("extern_add") n0 ^ n1 n0 ^ n2 @@ -399,15 +399,15 @@ def test_diamond(): n2 >> n3 dfb = Diamond["main"].body.blocks[0] - assert ctx.match_dfb(dfb) + assert ctx.match_dfb(dfb) # simplify it with fork_to with PatternContext() as ctx: - n1 = is_call_tir_extern("tir_relu") - n2 = is_call_tir_extern("tir_sigmoid") - n3 = is_call_tir_extern("tir_add") + n1 = is_call_dps_packed("extern_relu") + n2 = is_call_dps_packed("extern_sigmoid") + n3 = is_call_dps_packed("extern_add") - is_call_tir_extern("tir_matmul").fork_to(n1, n2) + is_call_dps_packed("extern_matmul").fork_to(n1, n2) n1 >> n3 n2 >> n3 @@ -417,10 +417,10 @@ def test_diamond(): def test_diamond_counter_oub(): with PatternContext() as ctx: - n0 = is_call_tir_extern("tir_matmul") - n1 = is_call_tir_extern("tir_relu") - n2 = is_call_tir_extern("tir_sigmoid") - n3 = is_call_tir_extern("tir_add") + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_relu") + n2 = is_call_dps_packed("extern_sigmoid") + n3 = is_call_dps_packed("extern_add") n0 >> n1 n0 >> n2 @@ -440,8 +440,8 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: # / \ # \ / # add - lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 @@ -454,9 +454,9 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: # relu relu # \ / # add - lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) R.output(lv2) return lv2 @@ -468,8 +468,8 @@ def test_distiguish_diamond_and_parallel(): with PatternContext() as ctx: # describe a diamond pattern - fork = is_call_tir_extern("my_relu") - join = is_call_tir_extern("my_add") + fork = is_call_dps_packed("my_relu") + join = is_call_dps_packed("my_add") fork.only_used_by(join, index=0) fork.only_used_by(join, index=1) @@ -478,13 +478,13 @@ def test_distiguish_diamond_and_parallel(): with PatternContext() as ctx: # describe a parallel pattern - join = is_call_tir_extern("my_add") + join = is_call_dps_packed("my_add") # Due to one-one mathcing: - # is_call_tir_extern("my_relu") creates the 1st relu - is_call_tir_extern("my_relu") >> join - # is_call_tir_extern("my_relu") + # is_call_dps_packed("my_relu") creates the 1st relu + is_call_dps_packed("my_relu") >> join + # is_call_dps_packed("my_relu") # creates the another different relu (obj address is different) - is_call_tir_extern("my_relu") >> join + is_call_dps_packed("my_relu") >> join assert ctx.match_dfb(parallel) assert not ctx.match_dfb(diamond) @@ -507,13 +507,13 @@ def main( # \ / # concat with R.dataflow(): - lv0 = R.call_tir("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("my_relu", (lv1), R.Tensor((32, 32), dtype="float32")) - lv3 = R.call_tir("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32")) - lv4 = R.call_tir("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32")) - lv5 = R.call_tir("my_relu", (lv4), R.Tensor((32, 32), dtype="float32")) - lv6 = R.call_tir("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32")) + lv0 = R.call_dps_packed("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("my_relu", (lv1), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_dps_packed("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32")) + lv4 = R.call_dps_packed("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32")) + lv5 = R.call_dps_packed("my_relu", (lv4), R.Tensor((32, 32), dtype="float32")) + lv6 = R.call_dps_packed("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32")) R.output(lv6) return lv6 @@ -521,9 +521,9 @@ def main( def test_single_cbr(): with PatternContext() as ctx: ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) dfb = CBRx2["main"].body.blocks[0] matched = ctx.match_dfb(dfb) @@ -531,9 +531,9 @@ def test_single_cbr(): with PatternContext() as ctx: chain = ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) dfb = CBRx2["main"].body.blocks[0] # we want to specifically match the first CBR (lv0) @@ -549,9 +549,9 @@ def test_single_cbr(): def test_counter_single_crb(): with PatternContext() as ctx: ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("my_relu") - >> is_call_tir_extern("bias_add") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("my_relu") + >> is_call_dps_packed("bias_add") ) dfb = CBRx2["main"].body.blocks[0] assert not ctx.match_dfb(dfb) @@ -567,14 +567,14 @@ def test_nested_context(): dfb = CBRx2["main"].body.blocks[0] with PatternContext() as ctx0: ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) with PatternContext() as ctx1: - is_call_tir_extern("conv1x1") >> is_call_tir_extern("my_relu") # pattern to miss + is_call_dps_packed("conv1x1") >> is_call_dps_packed("my_relu") # pattern to miss with PatternContext() as ctx2: - is_call_tir_extern("bias_add") >> is_call_tir_extern("my_relu") + is_call_dps_packed("bias_add") >> is_call_dps_packed("my_relu") assert ctx2.match_dfb(dfb) assert PatternContext.current() == ctx2 assert not ctx1.match_dfb(dfb) @@ -586,9 +586,9 @@ def test_nested_context(): def test_two_cbr(): with PatternContext() as ctx: cbr0 = ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) cbr1 = cbr0.dup() @@ -603,9 +603,9 @@ def test_two_cbr(): with PatternContext() as ctx: # Deny the pattern cbr0 = ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) cbr1 = cbr0.dup() @@ -626,25 +626,25 @@ def main( c: R.Tensor((48, 32), "float32"), ) -> R.Tensor: with R.dataflow(): - lv0 = R.call_tir("matmul", (a, b), R.Tensor((32, 48), dtype="float32")) - lv1 = R.call_tir("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("matmul", (a, b), R.Tensor((32, 48), dtype="float32")) + lv1 = R.call_dps_packed("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 with PatternContext() as ctx: - is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") + is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") dfb = MatMul2["main"].body.blocks[0] assert ctx.match_dfb(dfb) with PatternContext() as ctx: - is_call_tir_extern("matmul").has_shape([32, 48]) >> is_call_tir_extern("matmul").has_shape( + is_call_dps_packed("matmul").has_shape([32, 48]) >> is_call_dps_packed("matmul").has_shape( [32, 32] ) dfb = MatMul2["main"].body.blocks[0] assert ctx.match_dfb(dfb) with PatternContext() as ctx: - is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") + is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") dfb = MatMul2["main"].body.blocks[0] # Three MatMul cannot match assert not ctx.match_dfb(dfb) @@ -661,9 +661,9 @@ def main( c: R.Tensor((16, 32), "float32"), ) -> R.Tensor: with R.dataflow(): - lv0 = R.call_tir("my_concat", (b, c), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir( + lv0 = R.call_dps_packed("my_concat", (b, c), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed( "my_split", (lv1,), [R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), dtype="float32")], @@ -676,15 +676,15 @@ def main( with PatternContext() as ctx: ( - is_call_tir_extern("my_concat") - >> is_call_tir_extern("my_matmul") - >> is_call_tir_extern("my_split") + is_call_dps_packed("my_concat") + >> is_call_dps_packed("my_matmul") + >> is_call_dps_packed("my_split") ) dfb = CMS["main"].body.blocks[0] assert ctx.match_dfb(dfb) with PatternContext() as ctx: - split = is_call_tir_extern("my_split") + split = is_call_dps_packed("my_split") lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32]) lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32]) split.fork_to(lv3, lv4) @@ -711,18 +711,26 @@ def main( ) -> R.Tensor: b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64() with R.dataflow(): - fcq = R.call_tir("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32")) - tpq = R.call_tir("my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32")) + fcq = R.call_dps_packed("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32")) + tpq = R.call_dps_packed( + "my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32") + ) - fck = R.call_tir("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32")) - tpk = R.call_tir("my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32")) + fck = R.call_dps_packed("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32")) + tpk = R.call_dps_packed( + "my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32") + ) mul = R.multiply(tpq, tpk) scale = R.multiply(mul, R.const(1.1, "float32")) - softmax = R.call_tir("softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32")) + softmax = R.call_dps_packed( + "softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32") + ) - fcv = R.call_tir("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32")) - tpv = R.call_tir("my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32")) + fcv = R.call_dps_packed("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32")) + tpv = R.call_dps_packed( + "my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32") + ) out = R.multiply(softmax, tpv) R.output(out) @@ -730,7 +738,7 @@ def main( return out with PatternContext() as ctx: - fc_trans_q = is_call_tir_extern("my_fc") >> is_call_tir_extern("my_transpose") + fc_trans_q = is_call_dps_packed("my_fc") >> is_call_dps_packed("my_transpose") fc_trans_k = fc_trans_q.dup() fc_trans_v = fc_trans_q.dup() @@ -752,43 +760,59 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> # add5 add6 # \ / # add7 - lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) - lv3 = R.call_tir("tir_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32")) - lv4 = R.call_tir("tir_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) - lv5 = R.call_tir("tir_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32")) - lv6 = R.call_tir("tir_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32")) - lv7 = R.call_tir("tir_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed( + "extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32") + ) + lv1 = R.call_dps_packed( + "extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32") + ) + lv2 = R.call_dps_packed( + "extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32") + ) + lv3 = R.call_dps_packed( + "extern_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32") + ) + lv4 = R.call_dps_packed( + "extern_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32") + ) + lv5 = R.call_dps_packed( + "extern_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32") + ) + lv6 = R.call_dps_packed( + "extern_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32") + ) + lv7 = R.call_dps_packed( + "extern_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32") + ) R.output(lv7) return lv7 # match matmul0 diamond with PatternContext() as ctx: - sigmoid2 = is_call_tir_extern("tir_sigmoid") - add4 = is_call_tir_extern("tir_add") - is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4) - add5 = is_call_tir_extern("tir_add") + sigmoid2 = is_call_dps_packed("extern_sigmoid") + add4 = is_call_dps_packed("extern_add") + is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4) + add5 = is_call_dps_packed("extern_add") sigmoid2 >> add5 add4 ^ add5 assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) # counter case: mis-match matmul0 diamond with PatternContext() as ctx: - sigmoid2 = is_call_tir_extern("tir_sigmoid") - add4 = is_call_tir_extern("tir_add") - is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4) - add5 = is_call_tir_extern("tir_add") + sigmoid2 = is_call_dps_packed("extern_sigmoid") + add4 = is_call_dps_packed("extern_add") + is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4) + add5 = is_call_dps_packed("extern_add") sigmoid2 >> add5 add4 >> add5 # not only-used-by relation assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) # match matmul1 diamond with PatternContext() as ctx: - sigmoid3 = is_call_tir_extern("tir_sigmoid") - add4 = is_call_tir_extern("tir_add") - is_call_tir_extern("tir_matmul").fork_to(sigmoid3, add4) - add6 = is_call_tir_extern("tir_add") + sigmoid3 = is_call_dps_packed("extern_sigmoid") + add4 = is_call_dps_packed("extern_add") + is_call_dps_packed("extern_matmul").fork_to(sigmoid3, add4) + add6 = is_call_dps_packed("extern_add") sigmoid3 >> add6 add4 ^ add6 assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) @@ -796,11 +820,11 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> # match add-4-5-6-7 with PatternContext() as ctx: add5, add6, add7 = ( - is_call_tir_extern("tir_add"), - is_call_tir_extern("tir_add"), - is_call_tir_extern("tir_add"), + is_call_dps_packed("extern_add"), + is_call_dps_packed("extern_add"), + is_call_dps_packed("extern_add"), ) - is_call_tir_extern("tir_add").fork_to(add5, add6) # add4 + is_call_dps_packed("extern_add").fork_to(add5, add6) # add4 add5 >> add7 add6 >> add7 assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) @@ -811,15 +835,15 @@ def test_incremental_solving(): def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): # relu -> sigmoid -> neg - lv0 = R.call_tir("tir_relu", (x), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("tir_neg", (lv1), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("extern_relu", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("extern_neg", (lv1), R.Tensor((32, 32), dtype="float32")) R.output(lv2) return lv2 - relu = is_call_tir_extern("tir_relu") - sigmoid = is_call_tir_extern("tir_sigmoid") - neg = is_call_tir_extern("tir_neg") + relu = is_call_dps_packed("extern_relu") + sigmoid = is_call_dps_packed("extern_sigmoid") + neg = is_call_dps_packed("extern_neg") with PatternContext() as ctx0: relu >> sigmoid @@ -840,14 +864,14 @@ def test_incremental_solving_counter(): def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): # sigmoid -> neg - lv0 = R.call_tir("tir_sigmoid", (x), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("tir_neg", (lv0), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("extern_sigmoid", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("extern_neg", (lv0), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 - relu = is_call_tir_extern("tir_relu") - sigmoid = is_call_tir_extern("tir_sigmoid") - neg = is_call_tir_extern("tir_neg") + relu = is_call_dps_packed("extern_relu") + sigmoid = is_call_dps_packed("extern_sigmoid") + neg = is_call_dps_packed("extern_neg") with PatternContext() as ctx0: relu >> sigmoid # cannot match diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py index 523a628fa98b5..fd2391153324b 100644 --- a/tests/python/relax/test_op_misc.py +++ b/tests/python/relax/test_op_misc.py @@ -39,7 +39,7 @@ def identity_tir(a: T.handle, b: T.handle) -> None: def test_call_tir() -> None: v0 = rx.Var("v0", R.Tensor([54, 96], "float32")) - v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), "float32")) + v1 = rx.call_dps_packed(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), "float32")) v1 = rx.call_tir(identity_tir, [v0], R.Tensor((54, 96), "float32")) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 85de4f912ecf1..3e6305c492871 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -32,8 +32,10 @@ class TestToNonDataflow: def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.int64(), T.int64() with R.dataflow(): - lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) - gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) + lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed( + "test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32") + ) R.output(gv0) return gv0 @@ -73,10 +75,14 @@ def fvisit(e): def test_call_tir_rewrite(): @tvm.script.ir_module class TestCallTIRRewrite: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + @R.function def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.int64(), T.int64() - gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_tir(exp, (x,), R.Tensor((m, n), dtype="float32")) return gv0 mod = TestCallTIRRewrite @@ -94,6 +100,40 @@ def foo(x: R.Tensor(("m", "n"), "float32")): block = func.body.blocks[0] assert not isinstance(block, relax.DataflowBlock) + s1 = block.bindings[0].value + assert isinstance(s1, relax.Call) + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], relax.ShapeExpr) + assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) + s2 = block.bindings[1].value + tvm.ir.expr.GlobalVar + assert s2.op.name_hint == "exp" + + +def test_call_dps_packed_rewrite(): + @tvm.script.ir_module + class TestCallDPSPackedRewrite: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + mod = TestCallDPSPackedRewrite + + # before rewrite + v0 = mod["foo"].body.blocks[0].bindings[0].var + s0 = mod["foo"].body.blocks[0].bindings[0].value + assert isinstance(s0, relax.Call) + assert s0.op.name == "relax.call_dps_packed" + + # CallTIRRewrite also works for call_dps_packed + new_mod = relax.transform.CallTIRRewrite()(mod) + func = new_mod["foo"] + + block = func.body.blocks[0] + assert not isinstance(block, relax.DataflowBlock) + s1 = block.bindings[0].value assert isinstance(s1, relax.Call) assert s1.op.name == "relax.builtin.alloc_tensor" diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index cef3842e3e494..7fc6798e37c4a 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -45,7 +45,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: @R.function def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) -> R.Tensor: m, n, k = T.int64(), T.int64(), T.int64() - gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 @@ -75,7 +75,7 @@ def main( ) -> R.Tensor: R.func_attr({"global_symbol": "main"}) m, n, k = T.int64(), T.int64(), T.int64() - gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 before = Before diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py index 1dfd9e0c8e198..2a30586b1b16f 100644 --- a/tests/python/relax/test_transform_bind_params.py +++ b/tests/python/relax/test_transform_bind_params.py @@ -87,10 +87,10 @@ def main( m = T.Var("m", "int64") n = T.Var("n", "int64") with R.dataflow(): - lv0 = R.call_tir( + lv0 = R.call_dps_packed( "linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n), dtype="float32") ) - out = R.call_tir( + out = R.call_dps_packed( "linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k), dtype="float32") ) R.output(out) diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index 3e9501147aa09..d82706200aa9e 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -235,12 +235,12 @@ def main( weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): with R.dataflow(): - lv = R.call_tir( + lv = R.call_dps_packed( "fused_relax_nn_conv2d_tensorrt", (data, weight1), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), ) - gv = R.call_tir( + gv = R.call_dps_packed( "fused_relax_nn_conv2d_tensorrt", (lv, weight2), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 33d57417cf0af..14d70ab77cda2 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -826,7 +826,7 @@ class Module: @R.function def main(x: R.Tensor((2, 3), "float32")): with R.dataflow(): - y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), "float32")) + y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3), "float32")) R.output(y) return y @@ -842,7 +842,7 @@ def main(x: R.Tensor((2, 3), "float32")): with R.dataflow(): a = R.call_tir(exp, (x,), out_sinfo=R.Tensor((2, 3), "float32")) b = R.call_tir(exp, (a,), out_sinfo=R.Tensor((2, 3), "float32")) - c = R.call_tir("packed_dps", (a,), out_sinfo=R.Tensor((2, 3), "float32")) + c = R.call_dps_packed("packed_dps", (a,), out_sinfo=R.Tensor((2, 3), "float32")) R.output(b, c) return R.tuple(b, c) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index c2784edec733c..f8d488e43b8ed 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -690,7 +690,7 @@ class Module: @R.function def main(x: R.Tensor((2, 3), "float32")): with R.dataflow(): - y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), "float32")) + y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3), "float32")) R.output(y) return y diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index da123f956d59f..874e83c7f955f 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -124,8 +124,10 @@ class ANFMod2: def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.int64(), T.int64() with R.dataflow(): - lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) - gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) + lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed( + "test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32") + ) R.output(gv0) return gv0 diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index f7c29b8dbe4c6..e103e9cddded1 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -25,7 +25,7 @@ def test_function_simple(): """ @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): - out = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + out = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) return out """ # create with Script IRBuilder @@ -35,8 +35,15 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) R.func_attr({"Primitive": 1}) x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32", ndim=2)) + y = R.emit( + R.call_dps_packed( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) out = R.emit( - R.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + R.call_dps_packed( + "extern_dps_func", y, relax.TensorStructInfo((128, 128), dtype="float32") + ) ) IRBuilder.name("out", out) R.func_ret_value(out) @@ -45,8 +52,15 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,), attrs={"Primitive": 1}): + y = bb.emit( + relax.call_dps_packed( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) out = bb.emit( - relax.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + relax.call_dps_packed( + "extern_dps_func", y, relax.TensorStructInfo((128, 128), dtype="float32") + ) ) bb.emit_func_output(out) mod = bb.get() @@ -112,7 +126,7 @@ def test_dataflow_block(): def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): # block 0 with R.dataflow(): - lv0 = R.call_tir("extern_func", (x,), R.Tensor((128, 128), dtype="float32")) + lv0 = R.call_dps_packed("extern_func", (x,), R.Tensor((128, 128), dtype="float32")) gv: Tensor((128, 128), "float32") = lv0 R.output(gv) return gv @@ -124,7 +138,7 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) with R.dataflow() as df: lv0 = R.emit( - R.call_tir( + R.call_dps_packed( "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") ) ) @@ -142,7 +156,7 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): with bb.function("foo", (x,)): with bb.dataflow(): lv0 = bb.emit( - relax.call_tir( + relax.call_dps_packed( "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") ) ) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index ffc024403d0cb..136bd8c1ea4cc 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -43,13 +43,17 @@ def test_simple_func(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): R.func_attr({"Primitive": 1}) - gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) - return gv0 + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + gv1 = R.call_dps_packed("extern_dps_func", gv0, R.Tensor((128, 128), dtype="float32")) + return gv1 x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,), attrs={"Primitive": 1}): - out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + y = bb.emit(relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + out = bb.emit( + relax.call_dps_packed("extern_dps_func", y, R.Tensor((128, 128), dtype="float32")) + ) bb.emit_func_output(out) _check(foo, bb.get()["foo"]) @@ -111,15 +115,34 @@ def f(x: R.Tensor(("m",), "float32")): return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32")) -def test_unexpected_tir_max_args(): +def test_unexpected_tir_args(): + + with pytest.raises(tvm.error.DiagnosticError): + + @tvm.script.ir_module + class TestWellCallTIR: + @T.prim_func + def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + T.func_attr(({"global_symbol": "tir_addone"})) + for i, j in T.grid(16, 16): + with T.block("tir_addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def foo(x: R.Tensor(("m", "m"), "float32")): + m = T.int64() + # tir.max expects 2 arguments, but got 1 + gv = R.call_tir(tir_addone, (x,), R.Tensor((T.max(16),), dtype="float32")) + return gv with pytest.raises(tvm.error.DiagnosticError): @R.function def f(x: R.Tensor(("m", "n"), "float32")): m = T.int64() - # tir.max expects 2 arguments, but got 1 - return relax.call_tir("foo", (x,), R.Tensor((T.max(m),), dtype="float32")) + # call_tir expected a tir prim_func + return relax.call_tir("extern_func", (x,), R.Tensor((T.max(m),), dtype="float32")) def test_func_type_annotation_fail(): @@ -315,14 +338,14 @@ def test_symbolic_shape(): def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): m = T.int64() n = T.int64() - gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 @R.function def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): m = T.int64() n = T.int64() - gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 with pytest.raises(tvm.error.DiagnosticError): @@ -331,7 +354,7 @@ def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): m = T.int64() n = T.int32() # The shape dtype should be int64 - gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 def _expected(name: str): @@ -339,7 +362,9 @@ def _expected(name: str): x = relax.Var("x", R.Tensor([m, n], "float32")) bb = relax.BlockBuilder() with bb.function(name, (x,)): - out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))) + out = bb.emit( + relax.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) + ) bb.emit_func_output(out) return bb.get()[name] @@ -403,15 +428,15 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): def test_tuple_return(): @R.function def foo(x: R.Tensor((4, 4), "float32")): - gv0 = R.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32")) - gv1 = R.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32")) + gv0 = R.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32")) + gv1 = R.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32")) return (gv0, gv1) x = relax.Var("x", R.Tensor((4, 4), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,)): - gv0 = bb.emit(relax.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))) - gv1 = bb.emit(relax.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))) + gv0 = bb.emit(relax.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))) + gv1 = bb.emit(relax.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))) bb.emit_func_output(relax.Tuple((gv0, gv1))) _check(foo, bb.get()["foo"]) @@ -483,8 +508,8 @@ def test_dataflow_block(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): with R.dataflow(): - lv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) - lv1 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + lv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + lv1 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) gv = lv1 R.output(gv) return gv @@ -493,8 +518,12 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) bb = relax.BlockBuilder() with bb.function("foo", (x,)): with bb.dataflow(): - lv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) - lv1 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) + lv0 = bb.emit( + relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + ) + lv1 = bb.emit( + relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + ) gv = bb.emit_output(lv1) bb.emit_func_output(gv) @@ -504,22 +533,22 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) def test_dataflow_block_advanced(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) - gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + gv1 = R.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) with R.dataflow(): m = T.int64() n = T.int64() - lv0 = R.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) + lv0 = R.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32")) - gv2 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) - gv2 = R.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + gv2 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + gv2 = R.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32")) gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32")) gv4 = gv3 gv5 = gv2 R.output(gv5, gv4) - gv6 = R.call_tir("extern_func", gv5, R.Tensor((128, 128), dtype="float32")) - gv7 = R.call_tir("extern_func", gv6, R.Tensor((128, 128), dtype="float32")) + gv6 = R.call_dps_packed("extern_func", gv5, R.Tensor((128, 128), dtype="float32")) + gv7 = R.call_dps_packed("extern_func", gv6, R.Tensor((128, 128), dtype="float32")) return gv7 x = relax.Var("x", R.Tensor((128, 128), "float32")) @@ -527,21 +556,33 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") with bb.function("foo", (x,)): - gv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) - gv1 = bb.emit(relax.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32"))) + gv0 = bb.emit( + relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + ) + gv1 = bb.emit( + relax.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + ) with bb.dataflow(): - lv0 = bb.emit(relax.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32"))) + lv0 = bb.emit( + relax.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) + ) lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) - gv2 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) + gv2 = bb.emit( + relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + ) gv21 = bb.emit( - relax.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + relax.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) ) gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32")) gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) gv32 = bb.emit_output(gv31) gv22 = bb.emit_output(gv21) - gv4 = bb.emit(relax.call_tir("extern_func", gv22, R.Tensor((128, 128), dtype="float32"))) - gv5 = bb.emit(relax.call_tir("extern_func", gv4, R.Tensor((128, 128), dtype="float32"))) + gv4 = bb.emit( + relax.call_dps_packed("extern_func", gv22, R.Tensor((128, 128), dtype="float32")) + ) + gv5 = bb.emit( + relax.call_dps_packed("extern_func", gv4, R.Tensor((128, 128), dtype="float32")) + ) bb.emit_func_output(gv5) _check(foo, bb.get()["foo"]) @@ -640,13 +681,13 @@ def foo(x: R.Tensor((128, 128), "float32")): def test_tensor_type_without_args(): @R.function def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: - v = R.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32")) + v = R.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32")) return v x = relax.Var("x", R.Tensor((32, 32), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x)): - v = bb.emit(relax.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32"))) + v = bb.emit(relax.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32"))) bb.emit_func_output(v) _check(foo, bb.get()["foo"]) @@ -753,10 +794,10 @@ def bar(x: R.Tensor): assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo) -def test_call_tir_empty_shape(): +def test_call_dps_packed_empty_shape(): @R.function def foo(x: R.Tensor((), "float32")): - z = R.call_tir("scalar_add", x, R.Tensor((), dtype="float32")) + z = R.call_dps_packed("scalar_add", x, R.Tensor((), dtype="float32")) return z (z_bind,) = foo.body.blocks[0].bindings @@ -1024,7 +1065,7 @@ def bar( x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): m = T.int64() - z = R.call_tir("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) + z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) return z m = tir.Var("m", "int64") @@ -1033,7 +1074,9 @@ def bar( bb = relax.BlockBuilder() with bb.function("bar", (x, y)): z = bb.emit( - relax.call_tir("test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32")) + relax.call_dps_packed( + "test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32") + ) ) bb.emit_func_output(z) @@ -1043,7 +1086,7 @@ def bar( @R.function def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): m = T.int64() - z = R.call_tir("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) + z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) return z m = tir.Var("m", "int64") @@ -1051,7 +1094,7 @@ def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) bb = relax.BlockBuilder() with bb.function("baz", (x, y)): - z = bb.emit(relax.call_tir("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) + z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) bb.emit_func_output(z) _check(baz, bb.get()["baz"]) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 464591f2592b8..76bb3bb812290 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -299,13 +299,22 @@ def test_shape_expr(): def test_call(): x = tir.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) - obj = relax.call_tir("my_func", args=a, out_sinfo=a.struct_info, tir_vars=[x]) + o0 = relax.call_tir(relax.GlobalVar("tir_func"), args=a, out_sinfo=a.struct_info, tir_vars=[x]) + o1 = relax.call_dps_packed("my_dps_func", args=a, out_sinfo=a.struct_info) _assert_print( - obj, + o0, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +R.call_tir(tir_func, (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) +""", + ) + _assert_print( + o1, """ x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") -R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) +R.call_dps_packed("my_dps_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32")) """, ) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index e51e22e3233c7..776679103f863 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -121,7 +121,7 @@ class TestVMCompileStage3: @R.function def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor: with R.dataflow(): - y = R.call_tir("test.vm.identity", (x), R.Tensor((32, 16), dtype="float32")) + y = R.call_dps_packed("test.vm.identity", (x), R.Tensor((32, 16), dtype="float32")) R.output(y) return y @@ -145,7 +145,7 @@ def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: with R.dataflow(): n, m = T.int64(), T.int64() _ = R.match_cast(x, R.Tensor((n, m), "float32")) - y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) + y = R.call_dps_packed("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) R.output(y) return y @@ -714,7 +714,7 @@ def test_vm_nested_tuple( @R.function def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: - gv0 = R.call_tir("test_vm_mul", (x, w), R.Tensor((32, 32), dtype="float32")) + gv0 = R.call_tir(test_vm_mul, (x, w), R.Tensor((32, 32), dtype="float32")) return gv0