Skip to content

Commit

Permalink
Fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Mar 7, 2023
1 parent be660e3 commit f1eedae
Show file tree
Hide file tree
Showing 12 changed files with 97 additions and 175 deletions.
9 changes: 1 addition & 8 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,12 @@ TVM_DLL Pass LambdaLift();
TVM_DLL Pass ToNonDataflow();

/*!
* \brief Perform explicit tensor allocation for call_tir.
* \brief Perform explicit tensor allocation for call_tir and call_dps_packed.
*
* \return The Pass.
*/
TVM_DLL Pass CallTIRRewrite();

/*!
* \brief Perform explicit tensor allocation for call_dps_packed.
*
* \return The Pass.
*/
TVM_DLL Pass CallDPSPackedRewrite();

/*!
* \brief Convert all reshape-like call_tir whose corresponding binding
* vars are DataflowVars to relax.reshape operator calls. The relax.reshape
Expand Down
5 changes: 1 addition & 4 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.runtime.object import Object

from . import _ffi_api
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc, GlobalVar
from ..expr import Expr, StringImm, ShapeExpr, Call, ExternFunc
from ..expr import Tuple as RxTuple
from ..struct_info import StructInfo, TensorStructInfo
from ...ir import PrimExpr
Expand Down Expand Up @@ -74,9 +74,6 @@ def call_tir(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(func, str):
func = GlobalVar(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))

Expand Down
12 changes: 1 addition & 11 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def LambdaLift():


def CallTIRRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_tir.
"""Perform explicit tensor allocation for call_tir and call_dps_packed.
Returns
-------
Expand All @@ -69,16 +69,6 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass:
return _ffi_api.CallTIRRewrite() # type: ignore


def CallDPSPackedRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_dps_packed.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.CallDPSPackedRewrite() # type: ignore


def Normalize() -> tvm.ir.transform.Pass:
"""Transforming Relax IR to normal form, i.e., the expressions are normalized(no nesting
and hence the AST is in ANF), and all checked_type_ and shape_ of expressions are available.
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relax/vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):
passes.append(relax.transform.RewriteDataflowReshape())
passes.append(relax.transform.ToNonDataflow())
passes.append(relax.transform.CallTIRRewrite())
passes.append(relax.transform.CallDPSPackedRewrite())
passes.append(relax.transform.StaticPlanBlockMemory())
passes.append(relax.transform.VMBuiltinLower())
passes.append(relax.transform.VMShapeLower())
Expand Down
14 changes: 14 additions & 0 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ class WellFormedChecker : public relax::ExprVisitor,
} else {
Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression");
}

for (size_t i = 0; i < op->args.size(); i++) {
Expr arg = op->args[i];
if (IsLeafOrTuple(arg)) {
Expand All @@ -260,6 +261,19 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}

// call_tir works for tir PrimFunc, call_dps_packed works for ExternFunc
static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
if (op->op == call_tir_op && !op->args[0]->IsInstance<GlobalVarNode>()) {
Malformed(Diagnostic::Error(op->args[0]->span)
<< "call_tir expects a prim_func, but gets " << op->args[0]->GetTypeKey());
}
if (op->op == call_dps_packed_op && !op->args[0]->IsInstance<ExternFuncNode>()) {
Malformed(Diagnostic::Error(op->args[0]->span)
<< "call_dps_packed expects an extern func, but gets "
<< op->args[0]->GetTypeKey());
}

for (const StructInfo& sinfo_arg : op->sinfo_args) {
this->VisitStructInfo(sinfo_arg);
}
Expand Down
140 changes: 0 additions & 140 deletions src/relax/transform/call_dps_packed_rewrite.cc

This file was deleted.

5 changes: 3 additions & 2 deletions src/relax/transform/call_tir_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace relax {

// ==================
// CallTIRMutator
// Perform explicit tensor allocation for call_tir.
// Perform explicit tensor allocation for call_tir or call_dps_packed.
// Example:
// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32")
// -->
Expand All @@ -49,10 +49,11 @@ class CallTIRMutator : public ExprMutator {
call = expr.as<CallNode>();

static const Op& call_tir_op = Op::Get("relax.call_tir");
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn");

if (call->op == call_tir_op) {
if (call->op == call_tir_op || call->op == call_dps_packed_op) {
Array<Expr> outs;
if (const auto& _tensor_sinfo = MatchStructInfo<TensorStructInfo>(expr)) {
// single output case
Expand Down
49 changes: 49 additions & 0 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,55 @@ def test_ANF():
assert not rx.analysis.well_formed(mod, check_struct_info=False)


def test_call_tir():
@tvm.script.ir_module
class TestWellCallTIR:
@T.prim_func
def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None:
T.func_attr(({"global_symbol": "tir_addone"}))
for i, j in T.grid(16, 16):
with T.block("tir_addone"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.int32(1)

@R.function
def foo(x: R.Tensor((16, 16), "float32")):
gv = R.call_tir(tir_addone, (x,), R.Tensor((16, 16), dtype="float32"))
return gv

# well-formed check
assert rx.analysis.well_formed(TestWellCallTIR, check_struct_info=False)


def test_call_dps_packed():
@tvm.script.ir_module
class TestWellCallDPSPacked:
@R.function
def foo(x: R.Tensor((16, 16), "float32")):
gv = R.call_dps_packed("extern_func", (x,), R.Tensor((16, 16), dtype="float32"))
return gv

assert rx.analysis.well_formed(TestWellCallDPSPacked, check_struct_info=False)

@tvm.script.ir_module
class TestMalCallDPSPacked:
@T.prim_func
def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None:
T.func_attr(({"global_symbol": "tir_addone"}))
for i, j in T.grid(16, 16):
with T.block("tir_addone"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.int32(1)

@R.function
def foo(x: R.Tensor((16, 16), "float32")):
gv = R.call_dps_packed(tir_addone, (x,), R.Tensor((16, 16), dtype="float32"))
return gv

# call_dps_packed is not able to call tir prim_func
assert not rx.analysis.well_formed(TestMalCallDPSPacked, check_struct_info=False)


def test_global_var_vs_gsymbol():
# Error: gsymbol "main1" not equals to the name in global var "main"
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def foo(x: R.Tensor(("m", "n"), "float32")):
assert isinstance(s0, relax.Call)
assert s0.op.name == "relax.call_dps_packed"

# after rewrite
new_mod = relax.transform.CallDPSPackedRewrite()(mod)
# CallTIRRewrite also works for call_dps_packed
new_mod = relax.transform.CallTIRRewrite()(mod)
func = new_mod["foo"]

block = func.body.blocks[0]
Expand Down
29 changes: 24 additions & 5 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,34 @@ def f(x: R.Tensor(("m",), "float32")):
return R.call_tir("foo", (x,), R.Tensor((T.cast("int32", m, 1),), dtype="float32"))


def test_unexpected_tir_max_args():
def test_unexpected_tir_args():

with pytest.raises(tvm.error.DiagnosticError):

@tvm.script.ir_module
class TestWellCallTIR:
@T.prim_func
def tir_addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None:
T.func_attr(({"global_symbol": "tir_addone"}))
for i, j in T.grid(16, 16):
with T.block("tir_addone"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + T.int32(1)

@R.function
def foo(x: R.Tensor(("m", "m"), "float32")):
m = T.int64()
# tir.max expects 2 arguments, but got 1
gv = R.call_tir(tir_addone, (x,), R.Tensor((T.max(16),), dtype="float32"))
return gv

with pytest.raises(tvm.error.DiagnosticError):

@R.function
def f(x: R.Tensor(("m", "n"), "float32")):
m = T.int64()
# tir.max expects 2 arguments, but got 1
return relax.call_tir("foo", (x,), R.Tensor((T.max(m),), dtype="float32"))
# call_tir expected a tir prim_func
return relax.call_tir("extern_func", (x,), R.Tensor((T.max(m),), dtype="float32"))


def test_func_type_annotation_fail():
Expand Down Expand Up @@ -724,10 +743,10 @@ def bar(x: R.Tensor):
assert isinstance(z_bind.var.struct_info, relax.TensorStructInfo)


def test_call_tir_empty_shape():
def test_call_dps_packed_empty_shape():
@R.function
def foo(x: R.Tensor((), "float32")):
z = R.call_tir("scalar_add", x, R.Tensor((), dtype="float32"))
z = R.call_dps_packed("scalar_add", x, R.Tensor((), dtype="float32"))
return z

(z_bind,) = foo.body.blocks[0].bindings
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_tvmscript_printer_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def test_shape_expr():
def test_call():
x = tir.Var("x", "int64")
a = relax.Var("a", relax.TensorStructInfo([1, x, 3], "float32"))
o0 = relax.call_tir("tir_func", args=a, out_sinfo=a.struct_info, tir_vars=[x])
o0 = relax.call_tir(relax.GlobalVar("tir_func"), args=a, out_sinfo=a.struct_info, tir_vars=[x])
o1 = relax.call_dps_packed("my_dps_func", args=a, out_sinfo=a.struct_info)
_assert_print(
o0,
Expand Down
Loading

0 comments on commit f1eedae

Please sign in to comment.