From f9113039b04ab4242ace61977f39857e1576e4eb Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 2 Mar 2023 08:11:09 -0800 Subject: [PATCH 1/7] [Unity] Introduce call_dps_packed --- include/tvm/relax/dataflow_pattern.h | 4 + include/tvm/relax/transform.h | 7 + python/tvm/relax/__init__.py | 2 +- python/tvm/relax/dpl/pattern.py | 18 +- python/tvm/relax/expr.py | 2 +- python/tvm/relax/op/base.py | 47 +++- python/tvm/relax/transform/transform.py | 10 + python/tvm/relax/vm_build.py | 1 + python/tvm/script/ir_builder/relax/ir.py | 2 + src/relax/ir/dataflow_pattern.cc | 14 + src/relax/op/op.cc | 39 +++ .../transform/call_dps_packed_rewrite.cc | 140 ++++++++++ src/relax/transform/run_codegen.cc | 9 +- src/script/printer/relax/call.cc | 8 +- src/script/printer/relax/utils.h | 3 +- tests/python/relax/test_analysis.py | 12 +- .../python/relax/test_analysis_well_formed.py | 6 +- tests/python/relax/test_ast_printer.py | 72 ++++- tests/python/relax/test_binding_rewrite.py | 16 +- tests/python/relax/test_dataflow_pattern.py | 254 ++++++++++-------- tests/python/relax/test_op_misc.py | 2 +- tests/python/relax/test_transform.py | 46 +++- .../test_transform_attach_global_symbol.py | 4 +- .../relax/test_transform_bind_params.py | 4 +- .../relax/test_transform_codegen_pass.py | 4 +- tests/python/relax/test_transform_fuse_ops.py | 4 +- tests/python/relax/test_transform_fuse_tir.py | 2 +- .../python/relax/test_transform_normalize.py | 6 +- .../python/relax/test_tvmscript_ir_builder.py | 26 +- tests/python/relax/test_tvmscript_parser.py | 94 ++++--- .../relax/test_tvmscript_printer_relax.py | 15 +- tests/python/relax/test_vm_build.py | 6 +- 32 files changed, 668 insertions(+), 211 deletions(-) create mode 100644 src/relax/transform/call_dps_packed_rewrite.cc diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index 701879745efa..37640750a8ef 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -793,6 +793,10 @@ ExprPattern IsOp(const String& op_name); CallPattern IsCallTIR(const String& name, Optional args = NullOpt); /*! \brief Syntatic Sugar for call_tir (return a tuple of tensor) */ CallPattern IsCallTIR(const String& name, TuplePattern var_args); +/*! \brief Syntatic Sugar for call_dps_packed (return a tensor) */ +CallPattern IsCallDPSPacked(const String& name, Optional args = NullOpt); +/*! \brief Syntatic Sugar for call_dps_packed (return a tuple of tensor) */ +CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args); /*! \brief Syntatic Sugar for creating TuplePattern or UnorderedTuplePattern (unordered=true) */ DFPattern IsTuple(const Array& fields, bool unordered = false); /*! \brief Syntatic Sugar for creating a TupleGetItemPattern */ diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 715c8e56ff9c..e7e7dfcf319b 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -89,6 +89,13 @@ TVM_DLL Pass ToNonDataflow(); */ TVM_DLL Pass CallTIRRewrite(); +/*! + * \brief Perform explicit tensor allocation for call_dps_packed. + * + * \return The Pass. + */ +TVM_DLL Pass CallDPSPackedRewrite(); + /*! * \brief Convert all reshape-like call_tir whose corresponding binding * vars are DataflowVars to relax.reshape operator calls. The relax.reshape diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py index bbd2040dd9d6..edbd848bd598 100644 --- a/python/tvm/relax/__init__.py +++ b/python/tvm/relax/__init__.py @@ -56,7 +56,7 @@ from .exec_builder import ExecBuilder # Operator -from .op.base import call_tir +from .op.base import call_tir, call_dps_packed # BlockBuilder from .block_builder import BlockBuilder diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 300b0af568c0..1ca41b378da5 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -862,11 +862,23 @@ def is_call_tir( return _is_call_tir(func_pattern, args) -def is_call_tir_extern( +def _is_call_dps_packed( + func_pattern: DFPattern, + args: Union[List, Tuple, TuplePattern] = None, +) -> CallPattern: + if args is None: + args = wildcard() + elif isinstance(args, (list, tuple)): + args = TuplePattern(args) + + return is_op("relax.call_dps_packed")(func_pattern, args) + + +def is_call_dps_packed( func_name: str, args: Union[List, Tuple, TuplePattern] = None, ) -> CallPattern: - """Syntax sugar for creating a CallPattern for call_tir that calls an extern function + """Syntax sugar for creating a CallPattern for call_dps_packed Parameters ---------- @@ -881,7 +893,7 @@ def is_call_tir_extern( The resulting CallPattern """ func_pattern = ExternFuncPattern(func_name) - return _is_call_tir(func_pattern, args) + return _is_call_dps_packed(func_pattern, args) def is_call_packed( diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index ab332eed618e..4af08a3118fc 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -597,7 +597,7 @@ def __call__(self, *args): @tvm._ffi.register_object("relax.expr.ExternFunc") class ExternFunc(BaseFunc): - """extern function, which can represent a TIR PrimFunc or a PackedFunc.""" + """extern function, which represents a PackedFunc.""" global_symbol: String diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 0b298679c1c5..2bfe9e85aae8 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -22,7 +22,7 @@ from tvm.runtime.object import Object from . import _ffi_api -from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc +from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar from ..expr import Tuple as RxTuple from ..struct_info import StructInfo, TensorStructInfo from ...ir import PrimExpr @@ -51,12 +51,12 @@ def call_tir( tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None, ) -> Call: """ - Call a destination-passing-style function and return the output. + Call a tir.prim_func function and return the output. Parameters ---------- func : Union[str, Expr] - The destination-passing-style function, can be ExternFunc or PrimFunc. + The tir PrimFunc function. args : Expr The input arguments. @@ -75,7 +75,7 @@ def call_tir( A call node for the call_tir operator. """ if isinstance(func, str): - func = ExternFunc(func) + func = GlobalVar(func) if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) @@ -89,6 +89,45 @@ def call_tir( return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore +@args_converter.auto +def call_dps_packed( + func: Union[str, Expr], + args: Expr, + out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], +) -> Call: + """ + Call a destination-passing-style packed function and return the output. + + Parameters + ---------- + func : Union[str, Expr] + The destination-passing-style function, can be ExternFunc. + + args : Expr + The input arguments. + + out_sinfo : Union[TensorStructInfo, List[TensorStructInfo]] + The structure info of the call_dps_packed output. + It should be a single or a list of TensorStructInfo. Each one denotes the + structure info of a returned tensor. + + Returns + ------- + ret: Call + A call node for the call_dps_packed operator. + """ + if isinstance(func, str): + func = ExternFunc(func) + + if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore + args = RxTuple((args,)) + + if not isinstance(out_sinfo, list): + out_sinfo = [out_sinfo] + + return _ffi_api.call_dps_packed(func, args, out_sinfo) # type: ignore + + @args_converter.auto def call_builtin_with_ctx( func: Union[str, Expr], diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index a33ad630935f..cbd8e61cf2b7 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -69,6 +69,16 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass: return _ffi_api.CallTIRRewrite() # type: ignore +def CallDPSPackedRewrite() -> tvm.ir.transform.Pass: + """Perform explicit tensor allocation for call_dps_packed. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CallDPSPackedRewrite() # type: ignore + + def Normalize() -> tvm.ir.transform.Pass: """Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available. diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index 0586bf9217a2..3fca75dfc933 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -296,6 +296,7 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): passes.append(relax.transform.RewriteDataflowReshape()) passes.append(relax.transform.ToNonDataflow()) passes.append(relax.transform.CallTIRRewrite()) + passes.append(relax.transform.CallDPSPackedRewrite()) passes.append(relax.transform.StaticPlanBlockMemory()) passes.append(relax.transform.VMBuiltinLower()) passes.append(relax.transform.VMShapeLower()) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 14ef36307ad5..4574f33c5ac4 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -46,6 +46,7 @@ builtin, call_builtin_with_ctx, call_tir, + call_dps_packed, ceil, clip, collapse_sum_like, @@ -534,6 +535,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "builtin", "call_packed", "call_tir", + "call_dps_packed", "call_builtin_with_ctx", "ceil", "clip", diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc index 3768627c204c..5eb1bf3ea6f6 100644 --- a/src/relax/ir/dataflow_pattern.cc +++ b/src/relax/ir/dataflow_pattern.cc @@ -563,6 +563,20 @@ CallPattern IsCallTIR(const String& name, Optional var_args) { CallPattern IsCallTIR(const String& name, TuplePattern var_args) { return IsOp("relax.call_tir")(GlobalVarPattern(name), var_args); } +CallPattern IsCallDPSPacked(const String& name, Optional var_args) { + DFPattern arg_pattern; + if (!var_args.defined()) { + arg_pattern = Wildcard(); + } else { + arg_pattern = var_args.value(); + } + + return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), arg_pattern); +} + +CallPattern IsCallDPSPacked(const String& name, TuplePattern var_args) { + return IsOp("relax.call_dps_packed")(GlobalVarPattern(name), var_args); +} DFPattern IsTuple(const Array& fields, bool unordered) { if (unordered) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 21d692b6a460..e0f1d26bdf4a 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -106,6 +106,45 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); + +// call_dps_packed + +StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ctx) { + if (call->sinfo_args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "sinfo_args should have exact 1 output struct info."); + } + return call->sinfo_args[0]; +} + +RELAY_REGISTER_OP("relax.call_dps_packed") + .set_num_inputs(2) + .add_argument("func", "Expr", "The destination-passing-style function.") + .add_argument("args", "Tuple", "The input arguments.") + .set_attr("FInferStructInfo", InferStructInfoCallDPSPacked); + +Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_list) { + for (const TensorStructInfo& sinfo : out_sinfo_list) { + const auto* shape = sinfo->shape.as(); + CHECK(shape != nullptr) << "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; + } + + StructInfo out_sinfo{nullptr}; + if (out_sinfo_list.size() == 1) { + out_sinfo = out_sinfo_list[0]; + } else { + out_sinfo = TupleStructInfo({out_sinfo_list.begin(), out_sinfo_list.end()}); + } + + static const Op& op = Op::Get("relax.call_dps_packed"); + return Call(op, {func, args}, {}, {out_sinfo}); +} + +TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked); + + // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() == 0) { diff --git a/src/relax/transform/call_dps_packed_rewrite.cc b/src/relax/transform/call_dps_packed_rewrite.cc new file mode 100644 index 000000000000..714e0c59eeea --- /dev/null +++ b/src/relax/transform/call_dps_packed_rewrite.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/call_dps_packed_rewrite.cc + * \brief Perform explicit tensor allocation for call_dps_packed. + */ +#include +#include +#include +#include +#include + +#include "../../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +// ================== +// CallDPSPackedMutator +// Perform explicit tensor allocation for call_dps_packed. +// Example: +// lv0: Tensor(n, m) = rx.call_dps_packed(func, (x), (n, m), dtype="float32") +// --> +// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m], dtype="float32") +// rx.call_packed(func, x, gv0) + +class CallDPSPackedMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const CallNode* call) override { + // post-order mutation + Expr expr = VisitExprPostOrder_(call); + call = expr.as(); + + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn"); + + if (call->op == call_dps_packed_op) { + Array outs; + if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { + // single output case + const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); + ICHECK(tensor_sinfo->shape.defined()) + << "the TensorStructInfo shape of call_dps_packed has not populated"; + outs.push_back( + builder_->Emit(Call(alloc_tensor_op, // + {Downcast(tensor_sinfo->shape.value()), + DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(0)}, // + Attrs()), + "alloc")); + } else if (const auto& _tuple_sinfo = MatchStructInfo(expr)) { + // multiple output case + const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value(); + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto& field = tuple_sinfo->fields[i]; + + ICHECK(field->IsInstance()) + << "call_dps_packed expects Tuple of TensorStructInfo, but got " << field + << " as an element of TupleStructInfo"; + const auto& field_tensor = Downcast(field); + ICHECK(field_tensor->shape.defined()) + << "call_dps_packed expects all TensorStructInfo has shape, but got " << field_tensor + << " as an element of TupleStructInfo"; + outs.push_back( + builder_->Emit(Call(alloc_tensor_op, + {Downcast(field_tensor->shape.value()), + DataTypeImm(field_tensor->dtype), PrimValue::Int64(0)}, + Attrs()), + "alloc")); + } + } else { + LOG(FATAL) + << "TypeError: The struct info of call_dps_packed expects to be TensorStructInfo or " + "TupleStructInfo, but got" + << expr->struct_info_; + } + + Array args; + if (call->args[1].as()) { + args = Downcast(call->args[1])->fields; + args.insert(args.end(), outs.begin(), outs.end()); + + if (call->args.size() == 2) { + builder_->Emit(Call(call->args[0], args), "_"); + } else { + // unpack semantics + args.push_back(call->args[2]); + builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); + } + } else { + args = outs; + args.insert(args.begin(), call->args[1]); + builder_->Emit(Call(call->args[0], args), "_"); + } + + if (outs.size() == 1) { + return outs[0]; + } + return std::move(Tuple(outs)); + } + + return GetRef(call); + } +}; + +Expr CallDPSPackedRewrite(const Expr& e) { return CallDPSPackedMutator().VisitExpr(e); } + +namespace transform { + +Pass CallDPSPackedRewrite() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(CallDPSPackedRewrite(f)); + }; + return CreateFunctionPass(pass_func, 0, "CallDPSPackedRewrite", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.CallDPSPackedRewrite").set_body_typed(CallDPSPackedRewrite); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 7deeb139d1a0..b5a4d7536f7f 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -75,17 +75,18 @@ class CodeGenRunner : ExprMutator { if (auto const* gvar_node = call_node->op.as()) { const GlobalVar gvar = GetRef(gvar_node); - auto create_call_tir = [call_node, this](Expr extern_func, StructInfo ret_struct_info) { + auto create_call_dps_packed = [call_node, this](Expr extern_func, + StructInfo ret_struct_info) { Array new_args({extern_func}); new_args.push_back(Tuple(call_node->args.Map([this](Expr arg) { return VisitExpr(arg); }))); - static const Op& call_op = Op::Get("relax.call_tir"); + static const Op& call_op = Op::Get("relax.call_dps_packed"); return Call(call_op, new_args, tvm::Attrs(), {ret_struct_info}); }; if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) { - return create_call_tir(it->second.first, it->second.second); + return create_call_dps_packed(it->second.first, it->second.second); } else { // TODO(@sunggg): Is there any better way to get this func? Function func = Downcast(builder_->GetContextIRModule()->Lookup(gvar)); @@ -101,7 +102,7 @@ class CodeGenRunner : ExprMutator { func = (*RemoveFuncAttrFunc)(func, tvm::attr::kGlobalSymbol); func = (*RemoveFuncAttrFunc)(func, attr::kCodegen); builder_->UpdateFunction(gvar, func); - return create_call_tir(new_func, func->ret_struct_info); + return create_call_dps_packed(new_func, func->ret_struct_info); } } } diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 2feb2082c510..318a0827899a 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -97,7 +97,8 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifi Optional PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); - if (!n->op.same_as(call_tir_op)) { + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op)) { return NullOpt; } ICHECK(n->args.size() == 2 || n->args.size() == 3); @@ -123,6 +124,9 @@ Optional PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, cons } else { kwargs_values.push_back(d->AsDoc(o_sinfo, o_sinfo_p)); } + if (n->op.same_as(call_dps_packed_op)) { + return Relax(d, "call_dps_packed")->Call(args, kwargs_keys, kwargs_values); + } // Step 4. Print n->args[2], the tir variables if (n->args.size() == 3) { kwargs_keys.push_back("tir_vars"); @@ -134,7 +138,7 @@ Optional PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, cons TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { - // Special case: call_tir + // Special case: call_tir, call_dps_packed if (Optional doc = PrintCallTIR(n, n_p, d)) { return doc.value(); } diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h index 7702f7b22dd2..8c4281ad7890 100644 --- a/src/script/printer/relax/utils.h +++ b/src/script/printer/relax/utils.h @@ -82,7 +82,8 @@ inline Optional StructInfoAsAnn(const relax::Var& v, const ObjectPath& } if (const auto* call = rhs.as()) { static const Op& call_tir_op = Op::Get("relax.call_tir"); - if (call->op.same_as(call_tir_op)) { + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) { return NullOpt; } } diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 4a345224e57e..8b26a2aa648a 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -66,8 +66,10 @@ class IdentityUnused: def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x - unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) - unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") + ) R.output(lv0) return lv0 @@ -92,8 +94,10 @@ class IdentityUnused: def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x - unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) - unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32") + ) R.output(lv0) z = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32, 32), "float32"))) return z diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 7b8035b17c7e..49d2b7601137 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -492,7 +492,7 @@ def test_sinfo_args_tir_var_used_before_define_call_tir(): # Error: Symbolic Var m1, n1 are not defined m1 = tir.Var("m1", "int64") n1 = tir.Var("n1", "int64") - call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32")) + call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m1, n1), "float32")) func = build_function([rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])]) mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) assert not rx.analysis.well_formed(mod, check_struct_info=False) @@ -505,12 +505,12 @@ def test_sinfo_erase_to_well_formed(): def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtype="float32"): m = T.int64() n = T.int64() - gv = R.call_tir("my_func", (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) + gv = R.call_dps_packed("my_func", (x,), out_sinfo=R.Tensor((m, n), dtype="float32")) return gv """ m1 = tir.Var("m1", "int64") n1 = tir.Var("n1", "int64") - call = R.call_tir("my_func", x, out_sinfo=R.Tensor((m, n), "float32")) + call = R.call_dps_packed("my_func", x, out_sinfo=R.Tensor((m, n), "float32")) blocks = [rx.BindingBlock([rx.VarBinding(rx.Var("gv"), call)])] seq_expr = rx.SeqExpr(blocks, blocks[-1].bindings[-1].var) func = rx.Function([x], seq_expr, R.Tensor((m1, n1), "float32")).with_attr( diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index c21dbd2bd1f5..1b029af7e025 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -429,10 +429,74 @@ def f( def test_call_tir(): # also from test_parser + @tvm.script.ir_module + class TestCallTIR: + @T.prim_func + def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + T.func_attr(({"global_symbol": "addone"})) + for i, j in T.grid(16, 16): + with T.block("addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.var("int64"), T.var("int64") + gv0 = R.call_tir(addone, (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + mod = TestCallTIR + foo = mod["foo"] + + foo_str = strip_whitespace( + dump_ast( + foo, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=False, + ) + ) + assert foo_str.startswith('Function(params=[Var(name_hint="x")]') + + # call_tir is an op in Relax and it takes an extern func as an argument + assert isinstance(foo.body, rx.SeqExpr) + tir_call = foo.body.blocks[0].bindings[0].value + tir_call_text = dump_ast( + tir_call, + include_type_annotations=False, + include_struct_info_annotations=False, + include_call_attrs=False, + ) + assert_fields( + "Call", + { + "op": 'Op(name="relax.call_tir")', + "args": """[ + GlobalVar(name_hint="addone"), + Tuple(fields=[Var(name_hint="x")]) + ]""", + "sinfo_args": """[ + TensorStructInfo( + dtype=float32, + shape=ShapeExpr( + values=[ + PrimExpr(value=`m`), + PrimExpr(value=`n`) + ] + ) + ) + ]""", + }, + tir_call_text, + ) + assert strip_whitespace(tir_call_text) in foo_str + + +def test_call_dps_packed(): @R.function def foo(x: R.Tensor(("m", "n"), "float32")): - m, n = T.int64(), T.int64() - gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + m, n = T.var("int64"), T.var("int64") + gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) return gv0 foo_str = strip_whitespace( @@ -445,7 +509,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): ) assert foo_str.startswith('Function(params=[Var(name_hint="x")]') - # call_tir is an op in Relax and it takes an extern func as an argument + # call_dps_packed is an op in Relax and it takes an extern func as an argument assert isinstance(foo.body, rx.SeqExpr) tir_call = foo.body.blocks[0].bindings[0].value tir_call_text = dump_ast( @@ -457,7 +521,7 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert_fields( "Call", { - "op": 'Op(name="relax.call_tir")', + "op": 'Op(name="relax.call_dps_packed")', "args": """[ ExternFunc(global_symbol="test.op.identity"), Tuple(fields=[Var(name_hint="x")]) diff --git a/tests/python/relax/test_binding_rewrite.py b/tests/python/relax/test_binding_rewrite.py index 1b424b97923a..d0d3344eb61e 100644 --- a/tests/python/relax/test_binding_rewrite.py +++ b/tests/python/relax/test_binding_rewrite.py @@ -228,8 +228,10 @@ class IdentityChainedUnused: def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): lv0 = x - unused0 = R.call_tir("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) - unused1 = R.call_tir("my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32")) + unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32")) + unused1 = R.call_dps_packed( + "my_sigmoid", (unused0,), R.Tensor((32, 32), dtype="float32") + ) R.output(lv0) return lv0 @@ -262,19 +264,19 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): # \ / # lv4 with R.dataflow(): - lv0: R.Tensor((32, 32), "float32") = R.call_tir( + lv0: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_relu", (x,), R.Tensor((32, 32), dtype="float32") ) - lv1: R.Tensor((32, 32), "float32") = R.call_tir( + lv1: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_sigmoid", (x,), R.Tensor((32, 32), dtype="float32") ) - lv2: R.Tensor((32, 32), "float32") = R.call_tir( + lv2: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_add", (x, lv0), R.Tensor((32, 32), dtype="float32") ) - lv3: R.Tensor((32, 32), "float32") = R.call_tir( + lv3: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_mul", (x, lv0), R.Tensor((32, 32), dtype="float32") ) - lv4: R.Tensor((32, 32), "float32") = R.call_tir( + lv4: R.Tensor((32, 32), "float32") = R.call_dps_packed( "my_whatever", (lv2, lv3), R.Tensor((32, 32), dtype="float32") ) R.output(lv4) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index ab7a5540ad66..ba6ea995231c 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -354,15 +354,15 @@ def test_simple_oub(): def test_counter_syntax_match(): with PatternContext() as ctx: - n0 = is_call_tir_extern("tir_matmul") - n1 = is_call_tir_extern("tir_impossible") + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_impossible") n0 >> n1 dfb = main_fn.body.blocks[0] assert not ctx.match_dfb(dfb) with PatternContext() as ctx: - n0 = is_call_tir_extern("tir_matmul") - n1 = is_call_tir_extern("tir_impossible") + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_impossible") n0 ^ n1 dfb = main_fn.body.blocks[0] assert not ctx.match_dfb(dfb) @@ -378,20 +378,20 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> # relu sigmoid # \ / # add - lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("tir_relu", (lv0,), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) - lv3 = R.call_tir("tir_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("extern_relu", (lv0,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_dps_packed("extern_add", (lv1, lv2), R.Tensor((32, 32), dtype="float32")) R.output(lv3) return lv3 def test_diamond(): with PatternContext() as ctx: - n0 = is_call_tir_extern("tir_matmul") - n1 = is_call_tir_extern("tir_relu") - n2 = is_call_tir_extern("tir_sigmoid") - n3 = is_call_tir_extern("tir_add") + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_relu") + n2 = is_call_dps_packed("extern_sigmoid") + n3 = is_call_dps_packed("extern_add") n0 ^ n1 n0 ^ n2 @@ -399,15 +399,15 @@ def test_diamond(): n2 >> n3 dfb = Diamond["main"].body.blocks[0] - assert ctx.match_dfb(dfb) + assert ctx.match_dfb(dfb) # simplify it with fork_to with PatternContext() as ctx: - n1 = is_call_tir_extern("tir_relu") - n2 = is_call_tir_extern("tir_sigmoid") - n3 = is_call_tir_extern("tir_add") + n1 = is_call_dps_packed("extern_relu") + n2 = is_call_dps_packed("extern_sigmoid") + n3 = is_call_dps_packed("extern_add") - is_call_tir_extern("tir_matmul").fork_to(n1, n2) + is_call_dps_packed("extern_matmul").fork_to(n1, n2) n1 >> n3 n2 >> n3 @@ -417,10 +417,10 @@ def test_diamond(): def test_diamond_counter_oub(): with PatternContext() as ctx: - n0 = is_call_tir_extern("tir_matmul") - n1 = is_call_tir_extern("tir_relu") - n2 = is_call_tir_extern("tir_sigmoid") - n3 = is_call_tir_extern("tir_add") + n0 = is_call_dps_packed("extern_matmul") + n1 = is_call_dps_packed("extern_relu") + n2 = is_call_dps_packed("extern_sigmoid") + n3 = is_call_dps_packed("extern_add") n0 >> n1 n0 >> n2 @@ -440,8 +440,8 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: # / \ # \ / # add - lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("my_add", (lv0, lv0), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 @@ -454,9 +454,9 @@ def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: # relu relu # \ / # add - lv0 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("my_relu", (x,), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("my_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) R.output(lv2) return lv2 @@ -468,8 +468,8 @@ def test_distiguish_diamond_and_parallel(): with PatternContext() as ctx: # describe a diamond pattern - fork = is_call_tir_extern("my_relu") - join = is_call_tir_extern("my_add") + fork = is_call_dps_packed("my_relu") + join = is_call_dps_packed("my_add") fork.only_used_by(join, index=0) fork.only_used_by(join, index=1) @@ -478,13 +478,13 @@ def test_distiguish_diamond_and_parallel(): with PatternContext() as ctx: # describe a parallel pattern - join = is_call_tir_extern("my_add") + join = is_call_dps_packed("my_add") # Due to one-one mathcing: - # is_call_tir_extern("my_relu") creates the 1st relu - is_call_tir_extern("my_relu") >> join - # is_call_tir_extern("my_relu") + # is_call_dps_packed("my_relu") creates the 1st relu + is_call_dps_packed("my_relu") >> join + # is_call_dps_packed("my_relu") # creates the another different relu (obj address is different) - is_call_tir_extern("my_relu") >> join + is_call_dps_packed("my_relu") >> join assert ctx.match_dfb(parallel) assert not ctx.match_dfb(diamond) @@ -507,13 +507,13 @@ def main( # \ / # concat with R.dataflow(): - lv0 = R.call_tir("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("my_relu", (lv1), R.Tensor((32, 32), dtype="float32")) - lv3 = R.call_tir("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32")) - lv4 = R.call_tir("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32")) - lv5 = R.call_tir("my_relu", (lv4), R.Tensor((32, 32), dtype="float32")) - lv6 = R.call_tir("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32")) + lv0 = R.call_dps_packed("conv1x1", (x, w0), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("bias_add", (lv0, bias0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("my_relu", (lv1), R.Tensor((32, 32), dtype="float32")) + lv3 = R.call_dps_packed("conv1x1", (x, w1), R.Tensor((32, 32), dtype="float32")) + lv4 = R.call_dps_packed("bias_add", (lv3, bias1), R.Tensor((32, 32), dtype="float32")) + lv5 = R.call_dps_packed("my_relu", (lv4), R.Tensor((32, 32), dtype="float32")) + lv6 = R.call_dps_packed("concat", (lv2, lv5), R.Tensor((32, 64), dtype="float32")) R.output(lv6) return lv6 @@ -521,9 +521,9 @@ def main( def test_single_cbr(): with PatternContext() as ctx: ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) dfb = CBRx2["main"].body.blocks[0] matched = ctx.match_dfb(dfb) @@ -531,9 +531,9 @@ def test_single_cbr(): with PatternContext() as ctx: chain = ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) dfb = CBRx2["main"].body.blocks[0] # we want to specifically match the first CBR (lv0) @@ -549,9 +549,9 @@ def test_single_cbr(): def test_counter_single_crb(): with PatternContext() as ctx: ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("my_relu") - >> is_call_tir_extern("bias_add") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("my_relu") + >> is_call_dps_packed("bias_add") ) dfb = CBRx2["main"].body.blocks[0] assert not ctx.match_dfb(dfb) @@ -567,14 +567,14 @@ def test_nested_context(): dfb = CBRx2["main"].body.blocks[0] with PatternContext() as ctx0: ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) with PatternContext() as ctx1: - is_call_tir_extern("conv1x1") >> is_call_tir_extern("my_relu") # pattern to miss + is_call_dps_packed("conv1x1") >> is_call_dps_packed("my_relu") # pattern to miss with PatternContext() as ctx2: - is_call_tir_extern("bias_add") >> is_call_tir_extern("my_relu") + is_call_dps_packed("bias_add") >> is_call_dps_packed("my_relu") assert ctx2.match_dfb(dfb) assert PatternContext.current() == ctx2 assert not ctx1.match_dfb(dfb) @@ -586,9 +586,9 @@ def test_nested_context(): def test_two_cbr(): with PatternContext() as ctx: cbr0 = ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) cbr1 = cbr0.dup() @@ -603,9 +603,9 @@ def test_two_cbr(): with PatternContext() as ctx: # Deny the pattern cbr0 = ( - is_call_tir_extern("conv1x1") - >> is_call_tir_extern("bias_add") - >> is_call_tir_extern("my_relu") + is_call_dps_packed("conv1x1") + >> is_call_dps_packed("bias_add") + >> is_call_dps_packed("my_relu") ) cbr1 = cbr0.dup() @@ -626,25 +626,25 @@ def main( c: R.Tensor((48, 32), "float32"), ) -> R.Tensor: with R.dataflow(): - lv0 = R.call_tir("matmul", (a, b), R.Tensor((32, 48), dtype="float32")) - lv1 = R.call_tir("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("matmul", (a, b), R.Tensor((32, 48), dtype="float32")) + lv1 = R.call_dps_packed("matmul", (lv0, c), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 with PatternContext() as ctx: - is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") + is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") dfb = MatMul2["main"].body.blocks[0] assert ctx.match_dfb(dfb) with PatternContext() as ctx: - is_call_tir_extern("matmul").has_shape([32, 48]) >> is_call_tir_extern("matmul").has_shape( + is_call_dps_packed("matmul").has_shape([32, 48]) >> is_call_dps_packed("matmul").has_shape( [32, 32] ) dfb = MatMul2["main"].body.blocks[0] assert ctx.match_dfb(dfb) with PatternContext() as ctx: - is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") >> is_call_tir_extern("matmul") + is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") >> is_call_dps_packed("matmul") dfb = MatMul2["main"].body.blocks[0] # Three MatMul cannot match assert not ctx.match_dfb(dfb) @@ -661,9 +661,9 @@ def main( c: R.Tensor((16, 32), "float32"), ) -> R.Tensor: with R.dataflow(): - lv0 = R.call_tir("my_concat", (b, c), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir( + lv0 = R.call_dps_packed("my_concat", (b, c), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("my_matmul", (a, lv0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed( "my_split", (lv1,), [R.Tensor((16, 32), dtype="float32"), R.Tensor((16, 32), dtype="float32")], @@ -676,15 +676,15 @@ def main( with PatternContext() as ctx: ( - is_call_tir_extern("my_concat") - >> is_call_tir_extern("my_matmul") - >> is_call_tir_extern("my_split") + is_call_dps_packed("my_concat") + >> is_call_dps_packed("my_matmul") + >> is_call_dps_packed("my_split") ) dfb = CMS["main"].body.blocks[0] assert ctx.match_dfb(dfb) with PatternContext() as ctx: - split = is_call_tir_extern("my_split") + split = is_call_dps_packed("my_split") lv3 = TupleGetItemPattern(split, 0).has_shape([16, 32]) lv4 = TupleGetItemPattern(split, 1).has_shape([16, 32]) split.fork_to(lv3, lv4) @@ -711,18 +711,26 @@ def main( ) -> R.Tensor: b, s, n, h = T.int64(), T.int64(), T.int64(), T.int64() with R.dataflow(): - fcq = R.call_tir("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32")) - tpq = R.call_tir("my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32")) + fcq = R.call_dps_packed("my_fc", (x, wq), R.Tensor((b, s, n, h), dtype="float32")) + tpq = R.call_dps_packed( + "my_transpose", (fcq,), R.Tensor((b, s, h, n), dtype="float32") + ) - fck = R.call_tir("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32")) - tpk = R.call_tir("my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32")) + fck = R.call_dps_packed("my_fc", (x, wk), R.Tensor((b, s, n, h), dtype="float32")) + tpk = R.call_dps_packed( + "my_transpose", (fck,), R.Tensor((b, s, h, n), dtype="float32") + ) mul = R.multiply(tpq, tpk) scale = R.multiply(mul, R.const(1.1, "float32")) - softmax = R.call_tir("softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32")) + softmax = R.call_dps_packed( + "softmax", (scale,), R.Tensor((b, s, n, h), dtype="float32") + ) - fcv = R.call_tir("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32")) - tpv = R.call_tir("my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32")) + fcv = R.call_dps_packed("my_fc", (x, wv), R.Tensor((b, s, n, h), dtype="float32")) + tpv = R.call_dps_packed( + "my_transpose", (fcv,), R.Tensor((b, s, h, n), dtype="float32") + ) out = R.multiply(softmax, tpv) R.output(out) @@ -730,7 +738,7 @@ def main( return out with PatternContext() as ctx: - fc_trans_q = is_call_tir_extern("my_fc") >> is_call_tir_extern("my_transpose") + fc_trans_q = is_call_dps_packed("my_fc") >> is_call_dps_packed("my_transpose") fc_trans_k = fc_trans_q.dup() fc_trans_v = fc_trans_q.dup() @@ -752,43 +760,59 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> # add5 add6 # \ / # add7 - lv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("tir_matmul", (x, w), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) - lv3 = R.call_tir("tir_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32")) - lv4 = R.call_tir("tir_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32")) - lv5 = R.call_tir("tir_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32")) - lv6 = R.call_tir("tir_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32")) - lv7 = R.call_tir("tir_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed( + "extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32") + ) + lv1 = R.call_dps_packed( + "extern_matmul", (x, w), R.Tensor((32, 32), dtype="float32") + ) + lv2 = R.call_dps_packed( + "extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32") + ) + lv3 = R.call_dps_packed( + "extern_sigmoid", (lv1), R.Tensor((32, 32), dtype="float32") + ) + lv4 = R.call_dps_packed( + "extern_add", (lv0, lv1), R.Tensor((32, 32), dtype="float32") + ) + lv5 = R.call_dps_packed( + "extern_add", (lv2, lv4), R.Tensor((32, 32), dtype="float32") + ) + lv6 = R.call_dps_packed( + "extern_add", (lv3, lv4), R.Tensor((32, 32), dtype="float32") + ) + lv7 = R.call_dps_packed( + "extern_add", (lv5, lv6), R.Tensor((32, 32), dtype="float32") + ) R.output(lv7) return lv7 # match matmul0 diamond with PatternContext() as ctx: - sigmoid2 = is_call_tir_extern("tir_sigmoid") - add4 = is_call_tir_extern("tir_add") - is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4) - add5 = is_call_tir_extern("tir_add") + sigmoid2 = is_call_dps_packed("extern_sigmoid") + add4 = is_call_dps_packed("extern_add") + is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4) + add5 = is_call_dps_packed("extern_add") sigmoid2 >> add5 add4 ^ add5 assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) # counter case: mis-match matmul0 diamond with PatternContext() as ctx: - sigmoid2 = is_call_tir_extern("tir_sigmoid") - add4 = is_call_tir_extern("tir_add") - is_call_tir_extern("tir_matmul").fork_to(sigmoid2, add4) - add5 = is_call_tir_extern("tir_add") + sigmoid2 = is_call_dps_packed("extern_sigmoid") + add4 = is_call_dps_packed("extern_add") + is_call_dps_packed("extern_matmul").fork_to(sigmoid2, add4) + add5 = is_call_dps_packed("extern_add") sigmoid2 >> add5 add4 >> add5 # not only-used-by relation assert not ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) # match matmul1 diamond with PatternContext() as ctx: - sigmoid3 = is_call_tir_extern("tir_sigmoid") - add4 = is_call_tir_extern("tir_add") - is_call_tir_extern("tir_matmul").fork_to(sigmoid3, add4) - add6 = is_call_tir_extern("tir_add") + sigmoid3 = is_call_dps_packed("extern_sigmoid") + add4 = is_call_dps_packed("extern_add") + is_call_dps_packed("extern_matmul").fork_to(sigmoid3, add4) + add6 = is_call_dps_packed("extern_add") sigmoid3 >> add6 add4 ^ add6 assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) @@ -796,11 +820,11 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> # match add-4-5-6-7 with PatternContext() as ctx: add5, add6, add7 = ( - is_call_tir_extern("tir_add"), - is_call_tir_extern("tir_add"), - is_call_tir_extern("tir_add"), + is_call_dps_packed("extern_add"), + is_call_dps_packed("extern_add"), + is_call_dps_packed("extern_add"), ) - is_call_tir_extern("tir_add").fork_to(add5, add6) # add4 + is_call_dps_packed("extern_add").fork_to(add5, add6) # add4 add5 >> add7 add6 >> add7 assert ctx.match_dfb(DiamondInDiamond["main"].body.blocks[0]) @@ -811,15 +835,15 @@ def test_incremental_solving(): def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): # relu -> sigmoid -> neg - lv0 = R.call_tir("tir_relu", (x), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("tir_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) - lv2 = R.call_tir("tir_neg", (lv1), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("extern_relu", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("extern_sigmoid", (lv0), R.Tensor((32, 32), dtype="float32")) + lv2 = R.call_dps_packed("extern_neg", (lv1), R.Tensor((32, 32), dtype="float32")) R.output(lv2) return lv2 - relu = is_call_tir_extern("tir_relu") - sigmoid = is_call_tir_extern("tir_sigmoid") - neg = is_call_tir_extern("tir_neg") + relu = is_call_dps_packed("extern_relu") + sigmoid = is_call_dps_packed("extern_sigmoid") + neg = is_call_dps_packed("extern_neg") with PatternContext() as ctx0: relu >> sigmoid @@ -840,14 +864,14 @@ def test_incremental_solving_counter(): def simple_chain(x: R.Tensor((32, 32), "float32")) -> R.Tensor: with R.dataflow(): # sigmoid -> neg - lv0 = R.call_tir("tir_sigmoid", (x), R.Tensor((32, 32), dtype="float32")) - lv1 = R.call_tir("tir_neg", (lv0), R.Tensor((32, 32), dtype="float32")) + lv0 = R.call_dps_packed("extern_sigmoid", (x), R.Tensor((32, 32), dtype="float32")) + lv1 = R.call_dps_packed("extern_neg", (lv0), R.Tensor((32, 32), dtype="float32")) R.output(lv1) return lv1 - relu = is_call_tir_extern("tir_relu") - sigmoid = is_call_tir_extern("tir_sigmoid") - neg = is_call_tir_extern("tir_neg") + relu = is_call_dps_packed("extern_relu") + sigmoid = is_call_dps_packed("extern_sigmoid") + neg = is_call_dps_packed("extern_neg") with PatternContext() as ctx0: relu >> sigmoid # cannot match diff --git a/tests/python/relax/test_op_misc.py b/tests/python/relax/test_op_misc.py index 523a628fa98b..fd2391153324 100644 --- a/tests/python/relax/test_op_misc.py +++ b/tests/python/relax/test_op_misc.py @@ -39,7 +39,7 @@ def identity_tir(a: T.handle, b: T.handle) -> None: def test_call_tir() -> None: v0 = rx.Var("v0", R.Tensor([54, 96], "float32")) - v1 = rx.call_tir(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), "float32")) + v1 = rx.call_dps_packed(rx.extern("test.op.identity"), [v0], R.Tensor((54, 96), "float32")) v1 = rx.call_tir(identity_tir, [v0], R.Tensor((54, 96), "float32")) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 85de4f912ecf..cdfb4e635e8f 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -32,8 +32,10 @@ class TestToNonDataflow: def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.int64(), T.int64() with R.dataflow(): - lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) - gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) + lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed( + "test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32") + ) R.output(gv0) return gv0 @@ -73,10 +75,14 @@ def fvisit(e): def test_call_tir_rewrite(): @tvm.script.ir_module class TestCallTIRRewrite: + @T.prim_func + def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + T.evaluate(0) + @R.function def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.int64(), T.int64() - gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_tir(exp, (x,), R.Tensor((m, n), dtype="float32")) return gv0 mod = TestCallTIRRewrite @@ -94,6 +100,40 @@ def foo(x: R.Tensor(("m", "n"), "float32")): block = func.body.blocks[0] assert not isinstance(block, relax.DataflowBlock) + s1 = block.bindings[0].value + assert isinstance(s1, relax.Call) + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], relax.ShapeExpr) + assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) + s2 = block.bindings[1].value + tvm.ir.expr.GlobalVar + assert s2.op.name_hint == "exp" + + +def test_call_dps_packed_rewrite(): + @tvm.script.ir_module + class TestCallDPSPackedRewrite: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.int64(), T.int64() + gv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + mod = TestCallDPSPackedRewrite + + # before rewrite + v0 = mod["foo"].body.blocks[0].bindings[0].var + s0 = mod["foo"].body.blocks[0].bindings[0].value + assert isinstance(s0, relax.Call) + assert s0.op.name == "relax.call_dps_packed" + + # after rewrite + new_mod = relax.transform.CallDPSPackedRewrite()(mod) + func = new_mod["foo"] + + block = func.body.blocks[0] + assert not isinstance(block, relax.DataflowBlock) + s1 = block.bindings[0].value assert isinstance(s1, relax.Call) assert s1.op.name == "relax.builtin.alloc_tensor" diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index cef3842e3e49..7fc6798e37c4 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -45,7 +45,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: @R.function def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) -> R.Tensor: m, n, k = T.int64(), T.int64(), T.int64() - gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 @@ -75,7 +75,7 @@ def main( ) -> R.Tensor: R.func_attr({"global_symbol": "main"}) m, n, k = T.int64(), T.int64(), T.int64() - gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 before = Before diff --git a/tests/python/relax/test_transform_bind_params.py b/tests/python/relax/test_transform_bind_params.py index 1dfd9e0c8e19..2a30586b1b16 100644 --- a/tests/python/relax/test_transform_bind_params.py +++ b/tests/python/relax/test_transform_bind_params.py @@ -87,10 +87,10 @@ def main( m = T.Var("m", "int64") n = T.Var("n", "int64") with R.dataflow(): - lv0 = R.call_tir( + lv0 = R.call_dps_packed( "linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n), dtype="float32") ) - out = R.call_tir( + out = R.call_dps_packed( "linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k), dtype="float32") ) R.output(out) diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index 3e9501147aa0..d82706200aa9 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -235,12 +235,12 @@ def main( weight2: R.Tensor((16, 3, 3, 16), dtype="float16"), ) -> R.Tensor((16, 32, 32, 16), dtype="float16"): with R.dataflow(): - lv = R.call_tir( + lv = R.call_dps_packed( "fused_relax_nn_conv2d_tensorrt", (data, weight1), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), ) - gv = R.call_tir( + gv = R.call_dps_packed( "fused_relax_nn_conv2d_tensorrt", (lv, weight2), out_sinfo=R.Tensor((16, 32, 32, 16), dtype="float16"), diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index a7a6066c4b60..8ec45a3d9b21 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -826,7 +826,7 @@ class Module: @R.function def main(x: R.Tensor((2, 3), "float32")): with R.dataflow(): - y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), "float32")) + y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3), "float32")) R.output(y) return y @@ -842,7 +842,7 @@ def main(x: R.Tensor((2, 3), "float32")): with R.dataflow(): a = R.call_tir(exp, (x,), out_sinfo=R.Tensor((2, 3), "float32")) b = R.call_tir(exp, (a,), out_sinfo=R.Tensor((2, 3), "float32")) - c = R.call_tir("packed_dps", (a,), out_sinfo=R.Tensor((2, 3), "float32")) + c = R.call_dps_packed("packed_dps", (a,), out_sinfo=R.Tensor((2, 3), "float32")) R.output(b, c) return R.tuple(b, c) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index c2784edec733..f8d488e43b8e 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -690,7 +690,7 @@ class Module: @R.function def main(x: R.Tensor((2, 3), "float32")): with R.dataflow(): - y = R.call_tir("func_packed_dps", x, R.Tensor((2, 3), "float32")) + y = R.call_dps_packed("func_packed_dps", x, R.Tensor((2, 3), "float32")) R.output(y) return y diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index da123f956d59..874e83c7f955 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -124,8 +124,10 @@ class ANFMod2: def foo(x: R.Tensor(("m", "n"), "float32")): m, n = T.int64(), T.int64() with R.dataflow(): - lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) - gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) + lv0 = R.call_dps_packed("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed( + "test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32") + ) R.output(gv0) return gv0 diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py index 014b00af0097..9f6c02d2b938 100644 --- a/tests/python/relax/test_tvmscript_ir_builder.py +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -25,7 +25,7 @@ def test_function_simple(): """ @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): - out = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) + out = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) return out """ # create with Script IRBuilder @@ -35,8 +35,15 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) R.func_attr({"Primitive": 1}) x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) R.func_ret_struct_info(relax.TensorStructInfo(dtype="float32", ndim=2)) + y = R.emit( + R.call_dps_packed( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) out = R.emit( - R.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + R.call_dps_packed( + "extern_dps_func", y, relax.TensorStructInfo((128, 128), dtype="float32") + ) ) IRBuilder.name("out", out) R.func_ret_value(out) @@ -45,8 +52,15 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) x = relax.Var("x", relax.TensorStructInfo((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,), attrs={"Primitive": 1}): + y = bb.emit( + relax.call_dps_packed( + "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") + ) + ) out = bb.emit( - relax.call_tir("extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32")) + relax.call_dps_packed( + "extern_dps_func", y, relax.TensorStructInfo((128, 128), dtype="float32") + ) ) bb.emit_func_output(out) mod = bb.get() @@ -116,7 +130,7 @@ def test_dataflow_block(): def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): # block 0 with R.dataflow(): - lv0 = R.call_tir("extern_func", (x,), R.Tensor((128, 128), dtype="float32")) + lv0 = R.call_dps_packed("extern_func", (x,), R.Tensor((128, 128), dtype="float32")) gv: Tensor((128, 128), "float32") = lv0 R.output(gv) return gv @@ -128,7 +142,7 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): x = R.arg("x", relax.TensorStructInfo((128, 128), "float32")) with R.dataflow() as df: lv0 = R.emit( - R.call_tir( + R.call_dps_packed( "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") ) ) @@ -146,7 +160,7 @@ def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): with bb.function("foo", (x,)): with bb.dataflow(): lv0 = bb.emit( - relax.call_tir( + relax.call_dps_packed( "extern_func", x, relax.TensorStructInfo((128, 128), dtype="float32") ) ) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index b885697c735f..ed8204817bfb 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -43,13 +43,17 @@ def test_simple_func(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): R.func_attr({"Primitive": 1}) - gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) - return gv0 + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + gv1 = R.call_dps_packed("extern_dps_func", gv0, R.Tensor((128, 128), dtype="float32")) + return gv1 x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,), attrs={"Primitive": 1}): - out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + y = bb.emit(relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32"))) + out = bb.emit( + relax.call_dps_packed("extern_dps_func", y, R.Tensor((128, 128), dtype="float32")) + ) bb.emit_func_output(out) _check(foo, bb.get()["foo"]) @@ -264,14 +268,14 @@ def test_symbolic_shape(): def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): m = T.int64() n = T.int64() - gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 @R.function def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): m = T.int64() n = T.int64() - gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 with pytest.raises(tvm.error.DiagnosticError): @@ -280,7 +284,7 @@ def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): m = T.int64() n = T.int32() # The shape dtype should be int64 - gv0 = R.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32")) + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) return gv0 def _expected(name: str): @@ -288,7 +292,9 @@ def _expected(name: str): x = relax.Var("x", R.Tensor([m, n], "float32")) bb = relax.BlockBuilder() with bb.function(name, (x,)): - out = bb.emit(relax.call_tir("extern_func", x, R.Tensor((m, n), dtype="float32"))) + out = bb.emit( + relax.call_dps_packed("extern_func", x, R.Tensor((m, n), dtype="float32")) + ) bb.emit_func_output(out) return bb.get()[name] @@ -352,15 +358,15 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): def test_tuple_return(): @R.function def foo(x: R.Tensor((4, 4), "float32")): - gv0 = R.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32")) - gv1 = R.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32")) + gv0 = R.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32")) + gv1 = R.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32")) return (gv0, gv1) x = relax.Var("x", R.Tensor((4, 4), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x,)): - gv0 = bb.emit(relax.call_tir("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))) - gv1 = bb.emit(relax.call_tir("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))) + gv0 = bb.emit(relax.call_dps_packed("extern_func_0", x, R.Tensor((4, 4), dtype="float32"))) + gv1 = bb.emit(relax.call_dps_packed("extern_func_1", x, R.Tensor((4, 4), dtype="float32"))) bb.emit_func_output(relax.Tuple((gv0, gv1))) _check(foo, bb.get()["foo"]) @@ -432,8 +438,8 @@ def test_dataflow_block(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): with R.dataflow(): - lv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) - lv1 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + lv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + lv1 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) gv = lv1 R.output(gv) return gv @@ -442,8 +448,12 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) bb = relax.BlockBuilder() with bb.function("foo", (x,)): with bb.dataflow(): - lv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) - lv1 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) + lv0 = bb.emit( + relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + ) + lv1 = bb.emit( + relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + ) gv = bb.emit_output(lv1) bb.emit_func_output(gv) @@ -453,22 +463,22 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) def test_dataflow_block_advanced(): @R.function def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): - gv0 = R.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32")) - gv1 = R.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + gv0 = R.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + gv1 = R.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) with R.dataflow(): m = T.int64() n = T.int64() - lv0 = R.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) + lv0 = R.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) lv1 = R.match_cast(lv0, R.Tensor((m, n), "float32")) - gv2 = R.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) - gv2 = R.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + gv2 = R.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + gv2 = R.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) gv3 = R.match_cast(gv2, R.Tensor((m, n), "float32")) gv3 = R.match_cast(lv0, R.Tensor((m, n), "float32")) gv4 = gv3 gv5 = gv2 R.output(gv5, gv4) - gv6 = R.call_tir("extern_func", gv5, R.Tensor((128, 128), dtype="float32")) - gv7 = R.call_tir("extern_func", gv6, R.Tensor((128, 128), dtype="float32")) + gv6 = R.call_dps_packed("extern_func", gv5, R.Tensor((128, 128), dtype="float32")) + gv7 = R.call_dps_packed("extern_func", gv6, R.Tensor((128, 128), dtype="float32")) return gv7 x = relax.Var("x", R.Tensor((128, 128), "float32")) @@ -476,21 +486,33 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2) m = tir.Var("m", dtype="int64") n = tir.Var("n", dtype="int64") with bb.function("foo", (x,)): - gv0 = bb.emit(relax.call_tir("extern_func", x, R.Tensor((128, 128), dtype="float32"))) - gv1 = bb.emit(relax.call_tir("extern_func", gv0, R.Tensor((128, 128), dtype="float32"))) + gv0 = bb.emit( + relax.call_dps_packed("extern_func", x, R.Tensor((128, 128), dtype="float32")) + ) + gv1 = bb.emit( + relax.call_dps_packed("extern_func", gv0, R.Tensor((128, 128), dtype="float32")) + ) with bb.dataflow(): - lv0 = bb.emit(relax.call_tir("extern_func", gv1, R.Tensor((128, 128), dtype="float32"))) + lv0 = bb.emit( + relax.call_dps_packed("extern_func", gv1, R.Tensor((128, 128), dtype="float32")) + ) lv1 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) - gv2 = bb.emit(relax.call_tir("extern_func", lv0, R.Tensor((128, 128), dtype="float32"))) + gv2 = bb.emit( + relax.call_dps_packed("extern_func", lv0, R.Tensor((128, 128), dtype="float32")) + ) gv21 = bb.emit( - relax.call_tir("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) + relax.call_dps_packed("extern_func", gv2, R.Tensor((128, 128), dtype="float32")) ) gv3 = bb.match_cast(gv21, R.Tensor((m, n), "float32")) gv31 = bb.match_cast(lv0, R.Tensor((m, n), "float32")) gv32 = bb.emit_output(gv31) gv22 = bb.emit_output(gv21) - gv4 = bb.emit(relax.call_tir("extern_func", gv22, R.Tensor((128, 128), dtype="float32"))) - gv5 = bb.emit(relax.call_tir("extern_func", gv4, R.Tensor((128, 128), dtype="float32"))) + gv4 = bb.emit( + relax.call_dps_packed("extern_func", gv22, R.Tensor((128, 128), dtype="float32")) + ) + gv5 = bb.emit( + relax.call_dps_packed("extern_func", gv4, R.Tensor((128, 128), dtype="float32")) + ) bb.emit_func_output(gv5) _check(foo, bb.get()["foo"]) @@ -589,13 +611,13 @@ def foo(x: R.Tensor((128, 128), "float32")): def test_tensor_type_without_args(): @R.function def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: - v = R.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32")) + v = R.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32")) return v x = relax.Var("x", R.Tensor((32, 32), "float32")) bb = relax.BlockBuilder() with bb.function("foo", (x)): - v = bb.emit(relax.call_tir("tir_relu", x, R.Tensor((32, 32), dtype="float32"))) + v = bb.emit(relax.call_dps_packed("extern_relu", x, R.Tensor((32, 32), dtype="float32"))) bb.emit_func_output(v) _check(foo, bb.get()["foo"]) @@ -973,7 +995,7 @@ def bar( x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): m = T.int64() - z = R.call_tir("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) + z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) return z m = tir.Var("m", "int64") @@ -982,7 +1004,9 @@ def bar( bb = relax.BlockBuilder() with bb.function("bar", (x, y)): z = bb.emit( - relax.call_tir("test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32")) + relax.call_dps_packed( + "test_intrin", (x, y), R.Tensor((tir.max(m, 20) + 1,), dtype="float32") + ) ) bb.emit_func_output(z) @@ -992,7 +1016,7 @@ def bar( @R.function def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): m = T.int64() - z = R.call_tir("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) + z = R.call_dps_packed("test_intrin", y, R.Tensor((m * 2,), dtype="float32")) return z m = tir.Var("m", "int64") @@ -1000,7 +1024,7 @@ def baz(x: R.Shape(("m",)), y: R.Tensor(("m * 2",), "float32")): y = relax.Var("y", relax.TensorStructInfo([m * 2], "float32")) bb = relax.BlockBuilder() with bb.function("baz", (x, y)): - z = bb.emit(relax.call_tir("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) + z = bb.emit(relax.call_dps_packed("test_intrin", (y), R.Tensor((m * 2,), dtype="float32"))) bb.emit_func_output(z) _check(baz, bb.get()["baz"]) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 464591f2592b..270c91299aeb 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -299,13 +299,22 @@ def test_shape_expr(): def test_call(): x = tir.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) - obj = relax.call_tir("my_func", args=a, out_sinfo=a.struct_info, tir_vars=[x]) + o0 = relax.call_tir("tir_func", args=a, out_sinfo=a.struct_info, tir_vars=[x]) + o1 = relax.call_dps_packed("my_dps_func", args=a, out_sinfo=a.struct_info) _assert_print( - obj, + o0, + """ +x = T.int64() +a: R.Tensor((1, x, 3), dtype="float32") +R.call_tir(tir_func, (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) +""", + ) + _assert_print( + o1, """ x = T.int64() a: R.Tensor((1, x, 3), dtype="float32") -R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) +R.call_dps_packed("my_dps_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32")) """, ) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index e51e22e3233c..0e92e6a6fbf4 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -121,7 +121,7 @@ class TestVMCompileStage3: @R.function def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor: with R.dataflow(): - y = R.call_tir("test.vm.identity", (x), R.Tensor((32, 16), dtype="float32")) + y = R.call_dps_packed("test.vm.identity", (x), R.Tensor((32, 16), dtype="float32")) R.output(y) return y @@ -145,7 +145,7 @@ def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: with R.dataflow(): n, m = T.int64(), T.int64() _ = R.match_cast(x, R.Tensor((n, m), "float32")) - y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) + y = R.call_dps_packed("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) R.output(y) return y @@ -714,7 +714,7 @@ def test_vm_nested_tuple( @R.function def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: - gv0 = R.call_tir("test_vm_mul", (x, w), R.Tensor((32, 32), dtype="float32")) + gv0 = R.call_dps_packed("test_vm_mul", (x, w), R.Tensor((32, 32), dtype="float32")) return gv0 From be660e3b1d51b7543674b2d886daafd35766b3f2 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 2 Mar 2023 17:48:00 -0800 Subject: [PATCH 2/7] fix lint --- src/relax/op/op.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index e0f1d26bdf4a..001c76dcf3a2 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -106,7 +106,6 @@ Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, TVM_REGISTER_GLOBAL("relax.op.call_tir").set_body_typed(MakeCallTIR); - // call_dps_packed StructInfo InferStructInfoCallDPSPacked(const Call& call, const BlockBuilder& ctx) { @@ -126,9 +125,10 @@ RELAY_REGISTER_OP("relax.call_dps_packed") Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_list) { for (const TensorStructInfo& sinfo : out_sinfo_list) { const auto* shape = sinfo->shape.as(); - CHECK(shape != nullptr) << "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. " - "However, one given structure info is " - << sinfo; + CHECK(shape != nullptr) + << "out_sinfo of call_dps_packed should have defined ShapeExpr as shape. " + "However, one given structure info is " + << sinfo; } StructInfo out_sinfo{nullptr}; @@ -144,7 +144,6 @@ Expr MakeCallDPSPacked(Expr func, Tuple args, Array out_sinfo_ TVM_REGISTER_GLOBAL("relax.op.call_dps_packed").set_body_typed(MakeCallDPSPacked); - // call builtin StructInfo InferStructInfoCallBuiltinWithCtx(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() == 0) { From f1eedae04e5815a4573df644469ae531e3673c30 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Mon, 6 Mar 2023 20:13:54 -0800 Subject: [PATCH 3/7] Fix comments --- include/tvm/relax/transform.h | 9 +- python/tvm/relax/op/base.py | 5 +- python/tvm/relax/transform/transform.py | 12 +- python/tvm/relax/vm_build.py | 1 - src/relax/analysis/well_formed.cc | 14 ++ .../transform/call_dps_packed_rewrite.cc | 140 ------------------ src/relax/transform/call_tir_rewrite.cc | 5 +- .../python/relax/test_analysis_well_formed.py | 49 ++++++ tests/python/relax/test_transform.py | 4 +- tests/python/relax/test_tvmscript_parser.py | 29 +++- .../relax/test_tvmscript_printer_relax.py | 2 +- tests/python/relax/test_vm_build.py | 2 +- 12 files changed, 97 insertions(+), 175 deletions(-) delete mode 100644 src/relax/transform/call_dps_packed_rewrite.cc diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index e7e7dfcf319b..0819e48aee4a 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -83,19 +83,12 @@ TVM_DLL Pass LambdaLift(); TVM_DLL Pass ToNonDataflow(); /*! - * \brief Perform explicit tensor allocation for call_tir. + * \brief Perform explicit tensor allocation for call_tir and call_dps_packed. * * \return The Pass. */ TVM_DLL Pass CallTIRRewrite(); -/*! - * \brief Perform explicit tensor allocation for call_dps_packed. - * - * \return The Pass. - */ -TVM_DLL Pass CallDPSPackedRewrite(); - /*! * \brief Convert all reshape-like call_tir whose corresponding binding * vars are DataflowVars to relax.reshape operator calls. The relax.reshape diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 2bfe9e85aae8..23ce8668f632 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -22,7 +22,7 @@ from tvm.runtime.object import Object from . import _ffi_api -from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar +from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc from ..expr import Tuple as RxTuple from ..struct_info import StructInfo, TensorStructInfo from ...ir import PrimExpr @@ -74,9 +74,6 @@ def call_tir( ret: Call A call node for the call_tir operator. """ - if isinstance(func, str): - func = GlobalVar(func) - if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore args = RxTuple((args,)) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index cbd8e61cf2b7..d262eeafc7f3 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -60,7 +60,7 @@ def LambdaLift(): def CallTIRRewrite() -> tvm.ir.transform.Pass: - """Perform explicit tensor allocation for call_tir. + """Perform explicit tensor allocation for call_tir and call_dps_packed. Returns ------- @@ -69,16 +69,6 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass: return _ffi_api.CallTIRRewrite() # type: ignore -def CallDPSPackedRewrite() -> tvm.ir.transform.Pass: - """Perform explicit tensor allocation for call_dps_packed. - - Returns - ------- - ret: tvm.ir.transform.Pass - """ - return _ffi_api.CallDPSPackedRewrite() # type: ignore - - def Normalize() -> tvm.ir.transform.Pass: """Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available. diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index 3fca75dfc933..0586bf9217a2 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -296,7 +296,6 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): passes.append(relax.transform.RewriteDataflowReshape()) passes.append(relax.transform.ToNonDataflow()) passes.append(relax.transform.CallTIRRewrite()) - passes.append(relax.transform.CallDPSPackedRewrite()) passes.append(relax.transform.StaticPlanBlockMemory()) passes.append(relax.transform.VMBuiltinLower()) passes.append(relax.transform.VMShapeLower()) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 9a97931136c8..b8168bd49d90 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -250,6 +250,7 @@ class WellFormedChecker : public relax::ExprVisitor, } else { Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression"); } + for (size_t i = 0; i < op->args.size(); i++) { Expr arg = op->args[i]; if (IsLeafOrTuple(arg)) { @@ -260,6 +261,19 @@ class WellFormedChecker : public relax::ExprVisitor, } } + // call_tir works for tir PrimFunc, call_dps_packed works for ExternFunc + static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); + if (op->op == call_tir_op && !op->args[0]->IsInstance()) { + Malformed(Diagnostic::Error(op->args[0]->span) + << "call_tir expects a prim_func, but gets " << op->args[0]->GetTypeKey()); + } + if (op->op == call_dps_packed_op && !op->args[0]->IsInstance()) { + Malformed(Diagnostic::Error(op->args[0]->span) + << "call_dps_packed expects an extern func, but gets " + << op->args[0]->GetTypeKey()); + } + for (const StructInfo& sinfo_arg : op->sinfo_args) { this->VisitStructInfo(sinfo_arg); } diff --git a/src/relax/transform/call_dps_packed_rewrite.cc b/src/relax/transform/call_dps_packed_rewrite.cc deleted file mode 100644 index 714e0c59eeea..000000000000 --- a/src/relax/transform/call_dps_packed_rewrite.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file src/relax/transform/call_dps_packed_rewrite.cc - * \brief Perform explicit tensor allocation for call_dps_packed. - */ -#include -#include -#include -#include -#include - -#include "../../relay/transforms/pattern_utils.h" - -namespace tvm { -namespace relax { - -// ================== -// CallDPSPackedMutator -// Perform explicit tensor allocation for call_dps_packed. -// Example: -// lv0: Tensor(n, m) = rx.call_dps_packed(func, (x), (n, m), dtype="float32") -// --> -// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m], dtype="float32") -// rx.call_packed(func, x, gv0) - -class CallDPSPackedMutator : public ExprMutator { - public: - using ExprMutator::VisitExpr_; - Expr VisitExpr_(const CallNode* call) override { - // post-order mutation - Expr expr = VisitExprPostOrder_(call); - call = expr.as(); - - static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); - static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); - static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn"); - - if (call->op == call_dps_packed_op) { - Array outs; - if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { - // single output case - const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); - ICHECK(tensor_sinfo->shape.defined()) - << "the TensorStructInfo shape of call_dps_packed has not populated"; - outs.push_back( - builder_->Emit(Call(alloc_tensor_op, // - {Downcast(tensor_sinfo->shape.value()), - DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(0)}, // - Attrs()), - "alloc")); - } else if (const auto& _tuple_sinfo = MatchStructInfo(expr)) { - // multiple output case - const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value(); - for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { - const auto& field = tuple_sinfo->fields[i]; - - ICHECK(field->IsInstance()) - << "call_dps_packed expects Tuple of TensorStructInfo, but got " << field - << " as an element of TupleStructInfo"; - const auto& field_tensor = Downcast(field); - ICHECK(field_tensor->shape.defined()) - << "call_dps_packed expects all TensorStructInfo has shape, but got " << field_tensor - << " as an element of TupleStructInfo"; - outs.push_back( - builder_->Emit(Call(alloc_tensor_op, - {Downcast(field_tensor->shape.value()), - DataTypeImm(field_tensor->dtype), PrimValue::Int64(0)}, - Attrs()), - "alloc")); - } - } else { - LOG(FATAL) - << "TypeError: The struct info of call_dps_packed expects to be TensorStructInfo or " - "TupleStructInfo, but got" - << expr->struct_info_; - } - - Array args; - if (call->args[1].as()) { - args = Downcast(call->args[1])->fields; - args.insert(args.end(), outs.begin(), outs.end()); - - if (call->args.size() == 2) { - builder_->Emit(Call(call->args[0], args), "_"); - } else { - // unpack semantics - args.push_back(call->args[2]); - builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); - } - } else { - args = outs; - args.insert(args.begin(), call->args[1]); - builder_->Emit(Call(call->args[0], args), "_"); - } - - if (outs.size() == 1) { - return outs[0]; - } - return std::move(Tuple(outs)); - } - - return GetRef(call); - } -}; - -Expr CallDPSPackedRewrite(const Expr& e) { return CallDPSPackedMutator().VisitExpr(e); } - -namespace transform { - -Pass CallDPSPackedRewrite() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CallDPSPackedRewrite(f)); - }; - return CreateFunctionPass(pass_func, 0, "CallDPSPackedRewrite", {}); -} - -TVM_REGISTER_GLOBAL("relax.transform.CallDPSPackedRewrite").set_body_typed(CallDPSPackedRewrite); - -} // namespace transform - -} // namespace relax -} // namespace tvm diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc index 2ea039e0229b..6066ed8d2a7d 100644 --- a/src/relax/transform/call_tir_rewrite.cc +++ b/src/relax/transform/call_tir_rewrite.cc @@ -33,7 +33,7 @@ namespace relax { // ================== // CallTIRMutator -// Perform explicit tensor allocation for call_tir. +// Perform explicit tensor allocation for call_tir or call_dps_packed. // Example: // lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32") // --> @@ -49,10 +49,11 @@ class CallTIRMutator : public ExprMutator { call = expr.as(); static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn"); - if (call->op == call_tir_op) { + if (call->op == call_tir_op || call->op == call_dps_packed_op) { Array outs; if (const auto& _tensor_sinfo = MatchStructInfo(expr)) { // single output case diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 49d2b7601137..0614b0c2e289 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -440,6 +440,55 @@ def test_ANF(): assert not rx.analysis.well_formed(mod, check_struct_info=False) +def test_call_tir(): + @tvm.script.ir_module + class TestWellCallTIR: + @T.prim_func + def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + T.func_attr(({"global_symbol": "tir_addone"})) + for i, j in T.grid(16, 16): + with T.block("tir_addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def foo(x: R.Tensor((16, 16), "float32")): + gv = R.call_tir(tir_addone, (x,), R.Tensor((16, 16), dtype="float32")) + return gv + + # well-formed check + assert rx.analysis.well_formed(TestWellCallTIR, check_struct_info=False) + + +def test_call_dps_packed(): + @tvm.script.ir_module + class TestWellCallDPSPacked: + @R.function + def foo(x: R.Tensor((16, 16), "float32")): + gv = R.call_dps_packed("extern_func", (x,), R.Tensor((16, 16), dtype="float32")) + return gv + + assert rx.analysis.well_formed(TestWellCallDPSPacked, check_struct_info=False) + + @tvm.script.ir_module + class TestMalCallDPSPacked: + @T.prim_func + def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + T.func_attr(({"global_symbol": "tir_addone"})) + for i, j in T.grid(16, 16): + with T.block("tir_addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def foo(x: R.Tensor((16, 16), "float32")): + gv = R.call_dps_packed(tir_addone, (x,), R.Tensor((16, 16), dtype="float32")) + return gv + + # call_dps_packed is not able to call tir prim_func + assert not rx.analysis.well_formed(TestMalCallDPSPacked, check_struct_info=False) + + def test_global_var_vs_gsymbol(): # Error: gsymbol "main1" not equals to the name in global var "main" gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index cdfb4e635e8f..3e6305c49287 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -127,8 +127,8 @@ def foo(x: R.Tensor(("m", "n"), "float32")): assert isinstance(s0, relax.Call) assert s0.op.name == "relax.call_dps_packed" - # after rewrite - new_mod = relax.transform.CallDPSPackedRewrite()(mod) + # CallTIRRewrite also works for call_dps_packed + new_mod = relax.transform.CallTIRRewrite()(mod) func = new_mod["foo"] block = func.body.blocks[0] diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index ed8204817bfb..10362e7d0d86 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -115,15 +115,34 @@ def f(x: R.Tensor(("m",), "float32")): return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32")) -def test_unexpected_tir_max_args(): +def test_unexpected_tir_args(): + + with pytest.raises(tvm.error.DiagnosticError): + + @tvm.script.ir_module + class TestWellCallTIR: + @T.prim_func + def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + T.func_attr(({"global_symbol": "tir_addone"})) + for i, j in T.grid(16, 16): + with T.block("tir_addone"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] + T.int32(1) + + @R.function + def foo(x: R.Tensor(("m", "m"), "float32")): + m = T.int64() + # tir.max expects 2 arguments, but got 1 + gv = R.call_tir(tir_addone, (x,), R.Tensor((T.max(16),), dtype="float32")) + return gv with pytest.raises(tvm.error.DiagnosticError): @R.function def f(x: R.Tensor(("m", "n"), "float32")): m = T.int64() - # tir.max expects 2 arguments, but got 1 - return relax.call_tir("foo", (x,), R.Tensor((T.max(m),), dtype="float32")) + # call_tir expected a tir prim_func + return relax.call_tir("extern_func", (x,), R.Tensor((T.max(m),), dtype="float32")) def test_func_type_annotation_fail(): @@ -724,10 +743,10 @@ def bar(x: R.Tensor): assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo) -def test_call_tir_empty_shape(): +def test_call_dps_packed_empty_shape(): @R.function def foo(x: R.Tensor((), "float32")): - z = R.call_tir("scalar_add", x, R.Tensor((), dtype="float32")) + z = R.call_dps_packed("scalar_add", x, R.Tensor((), dtype="float32")) return z (z_bind,) = foo.body.blocks[0].bindings diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 270c91299aeb..76bb3bb81229 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -299,7 +299,7 @@ def test_shape_expr(): def test_call(): x = tir.Var("x", "int64") a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32")) - o0 = relax.call_tir("tir_func", args=a, out_sinfo=a.struct_info, tir_vars=[x]) + o0 = relax.call_tir(relax.GlobalVar("tir_func"), args=a, out_sinfo=a.struct_info, tir_vars=[x]) o1 = relax.call_dps_packed("my_dps_func", args=a, out_sinfo=a.struct_info) _assert_print( o0, diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 0e92e6a6fbf4..776679103f86 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -714,7 +714,7 @@ def test_vm_nested_tuple( @R.function def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: - gv0 = R.call_dps_packed("test_vm_mul", (x, w), R.Tensor((32, 32), dtype="float32")) + gv0 = R.call_tir(test_vm_mul, (x, w), R.Tensor((32, 32), dtype="float32")) return gv0 From b2436613bc18d4c3666cf951daa9f11099159105 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 8 Mar 2023 14:17:01 -0800 Subject: [PATCH 4/7] Remove well_form update, enforce in InferStructInfoCallTIR --- src/relax/analysis/well_formed.cc | 14 ------ src/relax/op/op.cc | 3 ++ .../python/relax/test_analysis_well_formed.py | 49 ------------------- 3 files changed, 3 insertions(+), 63 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index b8168bd49d90..9a97931136c8 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -250,7 +250,6 @@ class WellFormedChecker : public relax::ExprVisitor, } else { Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression"); } - for (size_t i = 0; i < op->args.size(); i++) { Expr arg = op->args[i]; if (IsLeafOrTuple(arg)) { @@ -261,19 +260,6 @@ class WellFormedChecker : public relax::ExprVisitor, } } - // call_tir works for tir PrimFunc, call_dps_packed works for ExternFunc - static const Op& call_tir_op = Op::Get("relax.call_tir"); - static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); - if (op->op == call_tir_op && !op->args[0]->IsInstance()) { - Malformed(Diagnostic::Error(op->args[0]->span) - << "call_tir expects a prim_func, but gets " << op->args[0]->GetTypeKey()); - } - if (op->op == call_dps_packed_op && !op->args[0]->IsInstance()) { - Malformed(Diagnostic::Error(op->args[0]->span) - << "call_dps_packed expects an extern func, but gets " - << op->args[0]->GetTypeKey()); - } - for (const StructInfo& sinfo_arg : op->sinfo_args) { this->VisitStructInfo(sinfo_arg); } diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 001c76dcf3a2..400ba45bfe39 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -65,6 +65,9 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exact 1 output struct info."); } + CHECK(call->args[0]->IsInstance()) + << "call_tir expects the first argument as a GlobalVar referring tir PrimFunc. " + << "However, gets " << call->args[0]; return call->sinfo_args[0]; } diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 0614b0c2e289..49d2b7601137 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -440,55 +440,6 @@ def test_ANF(): assert not rx.analysis.well_formed(mod, check_struct_info=False) -def test_call_tir(): - @tvm.script.ir_module - class TestWellCallTIR: - @T.prim_func - def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: - T.func_attr(({"global_symbol": "tir_addone"})) - for i, j in T.grid(16, 16): - with T.block("tir_addone"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] + T.int32(1) - - @R.function - def foo(x: R.Tensor((16, 16), "float32")): - gv = R.call_tir(tir_addone, (x,), R.Tensor((16, 16), dtype="float32")) - return gv - - # well-formed check - assert rx.analysis.well_formed(TestWellCallTIR, check_struct_info=False) - - -def test_call_dps_packed(): - @tvm.script.ir_module - class TestWellCallDPSPacked: - @R.function - def foo(x: R.Tensor((16, 16), "float32")): - gv = R.call_dps_packed("extern_func", (x,), R.Tensor((16, 16), dtype="float32")) - return gv - - assert rx.analysis.well_formed(TestWellCallDPSPacked, check_struct_info=False) - - @tvm.script.ir_module - class TestMalCallDPSPacked: - @T.prim_func - def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: - T.func_attr(({"global_symbol": "tir_addone"})) - for i, j in T.grid(16, 16): - with T.block("tir_addone"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] + T.int32(1) - - @R.function - def foo(x: R.Tensor((16, 16), "float32")): - gv = R.call_dps_packed(tir_addone, (x,), R.Tensor((16, 16), dtype="float32")) - return gv - - # call_dps_packed is not able to call tir prim_func - assert not rx.analysis.well_formed(TestMalCallDPSPacked, check_struct_info=False) - - def test_global_var_vs_gsymbol(): # Error: gsymbol "main1" not equals to the name in global var "main" gv0 = rx.Var("gv0", R.Tensor([m, n], "float32")) From 06a6a003ed3ab84a91151ae4b64f1e5db79cf3af Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 8 Mar 2023 14:39:08 -0800 Subject: [PATCH 5/7] Update src/relax/op/op.cc Co-authored-by: Steven S. Lyubomirsky --- src/relax/op/op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 400ba45bfe39..f1ea913b1580 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -66,7 +66,7 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { << "sinfo_args should have exact 1 output struct info."); } CHECK(call->args[0]->IsInstance()) - << "call_tir expects the first argument as a GlobalVar referring tir PrimFunc. " + << "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " << "However, gets " << call->args[0]; return call->sinfo_args[0]; } From 3cdf6ea9a553b2f08e5e06056c231365b7f9de06 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 8 Mar 2023 15:05:08 -0800 Subject: [PATCH 6/7] Update description of call_tir --- python/tvm/relax/op/base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 23ce8668f632..aef0e731db51 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -22,7 +22,7 @@ from tvm.runtime.object import Object from . import _ffi_api -from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc +from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar from ..expr import Tuple as RxTuple from ..struct_info import StructInfo, TensorStructInfo from ...ir import PrimExpr @@ -45,18 +45,18 @@ def null_value() -> Call: @args_converter.auto def call_tir( - func: Union[str, Expr], + gvar: GlobalVar, args: Expr, out_sinfo: Union[TensorStructInfo, List[TensorStructInfo]], tir_vars: Optional[Union[ShapeExpr, Tuple[PrimExpr], List[PrimExpr]]] = None, ) -> Call: """ - Call a tir.prim_func function and return the output. + Call a tir.prim_func and return the output. Parameters ---------- - func : Union[str, Expr] - The tir PrimFunc function. + gvar : GlobalVar + The GlobalVar referring to a tir PrimFunc. args : Expr The input arguments. @@ -83,7 +83,7 @@ def call_tir( if isinstance(tir_vars, (list, tuple)): tir_vars = ShapeExpr(tir_vars) - return _ffi_api.call_tir(func, args, out_sinfo, tir_vars) # type: ignore + return _ffi_api.call_tir(gvar, args, out_sinfo, tir_vars) # type: ignore @args_converter.auto From e8b3b685ccda57c4ad44b6d921780f0ce1cfad47 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Thu, 9 Mar 2023 13:19:28 -0800 Subject: [PATCH 7/7] Remove unnecessary check in passes --- src/relax/backend/task_extraction.cc | 5 ---- src/relax/transform/fold_constant.cc | 12 ++++------ src/relax/transform/fuse_ops.cc | 20 ++++++++-------- src/relax/transform/fuse_tir.cc | 23 +++---------------- src/relax/transform/legalize_ops.cc | 3 ++- .../transform/rewrite_dataflow_reshape.cc | 7 ++---- src/script/printer/relax/call.cc | 5 ++-- 7 files changed, 24 insertions(+), 51 deletions(-) diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index beb3950af1d1..5bd764c68e78 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -73,11 +73,6 @@ class TaskExtractor : public ExprVisitor { return; } - // Do not extract external function - if (call->args[0].as()) { - return; - } - const GlobalVar& global_var = Downcast(call->args[0]); const tir::PrimFunc& func = Downcast(mod_->Lookup(global_var)); diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 6b28f3188915..622dd9ad09b7 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -87,13 +87,11 @@ class ConstantFolder : public ExprMutator { * \return The TIR function, or nullopt if pattern match fails. */ Optional MatchPrimFunc(const Expr& op) { - if (auto* ptr = op.as()) { - // NOTE: as check works for nullptr(returns null) - Optional base_func = - builder_->GetContextIRModule()->functions.Get(GetRef(ptr)); - if (auto* pfunc = base_func.as()) { - return GetRef(pfunc); - } + const GlobalVar& global_var = Downcast(op); + // NOTE: as check works for nullptr(returns null) + Optional base_func = builder_->GetContextIRModule()->functions.Get(global_var); + if (auto* pfunc = base_func.as()) { + return GetRef(pfunc); } return NullOpt; } diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index d6013c8874ee..0aec5a070d8c 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -185,19 +185,17 @@ class GraphCreator : public ExprVisitor { // recurse into the call expression. const auto* op = call->op.as(); if (op == call_tir_op_.get()) { - // Skip ExternFunc for call_dps_packed. - if (const auto* global_var = call->args[0].as()) { - tir::PrimFunc func = Downcast(mod_->Lookup(GetRef(global_var))); + const GlobalVar& global_var = Downcast(call->args[0]); + tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); - // Override args for call_tir - args = Downcast(call->args[1])->fields; + // Override args for call_tir + args = Downcast(call->args[1])->fields; - Optional opt_pattern = func->GetAttr("op_pattern"); - if (opt_pattern.defined()) { - pattern = static_cast(Downcast(opt_pattern)->value); - } else { - pattern = OpPatternKind::kOpaque; - } + Optional opt_pattern = func->GetAttr("op_pattern"); + if (opt_pattern.defined()) { + pattern = static_cast(Downcast(opt_pattern)->value); + } else { + pattern = OpPatternKind::kOpaque; } } // The pattern of the current binding variable node is set to the pattern of this operator. diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 925f09d85d34..e90d6e4bc1d1 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -302,11 +302,10 @@ class FusedTIRConstructor : public ExprVisitor { // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); - Optional prim_func_ = GetPrimFunc(gv); - ICHECK(prim_func_.defined()) << "Cannot find the prim_func of the call_tir in the module: " - << gv; + tir::PrimFunc prim_func_ = Downcast(mod_->Lookup(gv)); + // Step 2. Renew all vars/buffer definitions and blocks to avoid duplication - tir::PrimFunc prim_func = tir::RenewDefs(prim_func_.value()); + tir::PrimFunc prim_func = tir::RenewDefs(prim_func_); // Step 3. Check functions are all schedulable funcs. i.e. the body of func is root block // TODO(Siyuan): support un-schedulable functions. @@ -364,22 +363,6 @@ class FusedTIRConstructor : public ExprVisitor { LOG(FATAL) << "Relax.Constant is not supported in primitive functions."; } - /********** Helper Functions **********/ - - /*! - * \brief Pattern match op to a TIR function and look it up. - * \return The TIR function, or NullOpt if patter match fails. - */ - Optional GetPrimFunc(const GlobalVar& global_var) { - // NOTE: as check works for nullptr(returns null) - Optional base_func = mod_->functions.Get(global_var); - if (auto* pfunc = base_func.as()) { - return GetRef(pfunc); - } else { - return NullOpt; - } - } - /*! * \brief Get the number of outputs for a call_tir node. * \return The number of outputs. diff --git a/src/relax/transform/legalize_ops.cc b/src/relax/transform/legalize_ops.cc index f9a84c536101..350a40c37bf8 100644 --- a/src/relax/transform/legalize_ops.cc +++ b/src/relax/transform/legalize_ops.cc @@ -75,6 +75,7 @@ class LegalizeMutator : public ExprMutator { Call visited_call = Downcast(this->VisitExprPostOrder_(call)); static const auto& legalize_map = Op::GetAttrMap("FLegalize"); static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); auto* op_node = visited_call->op.as(); // Not an OpNode @@ -102,7 +103,7 @@ class LegalizeMutator : public ExprMutator { } // No legalization. - if (op != call_tir_op) { + if (op != call_tir_op && op != call_dps_packed_op) { LOG(WARNING) << "No legalization func for " << op->name << " is found."; } return visited_call; diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc index aec0911ecc5a..e5d654fba355 100644 --- a/src/relax/transform/rewrite_dataflow_reshape.cc +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -75,11 +75,8 @@ class DataflowReshapeRewriter : public ExprMutator { if (call->op != call_tir_op) { return false; } - const auto* gv = call->args[0].as(); - if (gv == nullptr) { - return false; - } - const auto* func = mod_->functions.Get(GetRef(gv)).as(); + const GlobalVar& global_var = Downcast(call->args[0]); + const auto* func = mod_->functions.Get(global_var).as(); ICHECK_NOTNULL(func); return HasReshapePattern(GetRef(func)); } diff --git a/src/script/printer/relax/call.cc b/src/script/printer/relax/call.cc index 318a0827899a..e99b81df8b0c 100644 --- a/src/script/printer/relax/call.cc +++ b/src/script/printer/relax/call.cc @@ -95,7 +95,8 @@ ExprDoc PrintCallee(const relax::Expr& n, const ObjectPath& n_p, const IRDocsifi } } -Optional PrintCallTIR(const relax::Call& n, const ObjectPath& n_p, const IRDocsifier& d) { +Optional PrintCallTIRDPSPacked(const relax::Call& n, const ObjectPath& n_p, + const IRDocsifier& d) { static const Op& call_tir_op = Op::Get("relax.call_tir"); static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); if (!n->op.same_as(call_tir_op) && !n->op.same_as(call_dps_packed_op)) { @@ -139,7 +140,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::Call n, ObjectPath n_p, IRDocsifier d) -> Doc { // Special case: call_tir, call_dps_packed - if (Optional doc = PrintCallTIR(n, n_p, d)) { + if (Optional doc = PrintCallTIRDPSPacked(n, n_p, d)) { return doc.value(); } ExprDoc prefix{nullptr};