|
2 | 2 | Helper functions for PyTorch forecasting |
3 | 3 | """ |
4 | 4 | from contextlib import redirect_stdout |
5 | | -import functools |
6 | | -import inspect |
7 | 5 | import os |
8 | | -import re |
9 | | -from typing import Callable, Tuple, Union |
| 6 | +from typing import Callable, List, Tuple, Union |
10 | 7 |
|
11 | 8 | import torch |
| 9 | +import torch.nn.functional as F |
12 | 10 | from torch.nn.utils import rnn |
13 | 11 |
|
14 | 12 |
|
@@ -202,3 +200,38 @@ def unpack_sequence(sequence: Union[torch.Tensor, rnn.PackedSequence]) -> Tuple[ |
202 | 200 | else: |
203 | 201 | lengths = torch.ones(sequence.size(0), device=sequence.device, dtype=torch.long) * sequence.size(1) |
204 | 202 | return sequence, lengths |
| 203 | + |
| 204 | + |
| 205 | +def padded_stack( |
| 206 | + tensors: List[torch.Tensor], side: str = "right", mode: str = "constant", value: Union[int, float] = 0 |
| 207 | +) -> torch.Tensor: |
| 208 | + """ |
| 209 | + Stack tensors along first dimension and pad them along last dimension to ensure their size is equal. |
| 210 | +
|
| 211 | + Args: |
| 212 | + tensors (List[torch.Tensor]): list of tensors to stack |
| 213 | + side (str): side on which to pad - "left" or "right". Defaults to "right". |
| 214 | + mode (str): 'constant', 'reflect', 'replicate' or 'circular'. Default: 'constant' |
| 215 | + value (Union[int, float]): value to use for constant padding |
| 216 | +
|
| 217 | + Returns: |
| 218 | + torch.Tensor: stacked tensor |
| 219 | + """ |
| 220 | + full_size = max([x.size(-1) for x in tensors]) |
| 221 | + |
| 222 | + def make_padding(pad): |
| 223 | + if side == "left": |
| 224 | + return (pad, 0) |
| 225 | + elif side == "right": |
| 226 | + return (0, pad) |
| 227 | + else: |
| 228 | + raise ValueError(f"side for padding '{side}' is unknown") |
| 229 | + |
| 230 | + out = torch.stack( |
| 231 | + [ |
| 232 | + F.pad(x, make_padding(full_size - x.size(-1)), mode=mode, value=value) if full_size - x.size(-1) > 0 else x |
| 233 | + for x in tensors |
| 234 | + ], |
| 235 | + dim=0, |
| 236 | + ) |
| 237 | + return out |
0 commit comments