From 9d537edd1054be3c303959d9a99d4fd972c5dc33 Mon Sep 17 00:00:00 2001 From: Sivan Ravid Date: Mon, 26 Feb 2024 23:40:41 +0200 Subject: [PATCH] support passing losses during test step --- fuse/dl/lightning/pl_module.py | 54 ++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/fuse/dl/lightning/pl_module.py b/fuse/dl/lightning/pl_module.py index ee44e8aac..2e52e02d2 100644 --- a/fuse/dl/lightning/pl_module.py +++ b/fuse/dl/lightning/pl_module.py @@ -138,12 +138,12 @@ def __init__( self._validation_metrics = ( validation_metrics if validation_metrics is not None else {} ) - + if test_metrics is None: + self._test_metrics = self._validation_metrics # convert all use-cases to the same format that supports multiple val dataloaders: List[Tuple[str, OrderedDict[str, MetricBase]]] if isinstance(self._validation_metrics, dict): self._validation_metrics = [(None, self._validation_metrics)] - self._test_metrics = test_metrics if test_metrics is not None else {} if log_unit not in [None, "optimizer_step", "epoch"]: raise Exception(f"Error: unexpected log_unit {log_unit}") @@ -158,11 +158,12 @@ def __init__( self._prediction_keys = None self._sep = tensorboard_sep + self._training_step_outputs = [] + self._validation_step_outputs = { i: [] for i, _ in enumerate(self._validation_metrics) } - self._training_step_outputs = [] - self._test_step_outputs = [] + self._test_step_outputs = {i: [] for i, _ in enumerate(self._test_metrics)} ## forward def forward(self, batch_dict: NDict) -> NDict: @@ -204,17 +205,27 @@ def validation_step( {"losses": batch_dict["losses"]} ) - def test_step(self, batch_dict: NDict, batch_idx: int) -> None: + def test_step( + self, batch_dict: NDict, batch_idx: int, dataloader_idx: int = 0 + ) -> None: # add step number to batch_dict batch_dict["global_step"] = self.global_step # run forward function and store the outputs in batch_dict["model"] batch_dict = self.forward(batch_dict) # given the batch_dict and FuseMedML style losses - compute the losses, return the total loss (ignored) and save losses values in batch_dict["losses"] - _ = step_losses(self._losses, batch_dict) + if self._validation_losses is not None: + losses = self._validation_losses[dataloader_idx][1] + else: + losses = self._losses + + _ = step_losses(losses, batch_dict) # given the batch_dict and FuseMedML style metrics - collect the required values to compute the metrics on epoch_end - step_metrics(self._test_metrics, batch_dict) + step_metrics(self._test_metrics[dataloader_idx][1], batch_dict) # aggregate losses - self._test_step_outputs.append({"losses": batch_dict["losses"]}) + if losses: # if there are losses, collect the results + self._test_step_outputs[dataloader_idx].append( + {"losses": batch_dict["losses"]} + ) def predict_step(self, batch_dict: NDict, batch_idx: int) -> dict: if self._prediction_keys is None: @@ -267,20 +278,25 @@ def on_validation_epoch_end(self) -> None: } def on_test_epoch_end(self) -> None: - step_outputs = self._test_step_outputs + step_outputs_lst = self._test_step_outputs # for the logs to be at each epoch, not each step if self._log_unit == "epoch": - self.log("step", self.current_epoch, on_epoch=True, sync_dist=True) - # calc average epoch loss and log it - epoch_end_compute_and_log_losses( - self, "test", [e["losses"] for e in step_outputs], sep=self._sep - ) - # evaluate and log it - epoch_end_compute_and_log_metrics( - self, "test", self._test_metrics, sep=self._sep - ) + self.log("step", float(self.current_epoch), on_epoch=True, sync_dist=True) + for dataloader_idx, step_outputs in step_outputs_lst.items(): + if len(self._test_metrics) == 1: + prefix = "test" + else: + prefix = f"test.{self._test_metrics[dataloader_idx][0]}" + # calc average epoch loss and log it + epoch_end_compute_and_log_losses( + self, prefix, [e["losses"] for e in step_outputs], sep=self._sep + ) + # evaluate and log it + epoch_end_compute_and_log_metrics( + self, prefix, self._test_metrics[dataloader_idx][1], sep=self._sep + ) # reset state - self._test_step_outputs.clear() + self._test_step_outputs = {i: [] for i, _ in enumerate(self._test_metrics)} # configuration def configure_callbacks(self) -> Sequence[pl.Callback]: