Skip to content

Commit c3992ee

Browse files
authored
Merge pull request #108 from jdb78/fix/stack_variable_lengths_tensors
Enable stacking of variable lengths tensors
2 parents d5e6fa8 + 0d299c1 commit c3992ee

File tree

2 files changed

+40
-6
lines changed

2 files changed

+40
-6
lines changed

pytorch_forecasting/models/temporal_fusion_transformer/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
InterpretableMultiHeadAttention,
2222
VariableSelectionNetwork,
2323
)
24-
from pytorch_forecasting.utils import autocorrelation, get_embedding_size, integer_histogram
24+
from pytorch_forecasting.utils import autocorrelation, get_embedding_size, integer_histogram, padded_stack
2525

2626

2727
class TemporalFusionTransformer(BaseModel, CovariatesMixin):
@@ -791,7 +791,8 @@ def _log_interpretation(self, outputs, label="train"):
791791
"""
792792
# extract interpretations
793793
interpretation = {
794-
name: torch.stack([x["interpretation"][name] for x in outputs]).sum(0)
794+
# use padded_stack because decoder length histogram can be of different length
795+
name: padded_stack([x["interpretation"][name] for x in outputs], side="right", value=0).sum(0)
795796
for name in outputs[0]["interpretation"].keys()
796797
}
797798
# normalize attention with length histogram squared to account for: 1. zeros in attention and

pytorch_forecasting/utils.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
Helper functions for PyTorch forecasting
33
"""
44
from contextlib import redirect_stdout
5-
import functools
6-
import inspect
75
import os
8-
import re
9-
from typing import Callable, Tuple, Union
6+
from typing import Callable, List, Tuple, Union
107

118
import torch
9+
import torch.nn.functional as F
1210
from torch.nn.utils import rnn
1311

1412

@@ -202,3 +200,38 @@ def unpack_sequence(sequence: Union[torch.Tensor, rnn.PackedSequence]) -> Tuple[
202200
else:
203201
lengths = torch.ones(sequence.size(0), device=sequence.device, dtype=torch.long) * sequence.size(1)
204202
return sequence, lengths
203+
204+
205+
def padded_stack(
206+
tensors: List[torch.Tensor], side: str = "right", mode: str = "constant", value: Union[int, float] = 0
207+
) -> torch.Tensor:
208+
"""
209+
Stack tensors along first dimension and pad them along last dimension to ensure their size is equal.
210+
211+
Args:
212+
tensors (List[torch.Tensor]): list of tensors to stack
213+
side (str): side on which to pad - "left" or "right". Defaults to "right".
214+
mode (str): 'constant', 'reflect', 'replicate' or 'circular'. Default: 'constant'
215+
value (Union[int, float]): value to use for constant padding
216+
217+
Returns:
218+
torch.Tensor: stacked tensor
219+
"""
220+
full_size = max([x.size(-1) for x in tensors])
221+
222+
def make_padding(pad):
223+
if side == "left":
224+
return (pad, 0)
225+
elif side == "right":
226+
return (0, pad)
227+
else:
228+
raise ValueError(f"side for padding '{side}' is unknown")
229+
230+
out = torch.stack(
231+
[
232+
F.pad(x, make_padding(full_size - x.size(-1)), mode=mode, value=value) if full_size - x.size(-1) > 0 else x
233+
for x in tensors
234+
],
235+
dim=0,
236+
)
237+
return out

0 commit comments

Comments
 (0)