Skip to content

Commit

Permalink
Remove well_form update, enforce in InferStructInfoCallTIR
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Mar 8, 2023
1 parent f1eedae commit 388972e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 62 deletions.
13 changes: 0 additions & 13 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,19 +261,6 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}

// call_tir works for tir PrimFunc, call_dps_packed works for ExternFunc
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
if (op->op == call_tir_op && !op->args[0]->IsInstance<GlobalVarNode>()) {
Malformed(Diagnostic::Error(op->args[0]->span)
<< "call_tir expects a prim_func, but gets " << op->args[0]->GetTypeKey());
}
if (op->op == call_dps_packed_op && !op->args[0]->IsInstance<ExternFuncNode>()) {
Malformed(Diagnostic::Error(op->args[0]->span)
<< "call_dps_packed expects an extern func, but gets "
<< op->args[0]->GetTypeKey());
}

for (const StructInfo& sinfo_arg : op->sinfo_args) {
this->VisitStructInfo(sinfo_arg);
}
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "sinfo_args should have exact 1 output struct info.");
}
CHECK(call->args[0]->IsInstance<GlobalVarNode>())
<< "call_tir expects the first argument as a GlobalVar referring tir PrimFunc. "
<< "However, gets " << call->args[0];
return call->sinfo_args[0];
}

Expand Down
49 changes: 0 additions & 49 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,55 +440,6 @@ def test_ANF():
assert not rx.analysis.well_formed(mod, check_struct_info=False)


def test_call_tir():
@tvm.script.ir_module
class TestWellCallTIR:
@T.prim_func
def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None:
T.func_attr(({"global_symbol": "tir_addone"}))
for i, j in T.grid(16, 16):
with T.block("tir_addone"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.int32(1)

@R.function
def foo(x: R.Tensor((16, 16), "float32")):
gv = R.call_tir(tir_addone, (x,), R.Tensor((16, 16), dtype="float32"))
return gv

# well-formed check
assert rx.analysis.well_formed(TestWellCallTIR, check_struct_info=False)


def test_call_dps_packed():
@tvm.script.ir_module
class TestWellCallDPSPacked:
@R.function
def foo(x: R.Tensor((16, 16), "float32")):
gv = R.call_dps_packed("extern_func", (x,), R.Tensor((16, 16), dtype="float32"))
return gv

assert rx.analysis.well_formed(TestWellCallDPSPacked, check_struct_info=False)

@tvm.script.ir_module
class TestMalCallDPSPacked:
@T.prim_func
def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None:
T.func_attr(({"global_symbol": "tir_addone"}))
for i, j in T.grid(16, 16):
with T.block("tir_addone"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.int32(1)

@R.function
def foo(x: R.Tensor((16, 16), "float32")):
gv = R.call_dps_packed(tir_addone, (x,), R.Tensor((16, 16), dtype="float32"))
return gv

# call_dps_packed is not able to call tir prim_func
assert not rx.analysis.well_formed(TestMalCallDPSPacked, check_struct_info=False)


def test_global_var_vs_gsymbol():
# Error: gsymbol "main1" not equals to the name in global var "main"
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
Expand Down

0 comments on commit 388972e

Please sign in to comment.