From ba4ef472858dd8a1b42c65f571555618d0b2d5c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 20 Oct 2024 19:32:11 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/training/diagnostics/callbacks/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index cf085eab..a671d6e2 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -378,7 +378,7 @@ def _plot( # prepare predicted output tensor for plotting output_tensor = self.post_processors( - y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu() + y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu(), ).numpy() fig = plot_predicted_multilevel_flat_sample( @@ -457,7 +457,7 @@ def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.Lightning epoch = trainer.current_epoch if model.trainable_data is not None: - data_coords = np.rad2deg(graph[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.numpy()) + data_coords = np.rad2deg(graph[self._graph_name_data, "to", self._graph_name_data].ecoords_rad.numpy()) self.plot( trainer, @@ -470,7 +470,7 @@ def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.Lightning if model.trainable_hidden is not None: hidden_coords = np.rad2deg( - graph[(self._graph_name_hidden, "to", self._graph_name_hidden)].hcoords_rad.numpy(), + graph[self._graph_name_hidden, "to", self._graph_name_hidden].hcoords_rad.numpy(), ) self.plot( @@ -609,7 +609,7 @@ def _plot( for rollout_step in range(pl_module.rollout): y_hat = outputs[1][rollout_step] y_true = batch[ - :, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.internal_data.output.full + :, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.internal_data.output.full, ] loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() @@ -971,7 +971,7 @@ def tracker_metadata(self, trainer: pl.Trainer) -> dict: return {} - def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: + def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: """Calls the strategy to remove the checkpoint file.""" super()._remove_checkpoint(trainer, filepath) trainer.strategy.remove_checkpoint(self._get_inference_checkpoint_filepath(filepath))