Skip to content

Commit

Permalink
support passing losses during test step
Browse files Browse the repository at this point in the history
  • Loading branch information
Sivan Ravid committed Feb 26, 2024
1 parent 1b50c1d commit 9d537ed
Showing 1 changed file with 35 additions and 19 deletions.
54 changes: 35 additions & 19 deletions fuse/dl/lightning/pl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ def __init__(
self._validation_metrics = (
validation_metrics if validation_metrics is not None else {}
)

if test_metrics is None:
self._test_metrics = self._validation_metrics
# convert all use-cases to the same format that supports multiple val dataloaders: List[Tuple[str, OrderedDict[str, MetricBase]]]
if isinstance(self._validation_metrics, dict):
self._validation_metrics = [(None, self._validation_metrics)]

self._test_metrics = test_metrics if test_metrics is not None else {}
if log_unit not in [None, "optimizer_step", "epoch"]:
raise Exception(f"Error: unexpected log_unit {log_unit}")

Expand All @@ -158,11 +158,12 @@ def __init__(
self._prediction_keys = None
self._sep = tensorboard_sep

self._training_step_outputs = []

self._validation_step_outputs = {
i: [] for i, _ in enumerate(self._validation_metrics)
}
self._training_step_outputs = []
self._test_step_outputs = []
self._test_step_outputs = {i: [] for i, _ in enumerate(self._test_metrics)}

## forward
def forward(self, batch_dict: NDict) -> NDict:
Expand Down Expand Up @@ -204,17 +205,27 @@ def validation_step(
{"losses": batch_dict["losses"]}
)

def test_step(self, batch_dict: NDict, batch_idx: int) -> None:
def test_step(
self, batch_dict: NDict, batch_idx: int, dataloader_idx: int = 0
) -> None:
# add step number to batch_dict
batch_dict["global_step"] = self.global_step
# run forward function and store the outputs in batch_dict["model"]
batch_dict = self.forward(batch_dict)
# given the batch_dict and FuseMedML style losses - compute the losses, return the total loss (ignored) and save losses values in batch_dict["losses"]
_ = step_losses(self._losses, batch_dict)
if self._validation_losses is not None:
losses = self._validation_losses[dataloader_idx][1]
else:
losses = self._losses

_ = step_losses(losses, batch_dict)
# given the batch_dict and FuseMedML style metrics - collect the required values to compute the metrics on epoch_end
step_metrics(self._test_metrics, batch_dict)
step_metrics(self._test_metrics[dataloader_idx][1], batch_dict)
# aggregate losses
self._test_step_outputs.append({"losses": batch_dict["losses"]})
if losses: # if there are losses, collect the results
self._test_step_outputs[dataloader_idx].append(
{"losses": batch_dict["losses"]}
)

def predict_step(self, batch_dict: NDict, batch_idx: int) -> dict:
if self._prediction_keys is None:
Expand Down Expand Up @@ -267,20 +278,25 @@ def on_validation_epoch_end(self) -> None:
}

def on_test_epoch_end(self) -> None:
step_outputs = self._test_step_outputs
step_outputs_lst = self._test_step_outputs
# for the logs to be at each epoch, not each step
if self._log_unit == "epoch":
self.log("step", self.current_epoch, on_epoch=True, sync_dist=True)
# calc average epoch loss and log it
epoch_end_compute_and_log_losses(
self, "test", [e["losses"] for e in step_outputs], sep=self._sep
)
# evaluate and log it
epoch_end_compute_and_log_metrics(
self, "test", self._test_metrics, sep=self._sep
)
self.log("step", float(self.current_epoch), on_epoch=True, sync_dist=True)
for dataloader_idx, step_outputs in step_outputs_lst.items():
if len(self._test_metrics) == 1:
prefix = "test"
else:
prefix = f"test.{self._test_metrics[dataloader_idx][0]}"
# calc average epoch loss and log it
epoch_end_compute_and_log_losses(
self, prefix, [e["losses"] for e in step_outputs], sep=self._sep
)
# evaluate and log it
epoch_end_compute_and_log_metrics(
self, prefix, self._test_metrics[dataloader_idx][1], sep=self._sep
)
# reset state
self._test_step_outputs.clear()
self._test_step_outputs = {i: [] for i, _ in enumerate(self._test_metrics)}

# configuration
def configure_callbacks(self) -> Sequence[pl.Callback]:
Expand Down

0 comments on commit 9d537ed

Please sign in to comment.