Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[Op][Layout] Add layout_transform operator in Relax. (#403)
Browse files Browse the repository at this point in the history
Adds layout_transform operator in Relax as part of Relax layout planning.

layout_transform takes an input tensor x and an attribute index_map of tir::IndexMap type. It also has an optional pad_value attribute. It transforms x as per the index_map attribute and returns the transformed tensor.

This operator for now allows layout transformations that introduce implicit padding. For example, transforming a tensor of shape (10,) using the lambda i: (i//4, i%4). The output shape will be (3, 4) with two elements padded. The optional pad_value is used to pad if specified, otherwise the compiler is free to choose any value to pad.
  • Loading branch information
psrivas2 authored and junrushao committed Feb 8, 2023
1 parent 3ee354e commit 4e4cbe4
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 3 deletions.
16 changes: 16 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_RELAX_ATTRS_MANIPULATE_H_

#include <tvm/relax/expr.h>
#include <tvm/tir/index_map.h>

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -52,6 +53,21 @@ struct ExpandDimsAttrs : public tvm::AttrsNode<ExpandDimsAttrs> {
}
}; // struct ExpandDimsAttrs

/*! \brief Attributes used in layout_transform operator */
struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
tir::IndexMap index_map;
// pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This
// needs to be revisited in case PrimValue is evolved to represent symbolic expression in future.
Optional<PrimValue> pad_value;

TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
TVM_ATTR_FIELD(pad_value).describe(
"The specific value to be used to pad if the layout transform would result in implicit "
"padding. If not specified, the compiler is free to choose any value.");
}
}; // struct LayoutTransformAttrs

/*! \brief Attributes used in permute_dims operator */
struct PermuteDimsAttrs : public tvm::AttrsNode<PermuteDimsAttrs> {
Optional<Array<Integer>> axes;
Expand Down
47 changes: 44 additions & 3 deletions python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Manipulation operators."""
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union, Callable

from tvm.ir.expr import PrimExpr
from tvm.tir import IntImm
from tvm.tir import IntImm, FloatImm, IndexMap

from . import _ffi_api
from ..expr import Expr, ShapeExpr, Tuple as RxTuple
from ..expr import Expr, PrimValue, ShapeExpr, Tuple as RxTuple


PrimExprLike = Union[int, PrimExpr]
Expand Down Expand Up @@ -110,6 +110,47 @@ def flatten(x: Expr) -> Expr:
return _ffi_api.flatten(x) # type: ignore


def layout_transform(
x: Expr,
index_map: Union[Callable, IndexMap],
pad_value: Optional[Union[int, float, PrimValue]] = None,
):
"""Modifies the layout of a tensor.
Parameters
----------
x : relax.Expr
The input tensor to the operator.
index_map : Union[Callable, IndexMap]
The transformation to apply.
pad_value : Optional[Union[int, float, PrimValue]]
The value used for padding if the transformation results in implicit padding.
If not specified, any value can be used.
Returns
-------
result : relax.Expr
The transformed tensor.
"""
if callable(index_map):
index_map = IndexMap.from_func(index_map)
x_dtype = x.checked_type.dtype

# Explicitly convert python int/float pad_value to the x's type. If the default behavior
# is applied, it would be converted to int32/float32, which may not match the x's type.
if pad_value is None:
pass
elif not isinstance(pad_value, PrimValue):
if "int" in x_dtype and isinstance(pad_value, int):
pad_value = IntImm(x_dtype, pad_value)
elif "float" in x_dtype and (isinstance(pad_value, (int, float))):
pad_value = FloatImm(x_dtype, float(pad_value))
pad_value = PrimValue(pad_value)
return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore


def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
"""Permutes the dimensions of an array.
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/relax/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ class BatchNormAttrs(Attrs):
"""Attributes used in batch_norm operator"""


@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
class LayoutTransformAttrs(Attrs):
"""Attributes used in layout_transform operator"""


@tvm._ffi.register_object("relax.attrs.LayerNormAttrs")
class LayerNormAttrs(Attrs):
"""Attributes used in layer_norm operator"""
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
isfinite,
isinf,
isnan,
layout_transform,
less,
less_equal,
linear,
Expand Down Expand Up @@ -498,6 +499,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"isfinite",
"isinf",
"isnan",
"layout_transform",
"less",
"less_equal",
"linear",
Expand Down
73 changes: 73 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,79 @@ TVM_REGISTER_OP("relax.flatten")
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlatten);

/* relax.layout_transform */
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);

Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
attrs->index_map = std::move(index_map);
attrs->pad_value = std::move(pad_value);

static const Op& op = Op::Get("relax.layout_transform");
return Call(op, {std::move(x)}, Attrs{attrs}, {});
}

TVM_REGISTER_GLOBAL("relax.op.layout_transform").set_body_typed(layout_transform);

StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<LayoutTransformAttrs>();
tir::IndexMap index_map = attrs->index_map;
Optional<PrimValue> optional_pad_value = attrs->pad_value;

// Check pad_value has same dtype as input.
if (optional_pad_value.defined()) {
PrimExpr padded_value = optional_pad_value.value()->value;
if (padded_value->dtype != data_sinfo->dtype) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "layout_transform pad_value dtype (" << padded_value->dtype
<< ") and input dtype (" << data_sinfo->dtype << ") must be the same");
}
}

// We would like to ensure safety, and therefore placed a stronger requirement for user to
// use MatchCast before layout_transform if the shape of input is not known at compile time.
// Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit
// this requirement to reduce the workload of importers and better support dynamic shapes.
auto report_error_for_unknown_shape =
[&]() {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "layout_transform expects the input tensor to have known rank (expected rank = "
<< index_map->initial_indices.size()
<< ") and shape. For input tensors, whose shape cannot be determined at compile time, "
"please use MatchCast to get input with symbolic shape.");
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
};

if (data_sinfo->IsUnknownNdim()) return report_error_for_unknown_shape();

// If rank is known, check that it is compatible with the index_map, i.e., #dims match.
if (index_map->initial_indices.size() != static_cast<size_t>(data_sinfo->ndim)) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "number of dimensions in input must match the number of source dimensions "
"in index map, but got "
<< data_sinfo->ndim << " != " << index_map->initial_indices.size());
}

if (!data_sinfo->shape.defined()) return report_error_for_unknown_shape();

// If input shape is known, get the ShapeStructInfo of the shape expr.
const auto* shape_sinfo = GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value());

if (!shape_sinfo->values.defined()) return report_error_for_unknown_shape();

Array<PrimExpr> input_shape = shape_sinfo->values.value();
Array<PrimExpr> output_shape = index_map->MapShape(input_shape);
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}

TVM_REGISTER_OP("relax.layout_transform")
.set_num_inputs(1)
.set_attrs_type<LayoutTransformAttrs>()
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayoutTransform);

/* relax.permute_dims */
TVM_REGISTER_NODE_TYPE(PermuteDimsAttrs);

Expand Down
10 changes: 10 additions & 0 deletions src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ Expr expand_dims(Expr x, Array<Integer> axis);
*/
Expr flatten(Expr x);

/*!
* \brief Transform layout of a tensor.
* \param x The input data to the operator.
* \param index_map The transformation to apply.
* \param pad_value The value used for padding if the transformation results in implicit padding. If
* not specified, any value can be used.
* \return The transformed result.
*/
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value);

/*!
* \brief Permutes the dimensions of an array.
* \param x The input data to the operator.
Expand Down
10 changes: 10 additions & 0 deletions src/relay/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <tvm/ir/type_functor.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/manipulate.h>
#include <tvm/relax/ir_functor.h>
#include <tvm/relax/utils.h>

Expand Down Expand Up @@ -400,6 +401,10 @@ Doc RelaxScriptPrinter::VisitExpr_(const tir::IntImmNode* op) {
return Doc::Text(std::to_string(op->value));
}

Doc RelaxScriptPrinter::VisitExpr_(const tir::FloatImmNode* op) {
return Doc::Text(std::to_string(op->value));
}

#define TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(OpName, OpString) \
Doc RelaxScriptPrinter::VisitExpr_(const OpName* op) { \
Doc doc; \
Expand Down Expand Up @@ -491,6 +496,11 @@ std::vector<Doc> RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) {
for (const auto& k : dict_attrs->dict) {
kwargs.push_back(Doc::Text(k.first) << "=" << Print(k.second));
}
} else if (const LayoutTransformAttrs* layout_attrs = attrs.as<LayoutTransformAttrs>()) {
kwargs.push_back(Doc::Text("index_map=") << layout_attrs->index_map->ToPythonString());
if (layout_attrs->pad_value.defined()) {
kwargs.push_back(Doc::Text("pad_value=") << Print(layout_attrs->pad_value.value()));
}
} else {
AttrPrinter attr_printer(&kwargs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitAttrs(&attr_printer);
Expand Down
1 change: 1 addition & 0 deletions src/relay/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ class RelaxScriptPrinter : public relax::IRFunctor<Doc(const ObjectRef&)>,
// PrimExpr nodes allowed in Relax
Doc VisitExpr_(const tir::VarNode* op) override;
Doc VisitExpr_(const tir::IntImmNode* op) override;
Doc VisitExpr_(const tir::FloatImmNode* op) override;
Doc VisitExpr_(const tir::AddNode* op) override;
Doc VisitExpr_(const tir::SubNode* op) override;
Doc VisitExpr_(const tir::MulNode* op) override;
Expand Down
113 changes: 113 additions & 0 deletions tests/python/relax/test_op_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def test_op_correctness():
assert relax.op.reshape(x, (4, 5, 3)).op == Op.get("relax.reshape")
assert relax.op.split(x, indices_or_sections=1).op == Op.get("relax.split")
assert relax.op.squeeze(x).op == Op.get("relax.squeeze")
assert relax.op.layout_transform(x, index_map=lambda a, b, c: (b, c, a)).op == Op.get(
"relax.layout_transform"
)


def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo):
Expand Down Expand Up @@ -657,6 +660,116 @@ def test_expand_dims_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.expand_dims(x1, axis=[]))


def test_layout_transform_infer_struct_info():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((10, 20, 30), "float32"))

transpose_transform = lambda a, b, c: (a, c, b)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=transpose_transform),
relax.TensorStructInfo((10, 30, 20), "float32"),
)

tiling_transform = lambda a, b, c: (a, b // 2, c, b % 2)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=tiling_transform),
relax.TensorStructInfo((10, 10, 30, 2), "float32"),
)

implicit_padding_transform = lambda a, b, c: (a, c, b // 3, b % 3)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=implicit_padding_transform, pad_value=2),
relax.TensorStructInfo((10, 30, 7, 3), "float32"),
)

flatten_transform = lambda a, b, c: (a * 600 + b * 30 + c)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=flatten_transform),
relax.TensorStructInfo((6000,), "float32"),
)


def test_layout_transform_infer_struct_info_mismatch_dtype():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((10, 20, 30), "int32"))

transpose_transform = lambda a, b, c: (a, c, b)
with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x, index_map=transpose_transform, pad_value=2.2))


def test_layout_transform_infer_struct_info_unknown_shape():
bb = relax.BlockBuilder()
tiling_transform = lambda a, b: (a, b // 2, b % 2)

x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2))
with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform))

x_unknown_rank_dtype = relax.Var("x", R.Tensor())
with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform))


def test_layout_transform_infer_struct_info_symbolic_shape():
bb = relax.BlockBuilder()
a = tir.Var("a", "int64")
b = tir.Var("b", "int64")
x0 = relax.Var("x", R.Tensor((a, b), "float32"))

tiling_transform = lambda a, b: (a, b // 3, b % 3)
_check_inference(
bb,
relax.op.layout_transform(x0, index_map=tiling_transform),
relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"),
)


def test_layout_transform_infer_struct_info_shape_var():
bb = relax.BlockBuilder()

s = relax.Var("s", relax.ShapeStructInfo((30, 20)))
x = relax.Var("x", relax.TensorStructInfo(s, "float32"))
tiling_padding_transform = lambda a, b: (a, b // 3, b % 3)
_check_inference(
bb,
relax.op.layout_transform(x, index_map=tiling_padding_transform),
relax.TensorStructInfo((30, 7, 3), "float32"),
)

s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2))
x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform))

s_unknown_rank = relax.Var("s", relax.ShapeStructInfo())
x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform))

a = tir.Var("a", "int64")
b = tir.Var("b", "int64")
s_symbolic_shape = relax.Var("s", relax.ShapeStructInfo((a, b)))
x_symbolic_shape = relax.Var("x", relax.TensorStructInfo(s_symbolic_shape, "float32"))
_check_inference(
bb,
relax.op.layout_transform(x_symbolic_shape, index_map=tiling_padding_transform),
relax.TensorStructInfo((a, (b - b % (-3)) // 3, 3), "float32"),
)


def test_layout_transform_infer_struct_info_invalid_index_map():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((10, 20, 30), "float32"))

with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x, index_map=lambda a, b: (b, a)))


def test_squeeze_infer_struct_info():
bb = relax.BlockBuilder()
x0 = relax.Var("x", R.Tensor((2, 1, 3, 1, 1, 4), "float32"))
Expand Down
Loading

0 comments on commit 4e4cbe4

Please sign in to comment.