Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[Op][Layout] Add layout_transform operator in Relax. #403

Merged
merged 10 commits into from
Feb 7, 2023
14 changes: 14 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_RELAX_ATTRS_MANIPULATE_H_

#include <tvm/relax/expr.h>
#include <tvm/tir/index_map.h>

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -52,6 +53,19 @@ struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
}
}; // struct ExpandDimsAttrs

/*! \brief Attributes used in layout_transform operator */
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
tir::IndexMap index_map;
Optional<PrimValue> pad_value;
psrivas2 marked this conversation as resolved.
Show resolved Hide resolved

TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
TVM_ATTR_FIELD(pad_value).describe(
"The specific value to be used to pad if the layout transform would result in implicit "
"padding. If not specified, the compiler is free to choose any value.");
}
}; // struct LayoutTransformAttrs

/*! \brief Attributes used in permute_dims operator */
struct PermuteDimsAttrs : public tvm::AttrsNode<PermuteDimsAttrs> {
Optional<Array<Integer>> axes;
Expand Down
47 changes: 44 additions & 3 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Manipulation operators."""
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Callable

from tvm.ir.expr import PrimExpr
from tvm.tir import IntImm
from tvm.tir import IntImm, FloatImm, IndexMap

from . import _ffi_api
from ..expr import Expr, ShapeExpr, Tuple as RxTuple
from ..expr import Expr, PrimValue, ShapeExpr, Tuple as RxTuple


PrimExprLike = Union[int, PrimExpr]
Expand Down Expand Up @@ -110,6 +110,47 @@ def flatten(x: Expr) -> Expr:
return _ffi_api.flatten(x) # type: ignore


def layout_transform(
x: Expr,
index_map: Union[Callable, IndexMap],
pad_value: Optional[Union[int, float, PrimValue]] = None,
):
"""Modifies the layout of a tensor.

Parameters
----------
x : relax.Expr
The input tensor to the operator.

index_map : Union[Callable, IndexMap]
The transformation to apply.

pad_value : Optional[Union[int, float, PrimValue]]
The value used for padding if the transformation results in implicit padding.
If not specified, any value can be used.

Returns
-------
result : relax.Expr
The transformed tensor.
"""
if callable(index_map):
index_map = IndexMap.from_func(index_map)
x_dtype = x.checked_type.dtype

# Explicitly convert python int/float pad_value to the x's type. If the default behavior
# is applied, it would be converted to int32/float32, which may not match the x's type.
if pad_value is None:
pass
elif not isinstance(pad_value, PrimValue):
if "int" in x_dtype and isinstance(pad_value, int):
pad_value = IntImm(x_dtype, pad_value)
elif "float" in x_dtype and (isinstance(pad_value, (int, float))):
pad_value = FloatImm(x_dtype, float(pad_value))
pad_value = PrimValue(pad_value)
return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore


def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
"""Permutes the dimensions of an array.

Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ class BatchNormAttrs(Attrs):
"""Attributes used in batch_norm operator"""


@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
class LayoutTransformAttrs(Attrs):
"""Attributes used in layout_transform operator"""


@tvm._ffi.register_object("relax.attrs.LayerNormAttrs")
class LayerNormAttrs(Attrs):
"""Attributes used in layer_norm operator"""
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
isfinite,
isinf,
isnan,
layout_transform,
less,
less_equal,
log,
Expand Down Expand Up @@ -495,6 +496,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"isfinite",
"isinf",
"isnan",
"layout_transform",
"less",
"less_equal",
"log",
Expand Down
10 changes: 10 additions & 0 deletions src/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <tvm/ir/type_functor.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/manipulate.h>
#include <tvm/relax/ir_functor.h>
#include <tvm/relax/utils.h>

Expand Down Expand Up @@ -400,6 +401,10 @@ Doc RelaxScriptPrinter::VisitExpr_(const tir::IntImmNode* op) {
return Doc::Text(std::to_string(op->value));
}

Doc RelaxScriptPrinter::VisitExpr_(const tir::FloatImmNode* op) {
return Doc::Text(std::to_string(op->value));
}

#define TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(OpName, OpString) \
Doc RelaxScriptPrinter::VisitExpr_(const OpName* op) { \
Doc doc; \
Expand Down Expand Up @@ -491,6 +496,11 @@ std::vector<Doc> RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) {
for (const auto& k : dict_attrs->dict) {
kwargs.push_back(Doc::Text(k.first) << "=" << Print(k.second));
}
} else if (const LayoutTransformAttrs* layout_attrs = attrs.as<LayoutTransformAttrs>()) {
kwargs.push_back(Doc::Text("index_map=") << layout_attrs->index_map->ToPythonString());
if (layout_attrs->pad_value.defined()) {
kwargs.push_back(Doc::Text("pad_value=") << Print(layout_attrs->pad_value.value()));
}
} else {
AttrPrinter attr_printer(&kwargs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitAttrs(&attr_printer);
Expand Down
1 change: 1 addition & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ class RelaxScriptPrinter : public relax::IRFunctor<Doc(const ObjectRef&)>,
// PrimExpr nodes allowed in Relax
Doc VisitExpr_(const tir::VarNode* op) override;
Doc VisitExpr_(const tir::IntImmNode* op) override;
Doc VisitExpr_(const tir::FloatImmNode* op) override;
Doc VisitExpr_(const tir::AddNode* op) override;
Doc VisitExpr_(const tir::SubNode* op) override;
Doc VisitExpr_(const tir::MulNode* op) override;
Expand Down
55 changes: 55 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,61 @@ TVM_REGISTER_OP("relax.flatten")
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlatten);

/* relax.layout_transform */
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);

Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
attrs->index_map = std::move(index_map);
attrs->pad_value = std::move(pad_value);

static const Op& op = Op::Get("relax.layout_transform");
return Call(op, {std::move(x)}, Attrs{attrs}, {});
}

TVM_REGISTER_GLOBAL("relax.op.layout_transform")
.set_body_typed([](Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value) {
return layout_transform(x, index_map, pad_value);
});
psrivas2 marked this conversation as resolved.
Show resolved Hide resolved

StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<LayoutTransformAttrs>();
psrivas2 marked this conversation as resolved.
Show resolved Hide resolved
tir::IndexMap index_map = attrs->index_map;
psrivas2 marked this conversation as resolved.
Show resolved Hide resolved

if (!data_sinfo->shape.defined()) {
// For unknown input shape, the best we can do is get the #dims in output from index map
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
}
Array<PrimExpr> input_shape;
if (const auto* shape_sinfo = GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value())) {
if (!shape_sinfo->values.defined()) {
// For unknown input shape, the best we can do is get the #dims in output from index map
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
}
input_shape = shape_sinfo->values.value();
} else {
const auto* shape_expr = data_sinfo->shape.as<ShapeExprNode>();
ICHECK(shape_expr);
input_shape = shape_expr->values;
}
psrivas2 marked this conversation as resolved.
Show resolved Hide resolved

if (input_shape.size() != index_map->initial_indices.size()) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "number of dimensions in input must match the number of source dimensions "
"in index map, but got "
<< input_shape.size() << " != " << index_map->initial_indices.size());
}
psrivas2 marked this conversation as resolved.
Show resolved Hide resolved
Array<PrimExpr> output_shape = index_map->MapShape(input_shape);
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}

TVM_REGISTER_OP("relax.layout_transform")
.set_num_inputs(1)
.set_attrs_type<LayoutTransformAttrs>()
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayoutTransform);

/* relax.permute_dims */
TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs);

Expand Down
10 changes: 10 additions & 0 deletions src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ Expr expand_dims(Expr x, Array<Integer> axis);
*/
Expr flatten(Expr x);

/*!
* \brief Transform layout of a tensor.
* \param x The input data to the operator.
* \param index_map The transformation to apply.
* \param pad_value The value used for padding if the transformation results in implicit padding. If
* not specified, any value can be used.
* \return The transformed result.
*/
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value);

/*!
* \brief Permutes the dimensions of an array.
* \param x The input data to the operator.
Expand Down
115 changes: 115 additions & 0 deletions tests/python/relax/test_op_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def test_op_correctness():
assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape")
assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split")
assert relax.op.squeeze(x).op == Op.get("relax.squeeze")
assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, a)).op == Op.get(
"relax.layout_transform"
)


def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
Expand Down Expand Up @@ -649,6 +652,118 @@ def test_expand_dims_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.expand_dims(x1, axis=[]))


def test_layout_transform_infer_struct_info():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((10, 20, 30), "float32"))

transpose_transform = lambda a, b, c: (a, c, b)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=transpose_transform),
relax.TensorStructInfo((10, 30, 20), "float32"),
)

tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=tiling_transform),
relax.TensorStructInfo((10, 10, 30, 2), "float32"),
)

implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=implicit_padding_transform, pad_value=2),
relax.TensorStructInfo((10, 30, 7, 3), "float32"),
)

flatten_transform = lambda a, b, c: (a * 600 + b * 30 + c)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=flatten_transform),
relax.TensorStructInfo((6000,), "float32"),
)


def test_layout_transform_infer_struct_info_unknown_shape():
bb = relax.BlockBuilder()
x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2))
x_unknown_rank_dtype = relax.Var("x", R.Tensor())

tiling_transform = lambda a, b: (a, b // 2, b % 2)
_check_inference(
bb,
relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform),
relax.TensorStructInfo(dtype="float32", ndim=3),
)
_check_inference(
bb,
relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform),
relax.TensorStructInfo(dtype="", ndim=3),
)


def test_layout_transform_infer_struct_info_symbolic_shape():
bb = relax.BlockBuilder()
a = tir.Var("a", "int64")
b = tir.Var("b", "int64")
x0 = relax.Var("x", R.Tensor((a, b), "float32"))

tiling_transform = lambda a, b: (a, b // 3, b % 3)
_check_inference(
bb,
relax.op.layout_transform(x0, index_map=tiling_transform),
relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"),
)


def test_layout_transform_infer_struct_info_shape_var():
bb = relax.BlockBuilder()

s = relax.Var("s", relax.ShapeStructInfo((30, 20)))
x = relax.Var("x", relax.TensorStructInfo(s, "float32"))
tiling_padding_transform = lambda a, b: (a, b // 3, b % 3)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=tiling_padding_transform),
relax.TensorStructInfo((30, 7, 3), "float32"),
)

s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2))
x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32"))
_check_inference(
bb,
relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform),
relax.TensorStructInfo(ndim=3, dtype="float32"),
)

s_unknown_rank = relax.Var("s", relax.ShapeStructInfo())
x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32"))
_check_inference(
bb,
relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform),
relax.TensorStructInfo(ndim=3, dtype="float32"),
)

a = tir.Var("a", "int64")
b = tir.Var("b", "int64")
s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b)))
x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32"))
_check_inference(
bb,
relax.op.layout_transform(x_symbolic_shape, index_map=tiling_padding_transform),
relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"),
)


def test_layout_transform_infer_struct_info_invalid_index_map():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((10, 20, 30), "float32"))

with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x, index_map=lambda a, b: (b, a)))


def test_squeeze_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32"))
Expand Down
Loading