From 69a970083b21e6ddfa539fe396673d339f43aaac Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Tue, 14 Feb 2023 03:50:35 +0800 Subject: [PATCH] [TVMScript] Use explicit `R.shape` in TVMScript (#435) As we've introduced `arg_sinfo` in CallNode, implicit shape constructor is not widely used in TVMScript. This PR removes the implicit shape since it may cause confusion between shape and tuple. --- python/tvm/relax/utils.py | 16 +-- python/tvm/script/ir_builder/relax/ir.py | 22 +++- python/tvm/script/parser/relax/entry.py | 22 +++- src/script/printer/relax/expr.cc | 2 +- src/script/printer/relax/struct_info.cc | 14 ++- .../test_analysis_estimate_memory_usage.py | 16 +-- .../test_backend_transform_shape_lower.py | 2 +- tests/python/relax/test_frontend_from_fx.py | 4 +- tests/python/relax/test_parser.py | 6 +- tests/python/relax/test_transform.py | 6 +- ...test_transform_static_plan_block_memory.py | 102 +++++++++--------- tests/python/relax/test_tvmscript_parser.py | 15 ++- .../relax/test_tvmscript_printer_relax.py | 4 +- tests/python/relax/test_vm_build.py | 6 +- tests/python/relax/test_vm_codegen_only.py | 18 ++-- 15 files changed, 146 insertions(+), 109 deletions(-) diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 81d8b2f8bf..d6b405f183 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -23,7 +23,7 @@ from ..runtime import String, convert_to_object from ..tir import PrimExpr from . import _ffi_api -from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm +from .expr import Expr, Function, PrimValue, StringImm from .expr import Tuple as rx_Tuple @@ -74,14 +74,12 @@ def convert_to_expr(value: Any) -> Expr: 1. Return the input itself if it's already a `relax.Expr`; 2. Return `relax.PrimValue` if the input is a `PrimExpr`; 3. Return `relax.StringImm` if the input is `tvm.String` or `str`; - 4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/ int dtype; - 5. Return `relax.Tuple` if the input is a tuple/list of `Expr`. + 4. Return `relax.Tuple` if the input is a tuple/list of `Expr`. Notes ----- 1. `tvm.tir.StringImm` is not allowed because of ambiguity, which can be either `relax.StringImm` or `relax.PrimValue`. - 2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr` """ if isinstance(value, int): return PrimValue(tir.IntImm("int64", value)) @@ -102,16 +100,8 @@ def convert_to_expr(value: Any) -> Expr: # Case 3 if isinstance(tvm_value, String): return StringImm(value) - # Case 4 & 5 + # Case 4 if isinstance(value, (tuple, list)): - # Note 2 - if len(value) == 0: - return rx_Tuple([]) - # Case 4 - opt_prim_value = [convert_to_object(v) for v in value] - if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in opt_prim_value]): - return ShapeExpr(value) - # Case 5 # `convert_to_expr` ensures that all elements are `Expr` if no exception raises return rx_Tuple([convert_to_expr(v) for v in value]) raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`") diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ccefcc3a4b..c1b3cb0707 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -382,11 +382,11 @@ def Else() -> frame.ElseFrame: # pylint: disable=invalid-name ############################### R.tuple ################################ -def tuple(*fields: List[Expr]) -> Expr: +def tuple(*fields: Expr) -> Expr: """Create a tuple expression. Parameters ---------- - fields : List[Expr] + *fields : Expr The fields of the tuple. Returns ------- @@ -399,6 +399,23 @@ def tuple(*fields: List[Expr]) -> Expr: return relax.Tuple(fields) # pylint: disable=no-member # type: ignore +############################### R.shape ################################ + + +def shape(value: List[PrimExpr]) -> Expr: + """Create a ShapeExpr. + Parameters + ---------- + value : List[PrimExpr] + The fields of the tuple. + Returns + ------- + res : Expr + The result tuple. + """ + return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore + + ############################### PrimValue ############################## @@ -524,6 +541,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "prod", "reshape", "round", + "shape", "shape_of", "sigmoid", "sign", diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index d93f9a2826..7e51264cb3 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -22,6 +22,7 @@ from tvm.relax import ( Expr, + ShapeExpr, FuncStructInfo, Function, ObjectStructInfo, @@ -84,17 +85,22 @@ class TensorProxy(StructInfoProxy): def __init__( self, - shape: Optional[List[Union[PrimExpr, str]]] = None, + shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None, dtype: Optional[str] = None, ndim: int = -1, ) -> None: self.shape = shape + if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr): + raise ValueError( + "Only ShapeExpr is allowed as shape expr, but got: " + f"{shape} with type: {type(shape)}" + ) self.dtype = dtype self.ndim = ndim super().__init__() def get_symbolic_vars(self) -> Set[str]: - if self.shape is None: + if self.shape is None or isinstance(self.shape, Expr): return {} else: return {s for s in self.shape if isinstance(s, str) and s.isidentifier()} @@ -102,6 +108,8 @@ def get_symbolic_vars(self) -> Set[str]: def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo: if self.shape is None: return TensorStructInfo(None, self.dtype, self.ndim) + elif isinstance(self.shape, ShapeExpr): + return TensorStructInfo(self.shape, self.dtype, self.ndim) else: if dict_globals is None and any([isinstance(s, str) for s in self.shape]): raise ValueError( @@ -113,7 +121,7 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso def Tensor( - shape: Optional[List[Union[PrimExpr, str]]] = None, + shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None, dtype: Optional[str] = None, ndim: int = -1, ) -> TensorProxy: @@ -124,8 +132,12 @@ def Tensor( dtype = shape shape = None - if shape is not None and not isinstance(shape, (tuple, list)): - raise ValueError(f"shape must be a list or tuple, but got: {shape}") + if ( + shape is not None + and not isinstance(shape, (tuple, list)) + and not isinstance(shape, ShapeExpr) + ): + raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}") return TensorProxy(shape, dtype, ndim) diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index a786932fc3..66d7d187d0 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0, l = n->values.size(); i < l; ++i) { values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d)); } - return TupleDoc(values_doc); + return Relax(d, "shape")->Call({ListDoc(values_doc)}); }); Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) { diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index 6f4a66c991..c541619ec8 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -89,7 +89,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array kwargs_keys; Array kwargs_values; if (n->shape.defined()) { - args.push_back(d->AsDoc(n->shape.value(), n_p->Attr("shape"))); + // Need to dig into ShapeExpr to preserve the `R.shape` prefix + if (const auto* shape = n->shape.value().as()) { + auto shape_expr = GetRef(shape); + ObjectPath shape_p = n_p->Attr("shape")->Attr("values"); + Array shape_docs; + for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { + shape_docs.push_back( + PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d)); + } + args.push_back(TupleDoc(shape_docs)); + } else { + args.push_back(d->AsDoc(n->shape.value(), n_p->Attr("shape"))); + } } if (!n->IsUnknownDtype()) { kwargs_keys.push_back("dtype"); diff --git a/tests/python/relax/test_analysis_estimate_memory_usage.py b/tests/python/relax/test_analysis_estimate_memory_usage.py index 8185ed622e..d9d243309c 100644 --- a/tests/python/relax/test_analysis_estimate_memory_usage.py +++ b/tests/python/relax/test_analysis_estimate_memory_usage.py @@ -69,40 +69,40 @@ def pad( @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): storage: R.Object = R.memory.alloc_storage( - (32,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([32]), virtual_device_index=0, storage_scope="global", dtype="float32" ) alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor( - storage, offset=0, shape=(2, 4), dtype="float32" + storage, offset=0, shape=R.shape([2, 4]), dtype="float32" ) _: R.Tuple() = exp(x, alloc) lv: R.Tensor((2, 4), dtype="float32") = alloc lv1: R.Tensor((8,), dtype="float32") = R.call_packed( - "vm.builtin.reshape", lv, (8,), sinfo_args=[R.Tensor((8,), dtype="float32")] + "vm.builtin.reshape", lv, R.shape([8]), sinfo_args=[R.Tensor((8,), dtype="float32")] ) storage1: R.Object = R.memory.alloc_storage( - (40,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([40]), virtual_device_index=0, storage_scope="global", dtype="float32" ) alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( - storage1, offset=0, shape=(8,), dtype="float32" + storage1, offset=0, shape=R.shape([8]), dtype="float32" ) _1: R.Tuple() = relu(lv1, alloc1) _2: R.Tuple() = R.memory.kill_tensor(alloc) _3: R.Tuple() = R.memory.kill_tensor(lv1) lv2: R.Tensor((8,), dtype="float32") = alloc1 alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( - storage, offset=0, shape=(8,), dtype="float32" + storage, offset=0, shape=R.shape([8]), dtype="float32" ) _4: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2) _5: R.Tuple() = R.memory.kill_tensor(alloc1) lv3: R.Tensor((8,), dtype="float32") = alloc2 alloc3: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor( - storage1, offset=0, shape=(10,), dtype="float32" + storage1, offset=0, shape=R.shape([10]), dtype="float32" ) _6: R.Tuple() = pad(lv3, alloc3) _7: R.Tuple() = R.memory.kill_tensor(alloc2) lv4: R.Tensor((10,), dtype="float32") = alloc3 alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor( - (10,), dtype="float32", runtime_device_index=0 + R.shape([10]), dtype="float32", runtime_device_index=0 ) _8: R.Tuple() = log(lv4, alloc4) _9: R.Tuple() = R.memory.kill_tensor(alloc3) diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index bf1bc61a6e..85d5a24044 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -167,7 +167,7 @@ def main( n = T.Var("n", "int64") k = T.Var("k", "int64") z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) - return (k + 1, m, 2) + return R.shape([k + 1, m, 2]) # slot assignment: # 0: n, 1: m, 2:k, 3: k+1 diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index eeaf5871ec..9b35d34bd3 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -1041,7 +1041,7 @@ class expected1: def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): # block 0 with R.dataflow(): - gv: R.Shape([1, 3, 10, 10]) = (1, 3, 10, 10) + gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10]) R.output(gv) return gv @@ -1116,7 +1116,7 @@ class expected1: def main(input_1: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Shape([1, 3, 10, 10]): # block 0 with R.dataflow(): - gv: R.Shape([1, 3, 10, 10]) = (1, 3, 10, 10) + gv: R.Shape([1, 3, 10, 10]) = R.shape([1, 3, 10, 10]) R.output(gv) return gv diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index ef47539aaa..99c3a66d52 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -624,7 +624,7 @@ def f(x: R.Tensor(("n", "m"), "float32")): z: R.Tensor((n * m,), "float32") = R.call_packed( "my_flatten", (x,), sinfo_args=(R.Tensor(ndim=1, dtype="float32")) ) - sh: R.Shape = (n + m, n // m) + sh: R.Shape = R.shape([n + m, n // m]) return z x = f.params[0] @@ -870,8 +870,8 @@ def mul_add(x: R.Tensor) -> R.Tensor: def test_memory_op(): @R.function def memory(x: R.Tensor) -> R.Tensor: - storage = R.memory.alloc_storage((1024,), -1, "global", "float32") - alloca = R.memory.alloc_tensor(storage, 0, (1, 256), "float32") + storage = R.memory.alloc_storage(R.shape([1024]), -1, "global", "float32") + alloca = R.memory.alloc_tensor(storage, 0, R.shape([1, 256]), "float32") _ = R.memory.kill_tensor(alloca) _ = R.memory.kill_storage(storage) return alloca diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 059f4e43bc..8b8491bff7 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -292,7 +292,7 @@ class TestVMBuiltinLower: @R.function def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: m, n = T.var("int64"), T.var("int64") - alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32") + alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32") _ = R.call_packed( "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) ) @@ -325,7 +325,7 @@ def test_vm_builtin_lower_reshape(): class TestVMReshape: @R.function def main(x: R.Tensor((3, 4), "float32")): - y = R.reshape(x, (6, 2)) + y = R.reshape(x, R.shape([6, 2])) return y @tvm.script.ir_module @@ -333,7 +333,7 @@ class Expected: @R.function def main(x: R.Tensor((3, 4), "float32")): y: R.Tensor((6, 2), "float32") = R.call_packed( - "vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32") + "vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32") ) return y diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 764f5d1c4c..1db8b8a15b 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -51,20 +51,20 @@ def pad(rxplaceholder: T.Buffer[T.int64(8), "float32"], PadInput: T.Buffer[T.int @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): - alloc: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor((2, 4), dtype="float32", runtime_device_index=0) + alloc: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), dtype="float32", runtime_device_index=0) _: R.Tuple() = exp(x, alloc) lv: R.Tensor((2, 4), dtype="float32") = alloc lv1: R.Tensor((8,), dtype="float32") = R.reshape(lv, (8,)) - alloc1: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor((8,), dtype="float32", runtime_device_index=0) + alloc1: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor(R.shape([8]), dtype="float32", runtime_device_index=0) _1: R.Tuple() = relu(lv1, alloc1) lv2: R.Tensor((8,), dtype="float32") = alloc1 - alloc2: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor((8,), dtype="float32", runtime_device_index=0) + alloc2: R.Tensor((8,), dtype="float32") = R.builtin.alloc_tensor(R.shape([8]), dtype="float32", runtime_device_index=0) _2: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2) lv3: R.Tensor((8,), dtype="float32") = alloc2 - alloc3: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor((10,), dtype="float32", runtime_device_index=0) + alloc3: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) _3: R.Tuple() = pad(lv3, alloc3) lv4: R.Tensor((10,), dtype="float32") = alloc3 - alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor((10,), dtype="float32", runtime_device_index=0) + alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) _4: R.Tuple() = log(lv4, alloc4) gv: R.Tensor((10,), dtype="float32") = alloc4 return gv @@ -97,26 +97,26 @@ def pad(rxplaceholder: T.Buffer[T.int64(8), "float32"], PadInput: T.Buffer[T.int @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): - storage: R.Object = R.memory.alloc_storage((32,), virtual_device_index=0, storage_scope="global", dtype="float32") - alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, (2, 4), dtype="float32") + storage: R.Object = R.memory.alloc_storage(R.shape([32]), virtual_device_index=0, storage_scope="global", dtype="float32") + alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), dtype="float32") _: R.Tuple() = exp(x, alloc) lv: R.Tensor((2, 4), dtype="float32") = alloc lv1: R.Tensor((8,), dtype="float32") = R.reshape(lv, (8,)) - storage1: R.Object = R.memory.alloc_storage((40,), virtual_device_index=0, storage_scope="global", dtype="float32") - alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(storage1, 0, (8,), dtype="float32") + storage1: R.Object = R.memory.alloc_storage(R.shape([40]), virtual_device_index=0, storage_scope="global", dtype="float32") + alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(storage1, 0, R.shape([8]), dtype="float32") _1: R.Tuple() = relu(lv1, alloc1) _2: R.Tuple() = R.memory.kill_tensor(alloc) _3: R.Tuple() = R.memory.kill_tensor(lv1) lv2: R.Tensor((8,), dtype="float32") = alloc1 - alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(storage, 0, (8,), dtype="float32") + alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([8]), dtype="float32") _4: R.Tuple() = add(lv2, R.const(1, "float32"), alloc2) _5: R.Tuple() = R.memory.kill_tensor(alloc1) lv3: R.Tensor((8,), dtype="float32") = alloc2 - alloc3: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor(storage1, 0, (10,), dtype="float32") + alloc3: R.Tensor((10,), dtype="float32") = R.memory.alloc_tensor(storage1, 0, R.shape([10]), dtype="float32") _6: R.Tuple() = pad(lv3, alloc3) _7: R.Tuple() = R.memory.kill_tensor(alloc2) lv4: R.Tensor((10,), dtype="float32") = alloc3 - alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor((10,), dtype="float32", runtime_device_index=0) + alloc4: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) _8: R.Tuple() = log(lv4, alloc4) _9: R.Tuple() = R.memory.kill_tensor(alloc3) gv5: R.Tensor((10,), dtype="float32") = alloc4 @@ -153,12 +153,12 @@ def main( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") ) -> R.Tensor((2, 3), dtype="float32"): alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _: R.Tuple() = add(x, x, alloc) gv: R.Tensor((2, 3), dtype="float32") = alloc alloc1: R.Tensor((2, 3), dtype="int32") = R.builtin.alloc_tensor( - (2, 3), dtype="int32", runtime_device_index=0 + R.shape([2, 3]), dtype="int32", runtime_device_index=0 ) _1: R.Tuple() = add1(y, y, alloc1) gv1: R.Tensor((2, 3), dtype="int32") = alloc1 @@ -187,19 +187,19 @@ def main( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="int32") ) -> R.Tensor((2, 3), dtype="float32"): storage: R.Object = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage, 0, (2, 3), dtype="float32" + storage, 0, R.shape([2, 3]), dtype="float32" ) _: R.Tuple() = add(x, x, alloc) _1: R.Tuple() = R.memory.kill_tensor(alloc) gv1: R.Tensor((2, 3), dtype="float32") = alloc storage1: R.Object = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="int32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="int32" ) alloc1: R.Tensor((2, 3), dtype="int32") = R.memory.alloc_tensor( - storage1, 0, (2, 3), dtype="int32" + storage1, 0, R.shape([2, 3]), dtype="int32" ) _2: R.Tuple() = add1(y, y, alloc1) _3: R.Tuple() = R.memory.kill_tensor(alloc1) @@ -228,12 +228,12 @@ def main( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _: R.Tuple() = add(x, x, alloc) gv: R.Tensor((2, 3), dtype="float32") = alloc alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _1: R.Tuple() = add(y, y, alloc1) gv1: R.Tensor((2, 3), dtype="float32") = alloc1 @@ -254,16 +254,16 @@ def main( x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): storage: R.Object = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage, 0, (2, 3), dtype="float32" + storage, 0, R.shape([2, 3]), dtype="float32" ) _: R.Tuple() = add(x, x, alloc) _1: R.Tuple() = R.memory.kill_tensor(alloc) gv1: R.Tensor((2, 3), dtype="float32") = alloc alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage, 0, (2, 3), dtype="float32" + storage, 0, R.shape([2, 3]), dtype="float32" ) _2: R.Tuple() = add(y, y, alloc1) _3: R.Tuple() = R.memory.kill_tensor(alloc1) @@ -289,7 +289,7 @@ def exp(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(2, 3), "float32"]): @R.function def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): alloc: R.Tensor((), dtype="bool") = R.builtin.alloc_tensor( - (), dtype="bool", runtime_device_index=0 + R.shape([]), dtype="bool", runtime_device_index=0 ) _: R.Tuple() = all_less_than_zero(x, alloc) x1: R.Tensor((), dtype="bool") = alloc @@ -297,7 +297,7 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3 y: R.Tensor((2, 3), dtype="float32") = x else: alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _1: R.Tuple() = exp(x, alloc1) gv3: R.Tensor((2, 3), dtype="float32") = alloc1 @@ -321,7 +321,7 @@ def main( cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _: R.Tuple() = exp(x, alloc) y: R.Tensor((2, 3), dtype="float32") = alloc @@ -348,20 +348,20 @@ def main( cond: R.Tensor((), dtype="bool"), x: R.Tensor((2, 3), dtype="float32") ) -> R.Tensor((2, 3), dtype="float32"): alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _: R.Tuple() = exp(x, alloc) y: R.Tensor((2, 3), dtype="float32") = alloc if cond: alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _1: R.Tuple() = exp(y, alloc1) y2: R.Tensor((2, 3), dtype="float32") = alloc1 z: R.Tensor((2, 3), dtype="float32") = y2 else: alloc2: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _2: R.Tuple() = exp(y, alloc2) y2: R.Tensor((2, 3), dtype="float32") = alloc2 @@ -383,17 +383,17 @@ def exp(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(2, 3), "float32"]): @R.function def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _: R.Tuple() = exp(x, alloc) y1: R.Tensor((2, 3), dtype="float32") = alloc alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _1: R.Tuple() = exp(x, alloc1) y2: R.Tensor((2, 3), dtype="float32") = alloc1 alloc2: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _2: R.Tuple() = exp(x, alloc2) y3: R.Tensor((2, 3), dtype="float32") = alloc2 @@ -412,17 +412,17 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3 y2_: R.Tensor((2, 3), dtype="float32") = nt0[1] y3_: R.Tensor((2, 3), dtype="float32") = nt[1] alloc3: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _3: R.Tuple() = exp(y1_, alloc3) z1: R.Tensor((2, 3), dtype="float32") = alloc3 alloc4: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _4: R.Tuple() = exp(y2_, alloc4) z2: R.Tensor((2, 3), dtype="float32") = alloc4 alloc5: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _5: R.Tuple() = exp(y3_, alloc5) z3: R.Tensor((2, 3), dtype="float32") = alloc5 @@ -437,26 +437,26 @@ def exp(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(2, 3), "float32"]): @R.function def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): storage: R.Object = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage, 0, (2, 3), dtype="float32" + storage, 0, R.shape([2, 3]), dtype="float32" ) _: R.Tuple() = exp(x, alloc) y1: R.Tensor((2, 3), dtype="float32") = alloc storage1: R.Object = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) alloc1: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage1, 0, (2, 3), dtype="float32" + storage1, 0, R.shape([2, 3]), dtype="float32" ) _1: R.Tuple() = exp(x, alloc1) y2: R.Tensor((2, 3), dtype="float32") = alloc1 storage2: R.Object = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) alloc2: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage2, 0, (2, 3), dtype="float32" + storage2, 0, R.shape([2, 3]), dtype="float32" ) _2: R.Tuple() = exp(x, alloc2) y3: R.Tensor((2, 3), dtype="float32") = alloc2 @@ -475,24 +475,24 @@ def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float3 y2_: R.Tensor((2, 3), dtype="float32") = nt0[1] y3_: R.Tensor((2, 3), dtype="float32") = nt[1] storage3: R.Object = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) alloc3: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage3, 0, (2, 3), dtype="float32" + storage3, 0, R.shape([2, 3]), dtype="float32" ) _3: R.Tuple() = exp(y1_, alloc3) _4: R.Tuple() = R.memory.kill_tensor(alloc) _11: R.Tuple() = R.memory.kill_tensor(alloc3) z1: R.Tensor((2, 3), dtype="float32") = alloc3 alloc4: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage, 0, (2, 3), dtype="float32" + storage, 0, R.shape([2, 3]), dtype="float32" ) _41: R.Tuple() = exp(y2_, alloc4) _21: R.Tuple() = R.memory.kill_tensor(alloc1) _31: R.Tuple() = R.memory.kill_tensor(alloc4) z2: R.Tensor((2, 3), dtype="float32") = alloc4 alloc5: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage3, 0, (2, 3), dtype="float32" + storage3, 0, R.shape([2, 3]), dtype="float32" ) _5: R.Tuple() = exp(y3_, alloc5) _42: R.Tuple() = R.memory.kill_tensor(alloc2) @@ -514,7 +514,7 @@ class Module: @R.function def main(x: R.Tensor((2, 3), "float32")): alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) _ = R.add(x, alloc) y: R.Tensor((2, 3), dtype="float32") = alloc @@ -541,7 +541,7 @@ def main(x: R.Tensor(("m", "n"), "float32")): m = T.var("int64") n = T.var("int64") alloc: R.Tensor((m, n), dtype="float32") = R.builtin.alloc_tensor( - (m, n), dtype="float32", runtime_device_index=0 + R.shape([m, n]), dtype="float32", runtime_device_index=0 ) _ = exp(x, alloc) y: R.Tensor((m, n), dtype="float32") = alloc @@ -558,7 +558,7 @@ class Module: @R.function def main(x: R.Tensor((2, 3), "float32")): alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( - (2, 3), dtype="float32", runtime_device_index=0 + R.shape([2, 3]), dtype="float32", runtime_device_index=0 ) return x @@ -567,10 +567,10 @@ class Expected: @R.function def main(x: R.Tensor((2, 3), "float32")): storage: R.Object = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( - storage, 0, (2, 3), dtype="float32" + storage, 0, R.shape([2, 3]), dtype="float32" ) _: R.Tuple() = R.memory.kill_storage(storage) return x @@ -597,7 +597,7 @@ def main( lv: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(x, (2, 25, 2)) lv1: R.Tensor((2, 25, 2), dtype="float32") = R.reshape(y, (2, 25, 2)) alloc: R.Tensor((2, 25, 2), dtype="float32") = R.builtin.alloc_tensor( - (2, 25, 2), dtype="float32", runtime_device_index=0 + R.shape([2, 25, 2]), dtype="float32", runtime_device_index=0 ) _: R.Tuple() = add(lv, lv1, alloc) gv: R.Tensor((2, 25, 2), dtype="float32") = alloc diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 68cf0ea3b9..2c77109d30 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -22,7 +22,6 @@ import tvm.script import tvm.testing from tvm import IRModule, relax, tir, topi -from tvm.relax import DynTensorType from tvm.script.parser import ir as I from tvm.script.parser import relax as R from tvm.script.parser import tir as T @@ -113,7 +112,7 @@ def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"): def test_relax_base_op(): @R.function def foo(x: R.Tensor((4, 4), "float32")): - alloc = R.builtin.alloc_tensor((4, 4), runtime_device_index=0, dtype="float32") + alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype="float32") shape = R.shape_of(alloc) return shape @@ -199,7 +198,7 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): y0 = R.match_cast(y, R.Tensor([n], "float32")) gv = y0 R.output(gv) - return (x0, (m, n * 2)) + return (x0, R.shape([m, n * 2])) x = relax.Var("x", R.Tensor("float32")) y = relax.Var("y", R.Tensor("float32")) @@ -239,7 +238,7 @@ def test_tuple_return_2(): def foo(x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") x0 = R.match_cast(x, R.Tensor((n, m), "float32")) - return (x0, (n + 1, m, 1)) + return (x0, R.shape([n + 1, m, 1])) x = relax.Var("x", R.Tensor("float32", ndim=2)) n, m = tir.Var("n", "int64"), tir.Var("m", "int64") @@ -257,7 +256,7 @@ def foo(x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") x0 = R.match_cast(x, R.Tensor((n, m), "float32")) t0 = (x, x0) - t1 = (x, (n, m), t0) + t1 = (x, R.shape([n, m]), t0) return t1 x = relax.Var("x", R.Tensor("float32", ndim=2)) @@ -934,9 +933,9 @@ def test_vm_ops(): def foo(x: R.Tensor(("m", "n"), dtype="float32")): m = T.var("int64") n = T.var("int64") - storage = R.vm.alloc_storage((4 * m * n,), dtype="float32", runtime_device_index=0) - alloc = R.vm.alloc_tensor(storage, shape=(m, n), offset=0, dtype="float32") - tensor = R.builtin.alloc_tensor((m, n), dtype="float32", runtime_device_index=0) + storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32", runtime_device_index=0) + alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0, dtype="float32") + tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0) _ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n))) gv = tensor return alloc, gv diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index e678063e40..8b8302190e 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -291,7 +291,7 @@ def test_tuple_get_item(): def test_shape_expr(): obj = relax.ShapeExpr([1, 2, 3]) - _assert_print(obj, "(1, 2, 3)") + _assert_print(obj, "R.shape([1, 2, 3])") def test_call(): @@ -303,7 +303,7 @@ def test_call(): """ x = T.Var("x", "int64") a: R.Tensor((1, x, 3), dtype="float32") -R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=(x,)) +R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) """, ) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 52a4d0bec6..26560ba80d 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -86,7 +86,7 @@ class TestVMCompileStage2: def foo(x: R.Tensor(dtype="float32")) -> R.Shape: n, m = T.var("int64"), T.var("int64") _ = R.match_cast(x, R.Tensor((n, m), "float32")) - return (n * 2, m * 3) + return R.shape([n * 2, m * 3]) mod = TestVMCompileStage2 target = tvm.target.Target("llvm", host="llvm") @@ -509,9 +509,9 @@ class TestMemoryAllocStorageTensor: @R.function def main(x: R.Tensor((2, 3), dtype="float32")): storage = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) - y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32") + y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]), dtype="float32") _ = copy(x, y) return y diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index c3aa8fd033..38fa47bc14 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -18,13 +18,15 @@ Restrictions: all shape lowered, explicit allocation. """ -import tvm -import pytest import numpy as np -from tvm import relax, TVMError -from tvm.script import relax as R, tir as T +import pytest +import tvm +import tvm.testing +from tvm import relax +from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode from tvm.relax.testing.vm import check_saved_func -from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode +from tvm.script import relax as R +from tvm.script import tir as T def codegen(mod, target, exec_mode="bytecode"): @@ -310,7 +312,7 @@ class TestVMBuiltinReshape: def main(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "main"}) y = R.call_packed( - "vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32") + "vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32") ) return y @@ -325,3 +327,7 @@ def main(x: R.Tensor((3, 4), "float32")): res = vm["main"](input) expected = input_np.reshape(6, 2) tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-7, atol=1e-7) + + +if __name__ == "__main__": + tvm.testing.main()