Skip to content

Commit

Permalink
fix dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
Egor Baturin committed Dec 2, 2024
1 parent 7b0ce99 commit bd3cbae
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 23 deletions.
6 changes: 3 additions & 3 deletions etna/libs/chronos/chronos_bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def forward(
# patching
patched_context = self.patch(context)
patched_mask = torch.nan_to_num(self.patch(mask), nan=0.0)
patched_context = torch.where(patched_mask > 0.0, patched_context, torch.tensor(0, dtype=patched_context.dtype)) # replaced 0.0
patched_context = torch.where(patched_mask > 0.0, patched_context, 0.0)
# concat context and mask along patch dim
patched_context = torch.cat([patched_context, patched_mask], dim=-1)

Expand Down Expand Up @@ -650,9 +650,9 @@ def predict( # type: ignore[override]
dtype=torch.float32, # scaling should be done in 32-bit precision
)
with torch.no_grad():
batch_prediction = self.model( # added batch iteration
batch_prediction = self.model( # TODO added batch iteration
context=batch_context_tensor,
).quantile_preds.to(context_tensor)
).quantile_preds.to(batch_context_tensor)
prediction.append(batch_prediction)

prediction = torch.cat(prediction, dim=0)
Expand Down
20 changes: 0 additions & 20 deletions etna/models/nn/chronos/file.py

This file was deleted.

0 comments on commit bd3cbae

Please sign in to comment.