Skip to content

Commit aa4f0d9

Browse files
authored
Merge pull request #109 from jdb78/fix/stack_variable_lengths_tensors
FIx second stacking variable length tensors occurance
2 parents c3992ee + 7930b9a commit aa4f0d9

File tree

1 file changed

+1
-1
lines changed
  • pytorch_forecasting/models/temporal_fusion_transformer

1 file changed

+1
-1
lines changed

pytorch_forecasting/models/temporal_fusion_transformer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,7 @@ def _log_interpretation(self, outputs, label="train"):
823823
# log lengths of encoder/decoder
824824
for type in ["encoder", "decoder"]:
825825
fig, ax = plt.subplots()
826-
lengths = torch.stack([out["interpretation"][f"{type}_length_histogram"] for out in outputs]).sum(0).cpu()
826+
lengths = padded_stack([out["interpretation"][f"{type}_length_histogram"] for out in outputs]).sum(0).cpu()
827827
if type == "decoder":
828828
start = 1
829829
else:

0 commit comments

Comments
 (0)