From b2436613bc18d4c3666cf951daa9f11099159105 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Wed, 8 Mar 2023 14:17:01 -0800 Subject: [PATCH] Remove well_form update, enforce in InferStructInfoCallTIR --- src/relax/analysis/well_formed.cc | 14 ------ src/relax/op/op.cc | 3 ++ .../python/relax/test_analysis_well_formed.py | 49 ------------------- 3 files changed, 3 insertions(+), 63 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index b8168bd49d90..9a97931136c8 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -250,7 +250,6 @@ class WellFormedChecker : public relax::ExprVisitor, } else { Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression"); } - for (size_t i = 0; i < op->args.size(); i++) { Expr arg = op->args[i]; if (IsLeafOrTuple(arg)) { @@ -261,19 +260,6 @@ class WellFormedChecker : public relax::ExprVisitor, } } - // call_tir works for tir PrimFunc, call_dps_packed works for ExternFunc - static const Op& call_tir_op = Op::Get("relax.call_tir"); - static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed"); - if (op->op == call_tir_op && !op->args[0]->IsInstance()) { - Malformed(Diagnostic::Error(op->args[0]->span) - << "call_tir expects a prim_func, but gets " << op->args[0]->GetTypeKey()); - } - if (op->op == call_dps_packed_op && !op->args[0]->IsInstance()) { - Malformed(Diagnostic::Error(op->args[0]->span) - << "call_dps_packed expects an extern func, but gets " - << op->args[0]->GetTypeKey()); - } - for (const StructInfo& sinfo_arg : op->sinfo_args) { this->VisitStructInfo(sinfo_arg); } diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 001c76dcf3a2..400ba45bfe39 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -65,6 +65,9 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exact 1 output struct info."); } + CHECK(call->args[0]->IsInstance()) + << "call_tir expects the first argument as a GlobalVar referring tir PrimFunc. " + << "However, gets " << call->args[0]; return call->sinfo_args[0]; } diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 0614b0c2e289..49d2b7601137 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -440,55 +440,6 @@ def test_ANF(): assert not rx.analysis.well_formed(mod, check_struct_info=False) -def test_call_tir(): - @tvm.script.ir_module - class TestWellCallTIR: - @T.prim_func - def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: - T.func_attr(({"global_symbol": "tir_addone"})) - for i, j in T.grid(16, 16): - with T.block("tir_addone"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] + T.int32(1) - - @R.function - def foo(x: R.Tensor((16, 16), "float32")): - gv = R.call_tir(tir_addone, (x,), R.Tensor((16, 16), dtype="float32")) - return gv - - # well-formed check - assert rx.analysis.well_formed(TestWellCallTIR, check_struct_info=False) - - -def test_call_dps_packed(): - @tvm.script.ir_module - class TestWellCallDPSPacked: - @R.function - def foo(x: R.Tensor((16, 16), "float32")): - gv = R.call_dps_packed("extern_func", (x,), R.Tensor((16, 16), dtype="float32")) - return gv - - assert rx.analysis.well_formed(TestWellCallDPSPacked, check_struct_info=False) - - @tvm.script.ir_module - class TestMalCallDPSPacked: - @T.prim_func - def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: - T.func_attr(({"global_symbol": "tir_addone"})) - for i, j in T.grid(16, 16): - with T.block("tir_addone"): - vi, vj = T.axis.remap("SS", [i, j]) - B[vi, vj] = A[vi, vj] + T.int32(1) - - @R.function - def foo(x: R.Tensor((16, 16), "float32")): - gv = R.call_dps_packed(tir_addone, (x,), R.Tensor((16, 16), dtype="float32")) - return gv - - # call_dps_packed is not able to call tir prim_func - assert not rx.analysis.well_formed(TestMalCallDPSPacked, check_struct_info=False) - - def test_global_var_vs_gsymbol(): # Error: gsymbol "main1" not equals to the name in global var "main" gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))