Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 20, 2024
1 parent 4300695 commit ba4ef47
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def _plot(

# prepare predicted output tensor for plotting
output_tensor = self.post_processors(
y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu()
y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu(),
).numpy()

fig = plot_predicted_multilevel_flat_sample(
Expand Down Expand Up @@ -457,7 +457,7 @@ def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.Lightning
epoch = trainer.current_epoch

if model.trainable_data is not None:
data_coords = np.rad2deg(graph[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.numpy())
data_coords = np.rad2deg(graph[self._graph_name_data, "to", self._graph_name_data].ecoords_rad.numpy())

self.plot(
trainer,
Expand All @@ -470,7 +470,7 @@ def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.Lightning

if model.trainable_hidden is not None:
hidden_coords = np.rad2deg(
graph[(self._graph_name_hidden, "to", self._graph_name_hidden)].hcoords_rad.numpy(),
graph[self._graph_name_hidden, "to", self._graph_name_hidden].hcoords_rad.numpy(),
)

self.plot(
Expand Down Expand Up @@ -609,7 +609,7 @@ def _plot(
for rollout_step in range(pl_module.rollout):
y_hat = outputs[1][rollout_step]
y_true = batch[
:, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.internal_data.output.full
:, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.internal_data.output.full,
]
loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy()

Expand Down Expand Up @@ -971,7 +971,7 @@ def tracker_metadata(self, trainer: pl.Trainer) -> dict:

return {}

def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None:
"""Calls the strategy to remove the checkpoint file."""
super()._remove_checkpoint(trainer, filepath)
trainer.strategy.remove_checkpoint(self._get_inference_checkpoint_filepath(filepath))
Expand Down

0 comments on commit ba4ef47

Please sign in to comment.