Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[TVMScript] Use explicit R.shape in TVMScript (#435)
Browse files Browse the repository at this point in the history
As we've introduced `arg_sinfo` in CallNode, implicit shape constructor
is not widely used in TVMScript. This PR removes the implicit shape since
it may cause confusion between shape and tuple.
  • Loading branch information
Hzfengsy authored Feb 13, 2023
1 parent 5302925 commit 69a9700
Show file tree
Hide file tree
Showing 15 changed files with 146 additions and 109 deletions.
16 changes: 3 additions & 13 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ..runtime import String, convert_to_object
from ..tir import PrimExpr
from . import _ffi_api
from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm
from .expr import Expr, Function, PrimValue, StringImm
from .expr import Tuple as rx_Tuple


Expand Down Expand Up @@ -74,14 +74,12 @@ def convert_to_expr(value: Any) -> Expr:
1. Return the input itself if it's already a `relax.Expr`;
2. Return `relax.PrimValue` if the input is a `PrimExpr`;
3. Return `relax.StringImm` if the input is `tvm.String` or `str`;
4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/ int dtype;
5. Return `relax.Tuple` if the input is a tuple/list of `Expr`.
4. Return `relax.Tuple` if the input is a tuple/list of `Expr`.
Notes
-----
1. `tvm.tir.StringImm` is not allowed because of ambiguity,
which can be either `relax.StringImm` or `relax.PrimValue`.
2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr`
"""
if isinstance(value, int):
return PrimValue(tir.IntImm("int64", value))
Expand All @@ -102,16 +100,8 @@ def convert_to_expr(value: Any) -> Expr:
# Case 3
if isinstance(tvm_value, String):
return StringImm(value)
# Case 4 & 5
# Case 4
if isinstance(value, (tuple, list)):
# Note 2
if len(value) == 0:
return rx_Tuple([])
# Case 4
opt_prim_value = [convert_to_object(v) for v in value]
if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in opt_prim_value]):
return ShapeExpr(value)
# Case 5
# `convert_to_expr` ensures that all elements are `Expr` if no exception raises
return rx_Tuple([convert_to_expr(v) for v in value])
raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`")
Expand Down
22 changes: 20 additions & 2 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,11 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name
############################### R.tuple ################################


def tuple(*fields: List[Expr]) -> Expr:
def tuple(*fields: Expr) -> Expr:
"""Create a tuple expression.
Parameters
----------
fields : List[Expr]
*fields : Expr
The fields of the tuple.
Returns
-------
Expand All @@ -399,6 +399,23 @@ def tuple(*fields: List[Expr]) -> Expr:
return relax.Tuple(fields) # pylint: disable=no-member # type: ignore


############################### R.shape ################################


def shape(value: List[PrimExpr]) -> Expr:
"""Create a ShapeExpr.
Parameters
----------
value : List[PrimExpr]
The fields of the tuple.
Returns
-------
res : Expr
The result tuple.
"""
return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore


############################### PrimValue ##############################


Expand Down Expand Up @@ -524,6 +541,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"prod",
"reshape",
"round",
"shape",
"shape_of",
"sigmoid",
"sign",
Expand Down
22 changes: 17 additions & 5 deletions python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from tvm.relax import (
Expr,
ShapeExpr,
FuncStructInfo,
Function,
ObjectStructInfo,
Expand Down Expand Up @@ -84,24 +85,31 @@ class TensorProxy(StructInfoProxy):

def __init__(
self,
shape: Optional[List[Union[PrimExpr, str]]] = None,
shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> None:
self.shape = shape
if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr):
raise ValueError(
"Only ShapeExpr is allowed as shape expr, but got: "
f"{shape} with type: {type(shape)}"
)
self.dtype = dtype
self.ndim = ndim
super().__init__()

def get_symbolic_vars(self) -> Set[str]:
if self.shape is None:
if self.shape is None or isinstance(self.shape, Expr):
return {}
else:
return {s for s in self.shape if isinstance(s, str) and s.isidentifier()}

def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo:
if self.shape is None:
return TensorStructInfo(None, self.dtype, self.ndim)
elif isinstance(self.shape, ShapeExpr):
return TensorStructInfo(self.shape, self.dtype, self.ndim)
else:
if dict_globals is None and any([isinstance(s, str) for s in self.shape]):
raise ValueError(
Expand All @@ -113,7 +121,7 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso


def Tensor(
shape: Optional[List[Union[PrimExpr, str]]] = None,
shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> TensorProxy:
Expand All @@ -124,8 +132,12 @@ def Tensor(
dtype = shape
shape = None

if shape is not None and not isinstance(shape, (tuple, list)):
raise ValueError(f"shape must be a list or tuple, but got: {shape}")
if (
shape is not None
and not isinstance(shape, (tuple, list))
and not isinstance(shape, ShapeExpr)
):
raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}")
return TensorProxy(shape, dtype, ndim)


Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/relax/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
for (int i = 0, l = n->values.size(); i < l; ++i) {
values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d));
}
return TupleDoc(values_doc);
return Relax(d, "shape")->Call({ListDoc(values_doc)});
});

Optional<ExprDoc> SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) {
Expand Down
14 changes: 13 additions & 1 deletion src/script/printer/relax/struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Array<String> kwargs_keys;
Array<ExprDoc> kwargs_values;
if (n->shape.defined()) {
args.push_back(d->AsDoc<ExprDoc>(n->shape.value(), n_p->Attr("shape")));
// Need to dig into ShapeExpr to preserve the `R.shape` prefix
if (const auto* shape = n->shape.value().as<relax::ShapeExprNode>()) {
auto shape_expr = GetRef<relax::ShapeExpr>(shape);
ObjectPath shape_p = n_p->Attr("shape")->Attr("values");
Array<ExprDoc> shape_docs;
for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) {
shape_docs.push_back(
PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d));
}
args.push_back(TupleDoc(shape_docs));
} else {
args.push_back(d->AsDoc<ExprDoc>(n->shape.value(), n_p->Attr("shape")));
}
}
if (!n->IsUnknownDtype()) {
kwargs_keys.push_back("dtype");
Expand Down
16 changes: 8 additions & 8 deletions tests/python/relax/test_analysis_estimate_memory_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,40 +69,40 @@ def pad(
@R.function
def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
storage: R.Object = R.memory.alloc_storage(
(32,), virtual_device_index=0, storage_scope="global", dtype="float32"
R.shape([32]), virtual_device_index=0, storage_scope="global", dtype="float32"
)
alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(
storage, offset=0, shape=(2, 4), dtype="float32"
storage, offset=0, shape=R.shape([2, 4]), dtype="float32"
)
_: R.Tuple() = exp(x, alloc)
lv: R.Tensor((2, 4), dtype="float32") = alloc
lv1: R.Tensor((8,), dtype="float32") = R.call_packed(
"vm.builtin.reshape", lv, (8,), sinfo_args=[R.Tensor((8,), dtype="float32")]
"vm.builtin.reshape", lv, R.shape([8]), sinfo_args=[R.Tensor((8,), dtype="float32")]
)
storage1: R.Object = R.memory.alloc_storage(
(40,), virtual_device_index=0, storage_scope="global", dtype="float32"
R.shape([40]), virtual_device_index=0, storage_scope="global", dtype="float32"
)
alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
storage1, offset=0, shape=(8,), dtype="float32"
storage1, offset=0, shape=R.shape([8]), dtype="float32"
)
_1: R.Tuple() = relu(lv1, alloc1)
_2: R.Tuple() = R.memory.kill_tensor(alloc)
_3: R.Tuple() = R.memory.kill_tensor(lv1)
lv2: R.Tensor((8,), dtype="float32") = alloc1
alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(
storage, offset=0, shape=(8,), dtype="float32"
storage, offset=0, shape=R.shape([8]), dtype="float32"
)
_4: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2)
_5: R.Tuple() = R.memory.kill_tensor(alloc1)
lv3: R.Tensor((8,), dtype="float32") = alloc2
alloc3: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor(
storage1, offset=0, shape=(10,), dtype="float32"
storage1, offset=0, shape=R.shape([10]), dtype="float32"
)
_6: R.Tuple() = pad(lv3, alloc3)
_7: R.Tuple() = R.memory.kill_tensor(alloc2)
lv4: R.Tensor((10,), dtype="float32") = alloc3
alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(
(10,), dtype="float32", runtime_device_index=0
R.shape([10]), dtype="float32", runtime_device_index=0
)
_8: R.Tuple() = log(lv4, alloc4)
_9: R.Tuple() = R.memory.kill_tensor(alloc3)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_backend_transform_shape_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def main(
n = T.Var("n", "int64")
k = T.Var("k", "int64")
z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None))
return (k + 1, m, 2)
return R.shape([k + 1, m, 2])

# slot assignment:
# 0: n, 1: m, 2:k, 3: k+1
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,7 +1041,7 @@ class expected1:
def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]):
# block 0
with R.dataflow():
gv: R.Shape([1, 3, 10, 10]) = (1, 3, 10, 10)
gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10])
R.output(gv)
return gv

Expand Down Expand Up @@ -1116,7 +1116,7 @@ class expected1:
def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]):
# block 0
with R.dataflow():
gv: R.Shape([1, 3, 10, 10]) = (1, 3, 10, 10)
gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10])
R.output(gv)
return gv

Expand Down
6 changes: 3 additions & 3 deletions tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ def f(x: R.Tensor(("n", "m"), "float32")):
z: R.Tensor((n * m,), "float32") = R.call_packed(
"my_flatten", (x,), sinfo_args=(R.Tensor(ndim=1, dtype="float32"))
)
sh: R.Shape = (n + m, n // m)
sh: R.Shape = R.shape([n + m, n // m])
return z

x = f.params[0]
Expand Down Expand Up @@ -870,8 +870,8 @@ def mul_add(x: R.Tensor) -> R.Tensor:
def test_memory_op():
@R.function
def memory(x: R.Tensor) -> R.Tensor:
storage = R.memory.alloc_storage((1024,), -1, "global", "float32")
alloca = R.memory.alloc_tensor(storage, 0, (1, 256), "float32")
storage = R.memory.alloc_storage(R.shape([1024]), -1, "global", "float32")
alloca = R.memory.alloc_tensor(storage, 0, R.shape([1, 256]), "float32")
_ = R.memory.kill_tensor(alloca)
_ = R.memory.kill_storage(storage)
return alloca
Expand Down
6 changes: 3 additions & 3 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class TestVMBuiltinLower:
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
m, n = T.var("int64"), T.var("int64")
alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32")
alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32")
_ = R.call_packed(
"test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))
)
Expand Down Expand Up @@ -325,15 +325,15 @@ def test_vm_builtin_lower_reshape():
class TestVMReshape:
@R.function
def main(x: R.Tensor((3, 4), "float32")):
y = R.reshape(x, (6, 2))
y = R.reshape(x, R.shape([6, 2]))
return y

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((3, 4), "float32")):
y: R.Tensor((6, 2), "float32") = R.call_packed(
"vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32")
"vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32")
)
return y

Expand Down
Loading

0 comments on commit 69a9700

Please sign in to comment.