Skip to content

Commit

Permalink
Fix predict_step for inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
jakob-schloer committed Sep 27, 2024
1 parent a40ec02 commit 5ae42df
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,14 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor:
assert (
len(batch.shape) == 4
), f"The input tensor has an incorrect shape: expected a 4-dimensional tensor, got {batch.shape}!"

x = self.pre_processors_state(batch[:, 0 : self.multi_step, ...], in_place=False)

# Dimensions are
# batch, timesteps, horizontal space, variables
x = x[..., None, :] # add dummy ensemble dimension as 3rd index

x = x[..., None, :, :] # add dummy ensemble dimension as 3rd index
if self.prediction_strategy == "tendency":
tendency_hat = self(x)
y_hat = self.add_tendency_to_state(batch[:, self.multi_step, ...], tendency_hat)
y_hat = self.add_tendency_to_state(x[:, -1, ...], tendency_hat)
else:
y_hat = self(x)
y_hat = self.post_processors_state(y_hat, in_place=False)
Expand Down

0 comments on commit 5ae42df

Please sign in to comment.