Skip to content

Commit c00f52a

Browse files
[Relax][PyTorch] Add Stack Op Support for Exported Program (#17819)
* add op support for stack * trailing whitespace issue fixed * fixed lint issues * fixed whitespace issue * fixed lint error * fixing lint issues * fixed whitespace issue * add test script for fx_graph * fix lint issues * fixed unity check issues * unity check * fixed unity check issues * lint issues
1 parent ba9f174 commit c00f52a

File tree

14 files changed

+435
-94
lines changed

14 files changed

+435
-94
lines changed

include/tvm/relax/attrs/manipulate.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,19 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
119119
}
120120
}; // struct SqueezeAttrs
121121

122+
/*! \brief Attributes used in stack operators */
123+
struct StackAttrs : public tvm::AttrsNode<StackAttrs> {
124+
Optional<Integer> axis;
125+
126+
TVM_DECLARE_ATTRS(StackAttrs, "relax.attrs.StackAttrs") {
127+
TVM_ATTR_FIELD(axis).describe(
128+
"The axis along which to stack the input tensors. "
129+
"The axis will be inserted at this position in the output, "
130+
"so it must be in range [-ndim-1, ndim] where ndim is the "
131+
"number of dimensions of the input tensors.");
132+
}
133+
}; // struct StackAttrs
134+
122135
/*! \brief Attributes used in repeat operators */
123136
struct RepeatAttrs : public tvm::AttrsNode<RepeatAttrs> {
124137
int repeats;

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,21 +1194,9 @@ def _squeeze(self, node: fx.Node) -> relax.Var:
11941194

11951195
def _stack(self, node: fx.Node) -> relax.Var:
11961196
args = self.retrieve_args(node)
1197+
tensor_list = args[0]
11971198
axis = args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0)
1198-
in_args = args[0]
1199-
assert all(
1200-
a.struct_info.shape[axis] == in_args[0].struct_info.shape[axis] for a in in_args[1:]
1201-
), "Expect all dim at {} to be the same, get {}".format(
1202-
axis, [a.struct_info.shape for a in args]
1203-
)
1204-
cat = self.block_builder.emit(relax.op.concat(in_args, axis=axis))
1205-
s_shape = []
1206-
for idx, s in enumerate(cat.struct_info.shape):
1207-
if idx == axis:
1208-
s_shape.extend([len(in_args), in_args[0].struct_info.shape[axis]])
1209-
else:
1210-
s_shape.append(s)
1211-
return self.block_builder.emit(relax.op.reshape(cat, s_shape))
1199+
return self.block_builder.emit(relax.op.stack(tensor_list, axis=axis))
12121200

12131201
def _take(self, node: fx.Node) -> relax.Var:
12141202
x = self.env[node.args[0]]

python/tvm/relax/op/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
scatter_nd,
105105
split,
106106
squeeze,
107+
stack,
107108
tile,
108109
)
109110
from .mask import masked_fill

python/tvm/relax/op/manipulate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,30 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr:
279279
return _ffi_api.squeeze(x, axis) # type: ignore
280280

281281

282+
def stack(tensors: Union[Expr, List[Expr]], axis: int = 0) -> Expr:
283+
"""Stack the input tensors along a new axis.
284+
285+
Parameters
286+
----------
287+
tensors : Union[relax.Expr, List[relax.Expr]]
288+
An Expr in Tuple type, containing the tensors to be stacked,
289+
or a list of Tensors. All input tensors must have the same shape.
290+
291+
axis : int
292+
The axis in the resulting tensor along which the input tensors will be stacked.
293+
Negative values wrap around. Default is 0.
294+
295+
Returns
296+
-------
297+
result: relax.Expr
298+
The stacked tensor with an additional dimension compared to the input tensors.
299+
300+
"""
301+
if isinstance(tensors, (list, tuple)):
302+
tensors = RxTuple(tensors)
303+
return _ffi_api.stack(tensors, axis) # type: ignore
304+
305+
282306
def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr:
283307
"""Return a summation of data to the shape of collapse_target.
284308

python/tvm/relax/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ class SqueezeAttrs(Attrs):
139139
"""Attributes for squeeze operator"""
140140

141141

142+
@tvm._ffi.register_object("relax.attrs.StackAttrs")
143+
class StackAttrs(Attrs):
144+
"""Attributes for concat operator"""
145+
146+
142147
@tvm._ffi.register_object("relax.attrs.LayoutTransformAttrs")
143148
class LayoutTransformAttrs(Attrs):
144149
"""Attributes used in layout_transform operator"""

python/tvm/relax/transform/legalize_ops/manipulate.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,28 @@ def _squeeze(bb: BlockBuilder, call: Call) -> Expr:
118118
return bb.call_te(topi.squeeze, call.args[0], call.attrs.axis)
119119

120120

121+
@register_legalize("relax.stack")
122+
def _stack(bb: BlockBuilder, call: Call) -> Expr:
123+
t = call.args[0]
124+
n_field = len(t.struct_info.fields)
125+
126+
# Follow bindings to find the actual tuple
127+
while isinstance(t, Var):
128+
binding = bb.lookup_binding(t)
129+
if not isinstance(binding, (Tuple, Var)):
130+
break
131+
t = binding
132+
133+
assert isinstance(t, (Tuple, Var))
134+
135+
# Extract fields from either Tuple or bound Var
136+
fields = (
137+
t.fields if isinstance(t, Tuple) else [bb.emit(TupleGetItem(t, i)) for i in range(n_field)]
138+
)
139+
140+
return bb.call_te(topi.stack, fields, 0 if call.attrs.axis is None else call.attrs.axis.value)
141+
142+
121143
@register_legalize("relax.repeat")
122144
def _repeat(bb: BlockBuilder, call: Call) -> Expr:
123145
def te_repeat(data: te.Tensor, repeats: IntImm, axis: Optional[IntImm]):

python/tvm/script/ir_builder/relax/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@
158158
sqrt,
159159
square,
160160
squeeze,
161+
stack,
161162
std,
162163
strided_slice,
163164
subtract,
@@ -851,6 +852,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
851852
"square",
852853
"squeeze",
853854
"sqrt",
855+
"stack",
854856
"stop_lift_params",
855857
"str",
856858
"strided_slice",

python/tvm/topi/transform.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,23 +403,25 @@ def concatenate(a_tuple, axis=0):
403403
return cpp.concatenate(a_tuple, axis)
404404

405405

406-
def stack(a, axis):
407-
"""Repeats the whole array multiple times.
406+
def stack(tensors, axis=0):
407+
"""Join a sequence of tensors along a new axis.
408408
409409
Parameters
410410
----------
411-
a : tvm.te.Tensor
412-
The tensor to be stacked.
411+
tensors : tuple or list of tvm.te.Tensor
412+
The tensors to be stacked. All tensors must have the same shape.
413413
414414
axis : int, optional
415-
The axis in the result array along which the input arrays are stacked.
416-
415+
The axis in the resulting tensor along which the input tensors will be stacked.
416+
Negative values wrap around. Default is 0.
417417
418418
Returns
419419
-------
420420
ret : tvm.te.Tensor
421+
The stacked tensor with an additional dimension compared to the input tensors.
422+
421423
"""
422-
return cpp.stack(a, axis)
424+
return cpp.stack(tensors, axis)
423425

424426

425427
def split(ary, indices_or_sections, axis=0):

src/contrib/msc/framework/torch/torch_opcode.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ class TorchConcatCodeGen : public TorchOpCode {
209209
void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg<int>("axis", "dim"); }
210210
};
211211

212+
class TorchStackCodeGen : public TorchOpCode {
213+
TORCH_OP_CODEGEN_METHODS(TorchStackCodeGen);
214+
215+
protected:
216+
void CodeGenForward() final { stack_.op_call().op_inputs_arg().op_arg<int>("axis", "dim"); }
217+
};
218+
212219
class TorchConstantCodeGen : public TorchOpCode {
213220
TORCH_OP_CODEGEN_METHODS(TorchConstantCodeGen);
214221

@@ -789,6 +796,7 @@ const std::shared_ptr<std::unordered_map<String, std::shared_ptr<TorchOpCode>>>
789796
std::make_shared<TorchScatterElementsCodeGen>("", "torch.scatter"));
790797
map->emplace("scatter_nd", std::make_shared<TorchScatterNDCodeGen>("", ""));
791798
map->emplace("split", std::make_shared<TorchSplitCodeGen>("", "torch.split"));
799+
map->emplace("stack", std::make_shared<TorchStackCodeGen>("", "torch.stack"));
792800
map->emplace("strided_slice", std::make_shared<TorchStridedSliceCodeGen>("", ""));
793801
map->emplace("take", std::make_shared<TorchTakeCodeGen>("", ""));
794802

0 commit comments

Comments
 (0)