From d206fe8023f1a1a1e844c39c14effb12642b1e1d Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 1 Oct 2024 20:56:56 +0100 Subject: [PATCH] early stopping on meeting threshold --- SIRF_data_preparation/evaluation_utilities.py | 2 +- petric.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/SIRF_data_preparation/evaluation_utilities.py b/SIRF_data_preparation/evaluation_utilities.py index 20a371e..41f87d4 100644 --- a/SIRF_data_preparation/evaluation_utilities.py +++ b/SIRF_data_preparation/evaluation_utilities.py @@ -25,7 +25,7 @@ def get_metrics(qm: QualityMetrics, iters: Iterable[int], srcdir='.'): list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters]) -def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 1) -> int: +def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int: """ Returns first index of `metrics` with value <= `thresh`. The values must remain below the respective thresholds for at least `window` number of entries. diff --git a/petric.py b/petric.py index 8226558..725b31b 100755 --- a/petric.py +++ b/petric.py @@ -111,7 +111,10 @@ def __call__(self, algo: Algorithm): class QualityMetrics(ImageQualityCallback, Callback): """From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds""" - def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1, **kwargs): + THRESHOLD = {"AEM_VOI": 0.005, "RMSE_whole_object": 0.01, "RMSE_background": 0.01} + + def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1, + threshold_window: int = 10, **kwargs): # TODO: drop multiple inheritance once `interval` included in CIL Callback.__init__(self, interval=interval) ImageQualityCallback.__init__(self, reference_image, **kwargs) @@ -119,13 +122,24 @@ def __init__(self, reference_image, whole_object_mask, background_mask, interval self.background_indices = np.where(background_mask.as_array()) self.ref_im_arr = reference_image.as_array() self.norm = self.ref_im_arr[self.background_indices].mean() + self.threshold_window = threshold_window + self.threshold_iters = 0 def __call__(self, algo: Algorithm): if self.skip_iteration(algo): return t = self._time_ - for tag, value in self.evaluate(algo.x).items(): + # log metrics + metrics = self.evaluate(algo.x) + for tag, value in metrics.items(): self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t) + # stop if `all(metrics < THRESHOLD)` for `threshold_window` iters + if all(metrics[tag] <= self.THRESHOLD.get(tag, self.THRESHOLD[tag[:len("AEM_VOI")]]) for tag in metrics): + self.threshold_iters += 1 + if self.threshold_iters >= self.threshold_window: + raise StopIteration + else: + self.threshold_iters = 0 def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]: assert not any(self.filter.values()), "Filtering not implemented"