|
19 | 19 |
|
20 | 20 | import tvm |
21 | 21 | from tvm import te |
| 22 | +from tvm.tir import if_then_else |
22 | 23 |
|
23 | 24 | from .. import tag |
24 | 25 | from ..utils import equal_const_int |
25 | 26 |
|
26 | 27 |
|
| 28 | +def get_padded_shape(data, pad_before, pad_after=None): |
| 29 | + """ |
| 30 | + Calculates the output shape of a tensor after applying padding. |
| 31 | +
|
| 32 | + Args: |
| 33 | + data (tvm.te.Tensor): The input tensor to which padding is applied. |
| 34 | + pad_before : list / tuple of n ints |
| 35 | + Pad width on each dimension to pad the before the axis begin. |
| 36 | + pad_after : list / tuple of n ints, optional |
| 37 | + Pad width each dimension to pad the after the axis end. |
| 38 | +
|
| 39 | + Raises: |
| 40 | + ValueError: If `pad_before` or `pad_after` lengths mismatch with `data` dimensions. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + tuple: A tuple representing the padded shape of the tensor. |
| 44 | + """ |
| 45 | + n = data.ndim |
| 46 | + pad_after = pad_after if pad_after else pad_before |
| 47 | + |
| 48 | + if len(pad_before) != n: |
| 49 | + raise ValueError(f"pad_before length {len(pad_before)} != input dims {n}") |
| 50 | + if len(pad_after) != n: |
| 51 | + raise ValueError(f"pad_after length {len(pad_after)} != input dims {n}") |
| 52 | + |
| 53 | + ana = tvm.arith.Analyzer() |
| 54 | + out_shape = tuple(ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n)) |
| 55 | + |
| 56 | + return out_shape |
| 57 | + |
| 58 | + |
27 | 59 | @tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") |
28 | 60 | def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs=None): |
29 | | - """Pad Input with zeros. |
| 61 | + """Pad Input with using pad values. |
30 | 62 |
|
31 | 63 | Parameters |
32 | 64 | ---------- |
@@ -145,3 +177,143 @@ def _pad(*indices): |
145 | 177 | return data(*mapped_tuple) |
146 | 178 |
|
147 | 179 | return te.compute(out_shape, _pad, name=name) |
| 180 | + |
| 181 | + |
| 182 | +@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") |
| 183 | +def reflect_pad(data, pad_before, pad_after=None, name="ReflectPadInput"): |
| 184 | + """ |
| 185 | + Apply reflect padding to the input tensor. |
| 186 | +
|
| 187 | + Parameters |
| 188 | + ---------- |
| 189 | + data : tvm.te.Tensor |
| 190 | + Input tensor. |
| 191 | +
|
| 192 | + pad_before : List[int] |
| 193 | + Amount to pad before each dimension. |
| 194 | +
|
| 195 | + pad_after : List[int], optional |
| 196 | + Amount to pad after each dimension. If None, defaults to pad_before. |
| 197 | +
|
| 198 | + name : str |
| 199 | + Name of the resulting tensor. |
| 200 | +
|
| 201 | + Returns |
| 202 | + ------- |
| 203 | + out : tvm.te.Tensor |
| 204 | + Reflect-padded tensor. |
| 205 | + """ |
| 206 | + out_shape = get_padded_shape(data, pad_before, pad_after) |
| 207 | + |
| 208 | + def _pad(*indices): |
| 209 | + index_tuple = [] |
| 210 | + for i in range(data.ndim): |
| 211 | + idx = indices[i] |
| 212 | + size = data.shape[i] |
| 213 | + before = pad_before[i] |
| 214 | + |
| 215 | + orig_idx = idx - before |
| 216 | + |
| 217 | + reflected_idx = if_then_else( |
| 218 | + orig_idx < 0, |
| 219 | + -orig_idx, # reflect from start (no repeat) |
| 220 | + if_then_else( |
| 221 | + orig_idx >= size, |
| 222 | + (2 * size - 2) - orig_idx, # reflect from end |
| 223 | + orig_idx, |
| 224 | + ), |
| 225 | + ) |
| 226 | + index_tuple.append(reflected_idx) |
| 227 | + return data(*index_tuple) |
| 228 | + |
| 229 | + return te.compute(out_shape, _pad, name=name) |
| 230 | + |
| 231 | + |
| 232 | +@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") |
| 233 | +def replicate_pad(data, pad_before, pad_after=None, name="ReplicatePadInput"): |
| 234 | + """ |
| 235 | + Apply replicate padding (edge padding) to the input tensor. |
| 236 | +
|
| 237 | + Parameters |
| 238 | + ---------- |
| 239 | + data : tvm.te.Tensor |
| 240 | + Input tensor. |
| 241 | +
|
| 242 | + pad_before : List[int] |
| 243 | + Amount to pad before each dimension. |
| 244 | +
|
| 245 | + pad_after : List[int], optional |
| 246 | + Amount to pad after each dimension. If None, defaults to pad_before. |
| 247 | +
|
| 248 | + name : str |
| 249 | + Name of the resulting tensor. |
| 250 | +
|
| 251 | + Returns |
| 252 | + ------- |
| 253 | + out : tvm.te.Tensor |
| 254 | + Replicate-padded tensor. |
| 255 | + """ |
| 256 | + out_shape = get_padded_shape(data, pad_before, pad_after) |
| 257 | + |
| 258 | + def _pad(*indices): |
| 259 | + index_tuple = [] |
| 260 | + for i in range(data.ndim): |
| 261 | + idx = indices[i] |
| 262 | + size = data.shape[i] |
| 263 | + before = pad_before[i] |
| 264 | + |
| 265 | + orig_idx = idx - before |
| 266 | + clamped_idx = if_then_else( |
| 267 | + orig_idx < 0, |
| 268 | + tvm.tir.const(0, "int32"), # replicate first element |
| 269 | + if_then_else( |
| 270 | + orig_idx >= size, |
| 271 | + size - 1, # replicate last element |
| 272 | + orig_idx, |
| 273 | + ), |
| 274 | + ) |
| 275 | + index_tuple.append(clamped_idx) |
| 276 | + return data(*index_tuple) |
| 277 | + |
| 278 | + return te.compute(out_shape, _pad, name=name) |
| 279 | + |
| 280 | + |
| 281 | +@tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") |
| 282 | +def circular_pad(data, pad_before, pad_after=None, name="CircularPadInput"): |
| 283 | + """ |
| 284 | + Apply circular padding (wrap around) to the input tensor. |
| 285 | +
|
| 286 | + Parameters |
| 287 | + ---------- |
| 288 | + data : tvm.te.Tensor |
| 289 | + Input tensor. |
| 290 | +
|
| 291 | + pad_before : List[int] |
| 292 | + Amount to pad before each dimension. |
| 293 | +
|
| 294 | + pad_after : List[int], optional |
| 295 | + Amount to pad after each dimension. If None, defaults to pad_before. |
| 296 | +
|
| 297 | + name : str |
| 298 | + Name of the resulting tensor. |
| 299 | +
|
| 300 | + Returns |
| 301 | + ------- |
| 302 | + out : tvm.te.Tensor |
| 303 | + Circular-padded tensor. |
| 304 | + """ |
| 305 | + out_shape = get_padded_shape(data, pad_before, pad_after) |
| 306 | + |
| 307 | + def _pad(*indices): |
| 308 | + index_tuple = [] |
| 309 | + for i in range(data.ndim): |
| 310 | + idx = indices[i] |
| 311 | + size = data.shape[i] |
| 312 | + before = pad_before[i] |
| 313 | + |
| 314 | + orig_idx = idx - before |
| 315 | + wrapped_idx = tvm.tir.indexmod(orig_idx + size, size) |
| 316 | + index_tuple.append(wrapped_idx) |
| 317 | + return data(*index_tuple) |
| 318 | + |
| 319 | + return te.compute(out_shape, _pad, name=name) |
0 commit comments