Skip to content

Commit bf6e101

Browse files
Deivanayaki-Sdeivanayakisankaralingam
andauthored
[Relax][PyTorch] Add Pad Op Support for Exported Program and FX graph (#17821)
* add pad op support into frontend pipelines fixing end of files formatting issue fixing trailing space issues update the docstring for pad mode in nn file fixing lint issues remove trailing whitespaces fix lint format issues in test script fix lint issue in pad file import statement modify docstring of pad function fixing dtype error in unity check fixing lint issues in pad.py file resolve arg mismatch error resolved error while handling pad value attr fix dtype of pad value attribute add helper function for different pad mode test script enhanced to check different pad mode remove trailing whitespaces in test script add docstring for helper function update test script * fix pad op arg handling in fx graph * fix issue by updated the retrieval of value arg --------- Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
1 parent c00f52a commit bf6e101

File tree

8 files changed

+401
-17
lines changed

8 files changed

+401
-17
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,24 @@ def _max_pool2d(self, node: fx.Node) -> relax.Var:
901901

902902
return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode)
903903

904+
def _pad(self, node: fx.Node) -> relax.Var:
905+
x = self.env[node.args[0]]
906+
pad = node.args[1]
907+
mode = node.args[2] if len(node.args) > 2 else node.kwargs.get("mode", "constant")
908+
value = node.args[3] if len(node.args) > 3 else node.kwargs.get("value", 0.0)
909+
value = 0.0 if value is None else value
910+
911+
# Calculate symmetric padding width for each dimension
912+
# and applying them in reverse order to match the input dimensions.
913+
input_ndim = x.struct_info.ndim
914+
pad_width = [0] * (input_ndim * 2)
915+
pad_pairs = [pad[i : i + 2] for i in range(0, len(pad), 2)]
916+
reversed_pairs = list(reversed(pad_pairs))
917+
flattened = [value for pair in reversed_pairs for value in pair]
918+
pad_width[-len(flattened) :] = flattened
919+
920+
return self.block_builder.emit(relax.op.nn.pad(x, pad_width, mode, value))
921+
904922
def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var:
905923
transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3])
906924
query = transpose_S_H(self.env[node.args[0]])

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def create_convert_map(
299299
"log1p.default": self._log1p,
300300
"log_softmax.int": self._log_softmax,
301301
"neg.default": self._unary_op(relax.op.negative),
302+
"pad.default": self._pad,
302303
"prelu.default": self._prelu,
303304
"reciprocal.default": self._reciprocal,
304305
"relu.default": self._unary_op(relax.op.nn.relu),

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,7 @@ def create_convert_map(
649649
"logical_not": self._unary_op(relax.op.logical_not),
650650
"log_softmax": self._log_softmax,
651651
"neg": self._unary_op(relax.op.negative),
652+
"pad": self._pad,
652653
"prelu": self._prelu,
653654
"reciprocal": self._reciprocal,
654655
"relu": self._unary_op(relax.op.nn.relu),

python/tvm/relax/op/nn/nn.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -515,9 +515,9 @@ def conv2d_transpose(
515515

516516
def pad(
517517
data: Expr,
518-
pad_width: Tuple[Tuple[int, int], ...],
518+
pad_width: Union[List[int], Tuple[int, ...]],
519519
pad_mode: Optional[str] = "constant",
520-
pad_value: Optional[Union[float, Expr]] = 0.0,
520+
pad_value: Optional[float] = 0.0,
521521
):
522522
r"""Padding
523523
@@ -528,14 +528,15 @@ def pad(
528528
----------
529529
data: relax.Expr
530530
The input data to the operator
531-
pad_width: Tuple[Tuple[int, int], ...], required
531+
pad_width: Union[List[int], Tuple[int, ...]], required
532532
Number of values padded to the edges of each axis, in the format
533533
of ((before_1, after_1), ..., (before_N, after_N))
534534
pad_mode: Optional[str]
535-
'constant', 'edge', or 'reflect'
536-
'constant' pads with constant_value pad_value
537-
'edge' pads using the edge values of the input array
538-
'reflect' pads by reflecting values with respect to the edge
535+
'constant', 'reflect', 'replicate', 'circular'
536+
'constant' pads with constant value pad_value
537+
'reflect' pads by mirroring values excluding the edge
538+
'replicate' pads by repeating the edge values.
539+
'circular' pads by looping values from the other side
539540
Default is 'constant'
540541
pad_value: Optional[Union[float, Expr]]
541542
The value used for padding. Default is 0.

python/tvm/relax/transform/legalize_ops/nn.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,31 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:
222222

223223
@register_legalize("relax.nn.pad")
224224
def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
225-
# Unpack pad_width into two separate lists for topi.
225+
pad_mode = call.attrs.pad_mode
226226
pad_widths = call.attrs.pad_width
227227
pad_before = pad_widths[::2]
228228
pad_after = pad_widths[1::2]
229-
return bb.call_te(
230-
topi.nn.pad,
231-
call.args[0],
232-
pad_before=pad_before,
233-
pad_after=pad_after,
234-
pad_value=call.attrs.pad_value,
235-
primfunc_name_hint="pad",
236-
)
229+
if pad_mode == "reflect":
230+
return bb.call_te(
231+
topi.nn.reflect_pad, call.args[0], pad_before=pad_before, pad_after=pad_after
232+
)
233+
elif pad_mode == "replicate":
234+
return bb.call_te(
235+
topi.nn.replicate_pad, call.args[0], pad_before=pad_before, pad_after=pad_after
236+
)
237+
elif pad_mode == "circular":
238+
return bb.call_te(
239+
topi.nn.circular_pad, call.args[0], pad_before=pad_before, pad_after=pad_after
240+
)
241+
else:
242+
return bb.call_te(
243+
topi.nn.pad,
244+
call.args[0],
245+
pad_before=pad_before,
246+
pad_after=pad_after,
247+
pad_value=call.attrs.pad_value,
248+
primfunc_name_hint="pad",
249+
)
237250

238251

239252
@register_legalize("relax.nn.max_pool1d")

python/tvm/topi/nn/pad.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,46 @@
1919

2020
import tvm
2121
from tvm import te
22+
from tvm.tir import if_then_else
2223

2324
from .. import tag
2425
from ..utils import equal_const_int
2526

2627

28+
def get_padded_shape(data, pad_before, pad_after=None):
29+
"""
30+
Calculates the output shape of a tensor after applying padding.
31+
32+
Args:
33+
data (tvm.te.Tensor): The input tensor to which padding is applied.
34+
pad_before : list / tuple of n ints
35+
Pad width on each dimension to pad the before the axis begin.
36+
pad_after : list / tuple of n ints, optional
37+
Pad width each dimension to pad the after the axis end.
38+
39+
Raises:
40+
ValueError: If `pad_before` or `pad_after` lengths mismatch with `data` dimensions.
41+
42+
Returns:
43+
tuple: A tuple representing the padded shape of the tensor.
44+
"""
45+
n = data.ndim
46+
pad_after = pad_after if pad_after else pad_before
47+
48+
if len(pad_before) != n:
49+
raise ValueError(f"pad_before length {len(pad_before)} != input dims {n}")
50+
if len(pad_after) != n:
51+
raise ValueError(f"pad_after length {len(pad_after)} != input dims {n}")
52+
53+
ana = tvm.arith.Analyzer()
54+
out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n))
55+
56+
return out_shape
57+
58+
2759
@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
2860
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs=None):
29-
"""Pad Input with zeros.
61+
"""Pad Input with using pad values.
3062
3163
Parameters
3264
----------
@@ -145,3 +177,143 @@ def _pad(*indices):
145177
return data(*mapped_tuple)
146178

147179
return te.compute(out_shape, _pad, name=name)
180+
181+
182+
@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
183+
def reflect_pad(data, pad_before, pad_after=None, name="ReflectPadInput"):
184+
"""
185+
Apply reflect padding to the input tensor.
186+
187+
Parameters
188+
----------
189+
data : tvm.te.Tensor
190+
Input tensor.
191+
192+
pad_before : List[int]
193+
Amount to pad before each dimension.
194+
195+
pad_after : List[int], optional
196+
Amount to pad after each dimension. If None, defaults to pad_before.
197+
198+
name : str
199+
Name of the resulting tensor.
200+
201+
Returns
202+
-------
203+
out : tvm.te.Tensor
204+
Reflect-padded tensor.
205+
"""
206+
out_shape = get_padded_shape(data, pad_before, pad_after)
207+
208+
def _pad(*indices):
209+
index_tuple = []
210+
for i in range(data.ndim):
211+
idx = indices[i]
212+
size = data.shape[i]
213+
before = pad_before[i]
214+
215+
orig_idx = idx - before
216+
217+
reflected_idx = if_then_else(
218+
orig_idx < 0,
219+
-orig_idx, # reflect from start (no repeat)
220+
if_then_else(
221+
orig_idx >= size,
222+
(2 * size - 2) - orig_idx, # reflect from end
223+
orig_idx,
224+
),
225+
)
226+
index_tuple.append(reflected_idx)
227+
return data(*index_tuple)
228+
229+
return te.compute(out_shape, _pad, name=name)
230+
231+
232+
@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
233+
def replicate_pad(data, pad_before, pad_after=None, name="ReplicatePadInput"):
234+
"""
235+
Apply replicate padding (edge padding) to the input tensor.
236+
237+
Parameters
238+
----------
239+
data : tvm.te.Tensor
240+
Input tensor.
241+
242+
pad_before : List[int]
243+
Amount to pad before each dimension.
244+
245+
pad_after : List[int], optional
246+
Amount to pad after each dimension. If None, defaults to pad_before.
247+
248+
name : str
249+
Name of the resulting tensor.
250+
251+
Returns
252+
-------
253+
out : tvm.te.Tensor
254+
Replicate-padded tensor.
255+
"""
256+
out_shape = get_padded_shape(data, pad_before, pad_after)
257+
258+
def _pad(*indices):
259+
index_tuple = []
260+
for i in range(data.ndim):
261+
idx = indices[i]
262+
size = data.shape[i]
263+
before = pad_before[i]
264+
265+
orig_idx = idx - before
266+
clamped_idx = if_then_else(
267+
orig_idx < 0,
268+
tvm.tir.const(0, "int32"), # replicate first element
269+
if_then_else(
270+
orig_idx >= size,
271+
size - 1, # replicate last element
272+
orig_idx,
273+
),
274+
)
275+
index_tuple.append(clamped_idx)
276+
return data(*index_tuple)
277+
278+
return te.compute(out_shape, _pad, name=name)
279+
280+
281+
@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad")
282+
def circular_pad(data, pad_before, pad_after=None, name="CircularPadInput"):
283+
"""
284+
Apply circular padding (wrap around) to the input tensor.
285+
286+
Parameters
287+
----------
288+
data : tvm.te.Tensor
289+
Input tensor.
290+
291+
pad_before : List[int]
292+
Amount to pad before each dimension.
293+
294+
pad_after : List[int], optional
295+
Amount to pad after each dimension. If None, defaults to pad_before.
296+
297+
name : str
298+
Name of the resulting tensor.
299+
300+
Returns
301+
-------
302+
out : tvm.te.Tensor
303+
Circular-padded tensor.
304+
"""
305+
out_shape = get_padded_shape(data, pad_before, pad_after)
306+
307+
def _pad(*indices):
308+
index_tuple = []
309+
for i in range(data.ndim):
310+
idx = indices[i]
311+
size = data.shape[i]
312+
before = pad_before[i]
313+
314+
orig_idx = idx - before
315+
wrapped_idx = tvm.tir.indexmod(orig_idx + size, size)
316+
index_tuple.append(wrapped_idx)
317+
return data(*index_tuple)
318+
319+
return te.compute(out_shape, _pad, name=name)

0 commit comments

Comments
 (0)