From 685bb9753b8d5a0dd2c253fb9e7b21729b5f1b0c Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Fri, 3 Feb 2023 06:44:28 -0800 Subject: [PATCH 01/10] Add layout_transform operator in Relax. --- include/tvm/relax/attrs/manipulate.h | 10 ++ python/tvm/relax/op/manipulate.py | 25 +++- python/tvm/script/ir_builder/relax/ir.py | 2 + src/printer/relax_script_printer.cc | 3 + src/relax/op/tensor/manipulate.cc | 51 ++++++++ src/relax/op/tensor/manipulate.h | 8 ++ tests/python/relax/test_op_manipulate.py | 115 ++++++++++++++++++ .../test_tvmscript_parser_op_manipulate.py | 17 +++ 8 files changed, 229 insertions(+), 2 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 5e82456a87..f347549086 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -25,6 +25,7 @@ #define TVM_RELAX_ATTRS_MANIPULATE_H_ #include +#include namespace tvm { namespace relax { @@ -52,6 +53,15 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { } }; // struct ExpandDimsAttrs +/*! \brief Attributes used in layout_transform operator */ +struct LayoutTransformAttrs : public tvm::AttrsNode { + tir::IndexMap index_map; + + TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") { + TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply."); + } +}; // struct LayoutTransformAttrs + /*! \brief Attributes used in permute_dims operator */ struct PermuteDimsAttrs : public tvm::AttrsNode { Optional> axes; diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 01426e4129..b2b1d9bef3 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. """Manipulation operators.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Callable from tvm.ir.expr import PrimExpr -from tvm.tir import IntImm +from tvm.tir import IntImm, IndexMap from . import _ffi_api from ..expr import Expr, ShapeExpr, Tuple as RxTuple @@ -110,6 +110,27 @@ def flatten(x: Expr) -> Expr: return _ffi_api.flatten(x) # type: ignore +def layout_transform(x: Expr, index_map: Union[Callable, IndexMap]): + """Modifies the layout of a tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor to the operator. + + index_map : Union[Callable, IndexMap] + The transformation to apply. + + Returns + ------- + result : relax.Expr + The transformed tensor. + """ + if callable(index_map): + index_map = IndexMap.from_func(index_map) + return _ffi_api.layout_transform(x, index_map) + + def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr: """Permutes the dimensions of an array. diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ebad4f174b..1755545e02 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -64,6 +64,7 @@ isfinite, isinf, isnan, + layout_transform, less, less_equal, log, @@ -495,6 +496,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "isfinite", "isinf", "isnan", + "layout_transform", "less", "less_equal", "log", diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index 539ec3df62..b949b14f57 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -491,6 +492,8 @@ std::vector RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) { for (const auto& k : dict_attrs->dict) { kwargs.push_back(Doc::Text(k.first) << "=" << Print(k.second)); } + } else if (const LayoutTransformAttrs* layout_attrs = attrs.as()) { + kwargs.push_back(Doc::Text("index_map=") << layout_attrs->index_map->ToPythonString()); } else { AttrPrinter attr_printer(&kwargs, this); const_cast(attrs.operator->())->VisitAttrs(&attr_printer); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 68a525512e..9284c8b995 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -391,6 +391,57 @@ TVM_REGISTER_OP("relax.flatten") .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoFlatten); +/* relax.layout_transform */ +TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); + +Expr layout_transform(Expr x, tir::IndexMap index_map) { + ObjectPtr attrs = make_object(); + attrs->index_map = std::move(index_map); + + static const Op& op = Op::Get("relax.layout_transform"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); + +StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + tir::IndexMap index_map = attrs->index_map; + + if (!data_sinfo->shape.defined()) { + // For unknown input shape, the best we can do is get the #dims in output from index map + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + Array input_shape; + if (const auto* shape_sinfo = GetStructInfoAs(data_sinfo->shape.value())) { + if (!shape_sinfo->values.defined()) { + // For unknown input shape, the best we can do is get the #dims in output from index map + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + } + input_shape = shape_sinfo->values.value(); + } else { + const auto* shape_expr = data_sinfo->shape.as(); + ICHECK(shape_expr); + input_shape = shape_expr->values; + } + + if (input_shape.size() != index_map->initial_indices.size()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "number of dimensions in input must match the number of source dimensions " + "in index map, but got " + << input_shape.size() << " != " << index_map->initial_indices.size()); + } + Array output_shape = index_map->MapShape(input_shape); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.layout_transform") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoLayoutTransform); + /* relax.permute_dims */ TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 7e8a511e9e..82c2acc9ec 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -59,6 +59,14 @@ Expr expand_dims(Expr x, Array axis); */ Expr flatten(Expr x); +/*! + * \brief Transform layout of a tensor. + * \param x The input data to the operator. + * \param index_map The transformation to apply. + * \return The transformed result. + */ +Expr layout_transform(Expr x, tir::IndexMap index_map); + /*! * \brief Permutes the dimensions of an array. * \param x The input data to the operator. diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 6ec983180e..6dca48f46d 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -33,6 +33,9 @@ def test_op_correctness(): assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape") assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split") assert relax.op.squeeze(x).op == Op.get("relax.squeeze") + assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, a)).op == Op.get( + "relax.layout_transform" + ) def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -649,6 +652,118 @@ def test_expand_dims_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.expand_dims(x1, axis=[])) +def test_layout_transform_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + transpose_transform = lambda a, b, c: (a, c, b) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=transpose_transform), + relax.TensorStructInfo((10, 30, 20), "float32"), + ) + + tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_transform), + relax.TensorStructInfo((10, 10, 30, 2), "float32"), + ) + + implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=implicit_padding_transform), + relax.TensorStructInfo((10, 30, 7, 3), "float32"), + ) + + flatten_transform = lambda a, b, c: (a * 600 + b * 30 + c) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=flatten_transform), + relax.TensorStructInfo((6000,), "float32"), + ) + + +def test_layout_transform_infer_struct_info_unknown_shape(): + bb = relax.BlockBuilder() + x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2)) + x_unknown_rank_dtype = relax.Var("x", R.Tensor()) + + tiling_transform = lambda a, b: (a, b // 2, b % 2) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform), + relax.TensorStructInfo(dtype="float32", ndim=3), + ) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform), + relax.TensorStructInfo(dtype="", ndim=3), + ) + + +def test_layout_transform_infer_struct_info_symbolic_shape(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + + tiling_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x0, index_map=tiling_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + + s = relax.Var("s", relax.ShapeStructInfo((30, 20))) + x = relax.Var("x", relax.TensorStructInfo(s, "float32")) + tiling_padding_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_padding_transform), + relax.TensorStructInfo((30, 7, 3), "float32"), + ) + + s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform), + relax.TensorStructInfo(ndim=3, dtype="float32"), + ) + + s_unknown_rank = relax.Var("s", relax.ShapeStructInfo()) + x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform), + relax.TensorStructInfo(ndim=3, dtype="float32"), + ) + + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b))) + x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_symbolic_shape, index_map=tiling_padding_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_invalid_index_map(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=lambda a, b: (b, a))) + + def test_squeeze_infer_struct_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py index 68dcc0c820..8b2b47096c 100644 --- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -119,6 +119,23 @@ def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((60,), "float32"): _check(foo, bb.get()["foo"]) +def test_layout_transform(): + transformation = lambda n, c, h, w: (n, h, w, c) + + @R.function + def foo(x: R.Tensor((2, 3, 4, 5), "float32")): + gv: R.Tensor((2, 4, 5, 3), "float32") = R.layout_transform(x, index_map=transformation) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_permute_dims(): @R.function def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float32"): From 7e9cee2b2949fe88ecbaa954ce74f832feb37866 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Fri, 3 Feb 2023 07:35:34 -0800 Subject: [PATCH 02/10] Add optional pad_value attribute to op. --- include/tvm/relax/attrs/manipulate.h | 4 ++++ python/tvm/relax/op/manipulate.py | 6 +++--- src/printer/relax_script_printer.cc | 3 +++ src/relax/op/tensor/manipulate.cc | 3 ++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index f347549086..54f32017b1 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -56,9 +56,13 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { /*! \brief Attributes used in layout_transform operator */ struct LayoutTransformAttrs : public tvm::AttrsNode { tir::IndexMap index_map; + Optional pad_value; TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") { TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply."); + TVM_ATTR_FIELD(pad_value).describe( + "The specific value to be used to pad if the layout transform would result in implicit " + "padding. If not specified, the compiler is free to choose any value."); } }; // struct LayoutTransformAttrs diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index b2b1d9bef3..1462aa4f0b 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -21,7 +21,7 @@ from tvm.tir import IntImm, IndexMap from . import _ffi_api -from ..expr import Expr, ShapeExpr, Tuple as RxTuple +from ..expr import Expr, DataTypeImm, ShapeExpr, Tuple as RxTuple PrimExprLike = Union[int, PrimExpr] @@ -110,7 +110,7 @@ def flatten(x: Expr) -> Expr: return _ffi_api.flatten(x) # type: ignore -def layout_transform(x: Expr, index_map: Union[Callable, IndexMap]): +def layout_transform(x: Expr, index_map: Union[Callable, IndexMap], pad_value: Optional[DataTypeImm]): """Modifies the layout of a tensor. Parameters @@ -128,7 +128,7 @@ def layout_transform(x: Expr, index_map: Union[Callable, IndexMap]): """ if callable(index_map): index_map = IndexMap.from_func(index_map) - return _ffi_api.layout_transform(x, index_map) + return _ffi_api.layout_transform(x, index_map, pad_value) def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr: diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index b949b14f57..abc0066a22 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -494,6 +494,9 @@ std::vector RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) { } } else if (const LayoutTransformAttrs* layout_attrs = attrs.as()) { kwargs.push_back(Doc::Text("index_map=") << layout_attrs->index_map->ToPythonString()); + if (layout_attrs->pad_value.defined()) { + kwargs.push_back(Doc::Text("pad_value=") << Print(layout_attrs->pad_value.value())); + } } else { AttrPrinter attr_printer(&kwargs, this); const_cast(attrs.operator->())->VisitAttrs(&attr_printer); diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 9284c8b995..3cf58c09dc 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -394,9 +394,10 @@ TVM_REGISTER_OP("relax.flatten") /* relax.layout_transform */ TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); -Expr layout_transform(Expr x, tir::IndexMap index_map) { +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value) { ObjectPtr attrs = make_object(); attrs->index_map = std::move(index_map); + attrs->pad_value = std::move(pad_value); static const Op& op = Op::Get("relax.layout_transform"); return Call(op, {std::move(x)}, Attrs{attrs}, {}); From 5b2828edcba968f17fa23db87963fa59a61b05c2 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Fri, 3 Feb 2023 07:41:25 -0800 Subject: [PATCH 03/10] Add documentation for pad_value. --- python/tvm/relax/op/manipulate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 1462aa4f0b..fea325357d 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -110,7 +110,9 @@ def flatten(x: Expr) -> Expr: return _ffi_api.flatten(x) # type: ignore -def layout_transform(x: Expr, index_map: Union[Callable, IndexMap], pad_value: Optional[DataTypeImm]): +def layout_transform( + x: Expr, index_map: Union[Callable, IndexMap], pad_value: Optional[DataTypeImm] +): """Modifies the layout of a tensor. Parameters @@ -121,6 +123,10 @@ def layout_transform(x: Expr, index_map: Union[Callable, IndexMap], pad_value: O index_map : Union[Callable, IndexMap] The transformation to apply. + pad_value : Optional[DataTypeImm] + The value used for padding if the transformation results in implicit padding. + If not specified, any value can be used. + Returns ------- result : relax.Expr From 3e2bec205b8a334516306d24f7f83baf5610277d Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Fri, 3 Feb 2023 08:20:27 -0800 Subject: [PATCH 04/10] Change pad_value to be of PrimValue type --- include/tvm/relax/attrs/manipulate.h | 2 +- python/tvm/relax/op/manipulate.py | 8 +++++--- src/relax/op/tensor/manipulate.cc | 7 +++++-- tests/python/relax/test_op_manipulate.py | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 54f32017b1..eb704d9c28 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -56,7 +56,7 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { /*! \brief Attributes used in layout_transform operator */ struct LayoutTransformAttrs : public tvm::AttrsNode { tir::IndexMap index_map; - Optional pad_value; + Optional pad_value; TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") { TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply."); diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index fea325357d..c1aec5be69 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -21,7 +21,7 @@ from tvm.tir import IntImm, IndexMap from . import _ffi_api -from ..expr import Expr, DataTypeImm, ShapeExpr, Tuple as RxTuple +from ..expr import Expr, PrimValue, ShapeExpr, Tuple as RxTuple PrimExprLike = Union[int, PrimExpr] @@ -111,7 +111,7 @@ def flatten(x: Expr) -> Expr: def layout_transform( - x: Expr, index_map: Union[Callable, IndexMap], pad_value: Optional[DataTypeImm] + x: Expr, index_map: Union[Callable, IndexMap], pad_value: Optional[Union[int, PrimValue]] = None ): """Modifies the layout of a tensor. @@ -123,7 +123,7 @@ def layout_transform( index_map : Union[Callable, IndexMap] The transformation to apply. - pad_value : Optional[DataTypeImm] + pad_value : Optional[Union[int, float, PrimValue]] The value used for padding if the transformation results in implicit padding. If not specified, any value can be used. @@ -134,6 +134,8 @@ def layout_transform( """ if callable(index_map): index_map = IndexMap.from_func(index_map) + if isinstance(pad_value, int): + pad_value = PrimValue(pad_value) return _ffi_api.layout_transform(x, index_map, pad_value) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 3cf58c09dc..2182186546 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -394,7 +394,7 @@ TVM_REGISTER_OP("relax.flatten") /* relax.layout_transform */ TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); -Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value) { +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value) { ObjectPtr attrs = make_object(); attrs->index_map = std::move(index_map); attrs->pad_value = std::move(pad_value); @@ -403,7 +403,10 @@ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); +TVM_REGISTER_GLOBAL("relax.op.layout_transform") + .set_body_typed([](Expr x, tir::IndexMap index_map, Optional pad_value) { + return layout_transform(x, index_map, pad_value); + }); StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 6dca48f46d..58b168e2f0 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -673,7 +673,7 @@ def test_layout_transform_infer_struct_info(): implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3) _check_inference( bb, - relax.op.layout_transform(x, index_map=implicit_padding_transform), + relax.op.layout_transform(x, index_map=implicit_padding_transform, pad_value=2), relax.TensorStructInfo((10, 30, 7, 3), "float32"), ) From d2b3170bbff191be785383adf2787ff63bba2192 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Fri, 3 Feb 2023 11:58:02 -0800 Subject: [PATCH 05/10] Fix python side conversion of pad_value to PrimValue. --- python/tvm/relax/op/manipulate.py | 18 +++++++++++++--- python/tvm/relax/op/op_attrs.py | 5 +++++ src/printer/relax_script_printer.cc | 4 ++++ src/printer/text_printer.h | 1 + .../test_tvmscript_parser_op_manipulate.py | 21 ++++++++++++++++++- 5 files changed, 45 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index c1aec5be69..5ec070f969 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -18,7 +18,7 @@ from typing import List, Optional, Tuple, Union, Callable from tvm.ir.expr import PrimExpr -from tvm.tir import IntImm, IndexMap +from tvm.tir import IntImm, FloatImm, IndexMap from . import _ffi_api from ..expr import Expr, PrimValue, ShapeExpr, Tuple as RxTuple @@ -111,7 +111,9 @@ def flatten(x: Expr) -> Expr: def layout_transform( - x: Expr, index_map: Union[Callable, IndexMap], pad_value: Optional[Union[int, PrimValue]] = None + x: Expr, + index_map: Union[Callable, IndexMap], + pad_value: Optional[Union[int, float, PrimValue]] = None, ): """Modifies the layout of a tensor. @@ -134,7 +136,17 @@ def layout_transform( """ if callable(index_map): index_map = IndexMap.from_func(index_map) - if isinstance(pad_value, int): + x_dtype = x.checked_type.dtype + + # Explicitly convert python int/float pad_value to the x's type. If the default behavior + # is applied, it would be converted to int32/float32, which may not match the x's type. + if pad_value is None: + pass + elif not isinstance(pad_value, PrimValue): + if "int" in x_dtype and isinstance(pad_value, int): + pad_value = IntImm(x_dtype, pad_value) + elif "float" in x_dtype and (isinstance(pad_value, float) or isinstance(pad_value, int)): + pad_value = FloatImm(x_dtype, float(pad_value)) pad_value = PrimValue(pad_value) return _ffi_api.layout_transform(x, index_map, pad_value) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 45f6e74ec4..f61549a6f1 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -139,6 +139,11 @@ class BatchNormAttrs(Attrs): """Attributes used in batch_norm operator""" +@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") +class LayoutTransformAttrs(Attrs): + """Attributes used in layout_transform operator""" + + @tvm._ffi.register_object("relax.attrs.LayerNormAttrs") class LayerNormAttrs(Attrs): """Attributes used in layer_norm operator""" diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index abc0066a22..aa60cdb9e1 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -401,6 +401,10 @@ Doc RelaxScriptPrinter::VisitExpr_(const tir::IntImmNode* op) { return Doc::Text(std::to_string(op->value)); } +Doc RelaxScriptPrinter::VisitExpr_(const tir::FloatImmNode* op) { + return Doc::Text(std::to_string(op->value)); +} + #define TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(OpName, OpString) \ Doc RelaxScriptPrinter::VisitExpr_(const OpName* op) { \ Doc doc; \ diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 601dfb0bd9..1c44672583 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -290,6 +290,7 @@ class RelaxScriptPrinter : public relax::IRFunctor, // PrimExpr nodes allowed in Relax Doc VisitExpr_(const tir::VarNode* op) override; Doc VisitExpr_(const tir::IntImmNode* op) override; + Doc VisitExpr_(const tir::FloatImmNode* op) override; Doc VisitExpr_(const tir::AddNode* op) override; Doc VisitExpr_(const tir::SubNode* op) override; Doc VisitExpr_(const tir::MulNode* op) override; diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py index 8b2b47096c..44003df853 100644 --- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -132,7 +132,26 @@ def foo(x: R.Tensor((2, 3, 4, 5), "float32")): with bb.function("foo", [x]): gv = bb.emit(relax.op.layout_transform(x, index_map=transformation)) bb.emit_func_output(gv) - + + _check(foo, bb.get()["foo"]) + + +def test_layout_transform_with_padding(): + transformation = lambda n, c, h, w: (n, c // 3, h, w, c % 3) + + @R.function + def foo(x: R.Tensor((10, 20, 2, 2), "float32")): + gv: R.Tensor((10, 7, 2, 2, 3), "float32") = R.layout_transform( + x, index_map=transformation, pad_value=2 + ) + return gv + + x = relax.Var("x", R.Tensor((10, 20, 2, 2), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation, pad_value=2)) + bb.emit_func_output(gv) + _check(foo, bb.get()["foo"]) From 5b607fb3f5b9e21f8dd469f44bb82aec4faabc90 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Fri, 3 Feb 2023 12:11:02 -0800 Subject: [PATCH 06/10] fix lint --- python/tvm/relax/op/manipulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 5ec070f969..0fb0d433bf 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -145,7 +145,7 @@ def layout_transform( elif not isinstance(pad_value, PrimValue): if "int" in x_dtype and isinstance(pad_value, int): pad_value = IntImm(x_dtype, pad_value) - elif "float" in x_dtype and (isinstance(pad_value, float) or isinstance(pad_value, int)): + elif "float" in x_dtype and (isinstance(pad_value, (int, float))): pad_value = FloatImm(x_dtype, float(pad_value)) pad_value = PrimValue(pad_value) return _ffi_api.layout_transform(x, index_map, pad_value) From f6a1a40063e0b2e33a6e67cb94ff3c25f0a04bd1 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Fri, 3 Feb 2023 12:55:49 -0800 Subject: [PATCH 07/10] fix lint --- python/tvm/relax/op/manipulate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 0fb0d433bf..a46c62e1f1 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -148,7 +148,7 @@ def layout_transform( elif "float" in x_dtype and (isinstance(pad_value, (int, float))): pad_value = FloatImm(x_dtype, float(pad_value)) pad_value = PrimValue(pad_value) - return _ffi_api.layout_transform(x, index_map, pad_value) + return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr: From f170dd7340e6b34d606c04d15e31f4ccdc6f2677 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Fri, 3 Feb 2023 13:52:30 -0800 Subject: [PATCH 08/10] fix signature --- src/relax/op/tensor/manipulate.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 82c2acc9ec..6a2b23ecbd 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -63,9 +63,11 @@ Expr flatten(Expr x); * \brief Transform layout of a tensor. * \param x The input data to the operator. * \param index_map The transformation to apply. + * \param pad_value The value used for padding if the transformation results in implicit padding. If + * not specified, any value can be used. * \return The transformed result. */ -Expr layout_transform(Expr x, tir::IndexMap index_map); +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value); /*! * \brief Permutes the dimensions of an array. From 151d59f4461fce3005b9c8c86baff86f9195b3ba Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Tue, 7 Feb 2023 06:48:21 -0800 Subject: [PATCH 09/10] Address comments and add a test about mismatched dtype. --- include/tvm/relax/attrs/manipulate.h | 2 + src/relax/op/tensor/manipulate.cc | 49 ++++++++++++++---------- tests/python/relax/test_op_manipulate.py | 9 +++++ 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index eb704d9c28..bd6ae17bcf 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -56,6 +56,8 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { /*! \brief Attributes used in layout_transform operator */ struct LayoutTransformAttrs : public tvm::AttrsNode { tir::IndexMap index_map; + // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This + // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. Optional pad_value; TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") { diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 2182186546..db36549f37 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -403,39 +403,48 @@ Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_v return Call(op, {std::move(x)}, Attrs{attrs}, {}); } -TVM_REGISTER_GLOBAL("relax.op.layout_transform") - .set_body_typed([](Expr x, tir::IndexMap index_map, Optional pad_value) { - return layout_transform(x, index_map, pad_value); - }); +TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); const auto* attrs = call->attrs.as(); tir::IndexMap index_map = attrs->index_map; + Optional optional_pad_value = attrs->pad_value; - if (!data_sinfo->shape.defined()) { - // For unknown input shape, the best we can do is get the #dims in output from index map - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); - } - Array input_shape; - if (const auto* shape_sinfo = GetStructInfoAs(data_sinfo->shape.value())) { - if (!shape_sinfo->values.defined()) { - // For unknown input shape, the best we can do is get the #dims in output from index map - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + // Check pad_value has same dtype as input. + if (optional_pad_value.defined()) { + PrimExpr padded_value = optional_pad_value.value()->value; + if (padded_value->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "layout_transform pad_value dtype (" << padded_value->dtype + << ") and input dtype (" << data_sinfo->dtype << ") must be the same"); + return {}; } - input_shape = shape_sinfo->values.value(); - } else { - const auto* shape_expr = data_sinfo->shape.as(); - ICHECK(shape_expr); - input_shape = shape_expr->values; } - if (input_shape.size() != index_map->initial_indices.size()) { + // For unknown input shape, the best we can do is get the #dims in output from index map + auto inferred_sinfo_for_unknown_shape = [&]() { + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + }; + + if (data_sinfo->IsUnknownNdim()) return inferred_sinfo_for_unknown_shape(); + + // If rank is known, check that it is compatible with the index_map, i.e., #dims match. + if (index_map->initial_indices.size() != static_cast(data_sinfo->ndim)) { ctx->ReportFatal(Diagnostic::Error(call) << "number of dimensions in input must match the number of source dimensions " "in index map, but got " - << input_shape.size() << " != " << index_map->initial_indices.size()); + << data_sinfo->ndim << " != " << index_map->initial_indices.size()); } + + if (!data_sinfo->shape.defined()) return inferred_sinfo_for_unknown_shape(); + + // If input shape is known, get the ShapeStructInfo of the shape expr. + const auto* shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); + + if (!shape_sinfo->values.defined()) return inferred_sinfo_for_unknown_shape(); + + Array input_shape = shape_sinfo->values.value(); Array output_shape = index_map->MapShape(input_shape); return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); } diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 58b168e2f0..d2c6a93218 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -685,6 +685,15 @@ def test_layout_transform_infer_struct_info(): ) +def test_layout_transform_infer_struct_info_mismatch_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "int32")) + + transpose_transform = lambda a, b, c: (a, c, b) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=transpose_transform, pad_value=2.2)) + + def test_layout_transform_infer_struct_info_unknown_shape(): bb = relax.BlockBuilder() x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2)) From 9352d45554d0af99c5388a3c964cb93c5d98ac23 Mon Sep 17 00:00:00 2001 From: Prakalp Srivastava Date: Tue, 7 Feb 2023 08:16:22 -0800 Subject: [PATCH 10/10] Report error for any unknown shape. --- src/relax/op/tensor/manipulate.cc | 25 ++++++++++++------ tests/python/relax/test_op_manipulate.py | 33 ++++++++---------------- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index db36549f37..65a3c0e601 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -418,16 +418,25 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx->ReportFatal(Diagnostic::Error(call) << "layout_transform pad_value dtype (" << padded_value->dtype << ") and input dtype (" << data_sinfo->dtype << ") must be the same"); - return {}; } } - // For unknown input shape, the best we can do is get the #dims in output from index map - auto inferred_sinfo_for_unknown_shape = [&]() { - return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); - }; + // We would like to ensure safety, and therefore placed a stronger requirement for user to + // use MatchCast before layout_transform if the shape of input is not known at compile time. + // Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit + // this requirement to reduce the workload of importers and better support dynamic shapes. + auto report_error_for_unknown_shape = + [&]() { + ctx->ReportFatal( + Diagnostic::Error(call) + << "layout_transform expects the input tensor to have known rank (expected rank = " + << index_map->initial_indices.size() + << ") and shape. For input tensors, whose shape cannot be determined at compile time, " + "please use MatchCast to get input with symbolic shape."); + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + }; - if (data_sinfo->IsUnknownNdim()) return inferred_sinfo_for_unknown_shape(); + if (data_sinfo->IsUnknownNdim()) return report_error_for_unknown_shape(); // If rank is known, check that it is compatible with the index_map, i.e., #dims match. if (index_map->initial_indices.size() != static_cast(data_sinfo->ndim)) { @@ -437,12 +446,12 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& << data_sinfo->ndim << " != " << index_map->initial_indices.size()); } - if (!data_sinfo->shape.defined()) return inferred_sinfo_for_unknown_shape(); + if (!data_sinfo->shape.defined()) return report_error_for_unknown_shape(); // If input shape is known, get the ShapeStructInfo of the shape expr. const auto* shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); - if (!shape_sinfo->values.defined()) return inferred_sinfo_for_unknown_shape(); + if (!shape_sinfo->values.defined()) return report_error_for_unknown_shape(); Array input_shape = shape_sinfo->values.value(); Array output_shape = index_map->MapShape(input_shape); diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index d2c6a93218..02f2547680 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -696,20 +696,15 @@ def test_layout_transform_infer_struct_info_mismatch_dtype(): def test_layout_transform_infer_struct_info_unknown_shape(): bb = relax.BlockBuilder() + tiling_transform = lambda a, b: (a, b // 2, b % 2) + x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2)) - x_unknown_rank_dtype = relax.Var("x", R.Tensor()) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform)) - tiling_transform = lambda a, b: (a, b // 2, b % 2) - _check_inference( - bb, - relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform), - relax.TensorStructInfo(dtype="float32", ndim=3), - ) - _check_inference( - bb, - relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform), - relax.TensorStructInfo(dtype="", ndim=3), - ) + x_unknown_rank_dtype = relax.Var("x", R.Tensor()) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform)) def test_layout_transform_infer_struct_info_symbolic_shape(): @@ -740,19 +735,13 @@ def test_layout_transform_infer_struct_info_shape_var(): s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2)) x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32")) - _check_inference( - bb, - relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform), - relax.TensorStructInfo(ndim=3, dtype="float32"), - ) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform)) s_unknown_rank = relax.Var("s", relax.ShapeStructInfo()) x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32")) - _check_inference( - bb, - relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform), - relax.TensorStructInfo(ndim=3, dtype="float32"), - ) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform)) a = tir.Var("a", "int64") b = tir.Var("b", "int64")