diff --git a/include/tvm/relax/attrs/manipulate.h b/include/tvm/relax/attrs/manipulate.h index 5e82456a87..bd6ae17bcf 100644 --- a/include/tvm/relax/attrs/manipulate.h +++ b/include/tvm/relax/attrs/manipulate.h @@ -25,6 +25,7 @@ #define TVM_RELAX_ATTRS_MANIPULATE_H_ #include +#include namespace tvm { namespace relax { @@ -52,6 +53,21 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { } }; // struct ExpandDimsAttrs +/*! \brief Attributes used in layout_transform operator */ +struct LayoutTransformAttrs : public tvm::AttrsNode { + tir::IndexMap index_map; + // pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This + // needs to be revisited in case PrimValue is evolved to represent symbolic expression in future. + Optional pad_value; + + TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") { + TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply."); + TVM_ATTR_FIELD(pad_value).describe( + "The specific value to be used to pad if the layout transform would result in implicit " + "padding. If not specified, the compiler is free to choose any value."); + } +}; // struct LayoutTransformAttrs + /*! \brief Attributes used in permute_dims operator */ struct PermuteDimsAttrs : public tvm::AttrsNode { Optional> axes; diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 01426e4129..a46c62e1f1 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. """Manipulation operators.""" -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Callable from tvm.ir.expr import PrimExpr -from tvm.tir import IntImm +from tvm.tir import IntImm, FloatImm, IndexMap from . import _ffi_api -from ..expr import Expr, ShapeExpr, Tuple as RxTuple +from ..expr import Expr, PrimValue, ShapeExpr, Tuple as RxTuple PrimExprLike = Union[int, PrimExpr] @@ -110,6 +110,47 @@ def flatten(x: Expr) -> Expr: return _ffi_api.flatten(x) # type: ignore +def layout_transform( + x: Expr, + index_map: Union[Callable, IndexMap], + pad_value: Optional[Union[int, float, PrimValue]] = None, +): + """Modifies the layout of a tensor. + + Parameters + ---------- + x : relax.Expr + The input tensor to the operator. + + index_map : Union[Callable, IndexMap] + The transformation to apply. + + pad_value : Optional[Union[int, float, PrimValue]] + The value used for padding if the transformation results in implicit padding. + If not specified, any value can be used. + + Returns + ------- + result : relax.Expr + The transformed tensor. + """ + if callable(index_map): + index_map = IndexMap.from_func(index_map) + x_dtype = x.checked_type.dtype + + # Explicitly convert python int/float pad_value to the x's type. If the default behavior + # is applied, it would be converted to int32/float32, which may not match the x's type. + if pad_value is None: + pass + elif not isinstance(pad_value, PrimValue): + if "int" in x_dtype and isinstance(pad_value, int): + pad_value = IntImm(x_dtype, pad_value) + elif "float" in x_dtype and (isinstance(pad_value, (int, float))): + pad_value = FloatImm(x_dtype, float(pad_value)) + pad_value = PrimValue(pad_value) + return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore + + def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr: """Permutes the dimensions of an array. diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 45f6e74ec4..f61549a6f1 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -139,6 +139,11 @@ class BatchNormAttrs(Attrs): """Attributes used in batch_norm operator""" +@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs") +class LayoutTransformAttrs(Attrs): + """Attributes used in layout_transform operator""" + + @tvm._ffi.register_object("relax.attrs.LayerNormAttrs") class LayerNormAttrs(Attrs): """Attributes used in layer_norm operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 97ff216b84..ccefcc3a4b 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -65,6 +65,7 @@ isfinite, isinf, isnan, + layout_transform, less, less_equal, linear, @@ -498,6 +499,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "isfinite", "isinf", "isnan", + "layout_transform", "less", "less_equal", "linear", diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 68a525512e..65a3c0e601 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -391,6 +391,79 @@ TVM_REGISTER_OP("relax.flatten") .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoFlatten); +/* relax.layout_transform */ +TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs); + +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value) { + ObjectPtr attrs = make_object(); + attrs->index_map = std::move(index_map); + attrs->pad_value = std::move(pad_value); + + static const Op& op = Op::Get("relax.layout_transform"); + return Call(op, {std::move(x)}, Attrs{attrs}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform); + +StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + tir::IndexMap index_map = attrs->index_map; + Optional optional_pad_value = attrs->pad_value; + + // Check pad_value has same dtype as input. + if (optional_pad_value.defined()) { + PrimExpr padded_value = optional_pad_value.value()->value; + if (padded_value->dtype != data_sinfo->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "layout_transform pad_value dtype (" << padded_value->dtype + << ") and input dtype (" << data_sinfo->dtype << ") must be the same"); + } + } + + // We would like to ensure safety, and therefore placed a stronger requirement for user to + // use MatchCast before layout_transform if the shape of input is not known at compile time. + // Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit + // this requirement to reduce the workload of importers and better support dynamic shapes. + auto report_error_for_unknown_shape = + [&]() { + ctx->ReportFatal( + Diagnostic::Error(call) + << "layout_transform expects the input tensor to have known rank (expected rank = " + << index_map->initial_indices.size() + << ") and shape. For input tensors, whose shape cannot be determined at compile time, " + "please use MatchCast to get input with symbolic shape."); + return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size()); + }; + + if (data_sinfo->IsUnknownNdim()) return report_error_for_unknown_shape(); + + // If rank is known, check that it is compatible with the index_map, i.e., #dims match. + if (index_map->initial_indices.size() != static_cast(data_sinfo->ndim)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "number of dimensions in input must match the number of source dimensions " + "in index map, but got " + << data_sinfo->ndim << " != " << index_map->initial_indices.size()); + } + + if (!data_sinfo->shape.defined()) return report_error_for_unknown_shape(); + + // If input shape is known, get the ShapeStructInfo of the shape expr. + const auto* shape_sinfo = GetStructInfoAs(data_sinfo->shape.value()); + + if (!shape_sinfo->values.defined()) return report_error_for_unknown_shape(); + + Array input_shape = shape_sinfo->values.value(); + Array output_shape = index_map->MapShape(input_shape); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.layout_transform") + .set_num_inputs(1) + .set_attrs_type() + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoLayoutTransform); + /* relax.permute_dims */ TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs); diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 7e8a511e9e..6a2b23ecbd 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -59,6 +59,16 @@ Expr expand_dims(Expr x, Array axis); */ Expr flatten(Expr x); +/*! + * \brief Transform layout of a tensor. + * \param x The input data to the operator. + * \param index_map The transformation to apply. + * \param pad_value The value used for padding if the transformation results in implicit padding. If + * not specified, any value can be used. + * \return The transformed result. + */ +Expr layout_transform(Expr x, tir::IndexMap index_map, Optional pad_value); + /*! * \brief Permutes the dimensions of an array. * \param x The input data to the operator. diff --git a/src/relay/printer/relax_script_printer.cc b/src/relay/printer/relax_script_printer.cc index 539ec3df62..aa60cdb9e1 100644 --- a/src/relay/printer/relax_script_printer.cc +++ b/src/relay/printer/relax_script_printer.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -400,6 +401,10 @@ Doc RelaxScriptPrinter::VisitExpr_(const tir::IntImmNode* op) { return Doc::Text(std::to_string(op->value)); } +Doc RelaxScriptPrinter::VisitExpr_(const tir::FloatImmNode* op) { + return Doc::Text(std::to_string(op->value)); +} + #define TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(OpName, OpString) \ Doc RelaxScriptPrinter::VisitExpr_(const OpName* op) { \ Doc doc; \ @@ -491,6 +496,11 @@ std::vector RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) { for (const auto& k : dict_attrs->dict) { kwargs.push_back(Doc::Text(k.first) << "=" << Print(k.second)); } + } else if (const LayoutTransformAttrs* layout_attrs = attrs.as()) { + kwargs.push_back(Doc::Text("index_map=") << layout_attrs->index_map->ToPythonString()); + if (layout_attrs->pad_value.defined()) { + kwargs.push_back(Doc::Text("pad_value=") << Print(layout_attrs->pad_value.value())); + } } else { AttrPrinter attr_printer(&kwargs, this); const_cast(attrs.operator->())->VisitAttrs(&attr_printer); diff --git a/src/relay/printer/text_printer.h b/src/relay/printer/text_printer.h index 9329f9daeb..8dee474d96 100644 --- a/src/relay/printer/text_printer.h +++ b/src/relay/printer/text_printer.h @@ -287,6 +287,7 @@ class RelaxScriptPrinter : public relax::IRFunctor, // PrimExpr nodes allowed in Relax Doc VisitExpr_(const tir::VarNode* op) override; Doc VisitExpr_(const tir::IntImmNode* op) override; + Doc VisitExpr_(const tir::FloatImmNode* op) override; Doc VisitExpr_(const tir::AddNode* op) override; Doc VisitExpr_(const tir::SubNode* op) override; Doc VisitExpr_(const tir::MulNode* op) override; diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 66fef07617..ce4f4cdeb7 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -33,6 +33,9 @@ def test_op_correctness(): assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape") assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split") assert relax.op.squeeze(x).op == Op.get("relax.squeeze") + assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, a)).op == Op.get( + "relax.layout_transform" + ) def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -657,6 +660,116 @@ def test_expand_dims_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.expand_dims(x1, axis=[])) +def test_layout_transform_infer_struct_info(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + transpose_transform = lambda a, b, c: (a, c, b) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=transpose_transform), + relax.TensorStructInfo((10, 30, 20), "float32"), + ) + + tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_transform), + relax.TensorStructInfo((10, 10, 30, 2), "float32"), + ) + + implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=implicit_padding_transform, pad_value=2), + relax.TensorStructInfo((10, 30, 7, 3), "float32"), + ) + + flatten_transform = lambda a, b, c: (a * 600 + b * 30 + c) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=flatten_transform), + relax.TensorStructInfo((6000,), "float32"), + ) + + +def test_layout_transform_infer_struct_info_mismatch_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "int32")) + + transpose_transform = lambda a, b, c: (a, c, b) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=transpose_transform, pad_value=2.2)) + + +def test_layout_transform_infer_struct_info_unknown_shape(): + bb = relax.BlockBuilder() + tiling_transform = lambda a, b: (a, b // 2, b % 2) + + x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform)) + + x_unknown_rank_dtype = relax.Var("x", R.Tensor()) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform)) + + +def test_layout_transform_infer_struct_info_symbolic_shape(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((a, b), "float32")) + + tiling_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x0, index_map=tiling_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + + s = relax.Var("s", relax.ShapeStructInfo((30, 20))) + x = relax.Var("x", relax.TensorStructInfo(s, "float32")) + tiling_padding_transform = lambda a, b: (a, b // 3, b % 3) + _check_inference( + bb, + relax.op.layout_transform(x, index_map=tiling_padding_transform), + relax.TensorStructInfo((30, 7, 3), "float32"), + ) + + s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform)) + + s_unknown_rank = relax.Var("s", relax.ShapeStructInfo()) + x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform)) + + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b))) + x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32")) + _check_inference( + bb, + relax.op.layout_transform(x_symbolic_shape, index_map=tiling_padding_transform), + relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"), + ) + + +def test_layout_transform_infer_struct_info_invalid_index_map(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((10, 20, 30), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.layout_transform(x, index_map=lambda a, b: (b, a))) + + def test_squeeze_infer_struct_info(): bb = relax.BlockBuilder() x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32")) diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py index 68dcc0c820..44003df853 100644 --- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -119,6 +119,42 @@ def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((60,), "float32"): _check(foo, bb.get()["foo"]) +def test_layout_transform(): + transformation = lambda n, c, h, w: (n, h, w, c) + + @R.function + def foo(x: R.Tensor((2, 3, 4, 5), "float32")): + gv: R.Tensor((2, 4, 5, 3), "float32") = R.layout_transform(x, index_map=transformation) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_layout_transform_with_padding(): + transformation = lambda n, c, h, w: (n, c // 3, h, w, c % 3) + + @R.function + def foo(x: R.Tensor((10, 20, 2, 2), "float32")): + gv: R.Tensor((10, 7, 2, 2, 3), "float32") = R.layout_transform( + x, index_map=transformation, pad_value=2 + ) + return gv + + x = relax.Var("x", R.Tensor((10, 20, 2, 2), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.layout_transform(x, index_map=transformation, pad_value=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_permute_dims(): @R.function def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((2, 4, 3, 1), "float32"):