Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
early stopping on meeting threshold
Browse files Browse the repository at this point in the history
casperdcl committed Oct 1, 2024

Verified

This commit was signed with the committer’s verified signature.
jkoenig134 Julian König
1 parent f543f9c commit d206fe8
Showing 2 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion SIRF_data_preparation/evaluation_utilities.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ def get_metrics(qm: QualityMetrics, iters: Iterable[int], srcdir='.'):
list(qm.evaluate(STIR.ImageData(str(Path(srcdir) / f'iter_{i:04d}.hv'))).values()) for i in iters])


def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 1) -> int:
def pass_index(metrics: np.ndarray, thresh: Iterable, window: int = 10) -> int:
"""
Returns first index of `metrics` with value <= `thresh`.
The values must remain below the respective thresholds for at least `window` number of entries.
18 changes: 16 additions & 2 deletions petric.py
Original file line number Diff line number Diff line change
@@ -111,21 +111,35 @@ def __call__(self, algo: Algorithm):

class QualityMetrics(ImageQualityCallback, Callback):
"""From https://github.com/SyneRBI/PETRIC/wiki#metrics-and-thresholds"""
def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1, **kwargs):
THRESHOLD = {"AEM_VOI": 0.005, "RMSE_whole_object": 0.01, "RMSE_background": 0.01}

def __init__(self, reference_image, whole_object_mask, background_mask, interval: int = 1,
threshold_window: int = 10, **kwargs):
# TODO: drop multiple inheritance once `interval` included in CIL
Callback.__init__(self, interval=interval)
ImageQualityCallback.__init__(self, reference_image, **kwargs)
self.whole_object_indices = np.where(whole_object_mask.as_array())
self.background_indices = np.where(background_mask.as_array())
self.ref_im_arr = reference_image.as_array()
self.norm = self.ref_im_arr[self.background_indices].mean()
self.threshold_window = threshold_window
self.threshold_iters = 0

def __call__(self, algo: Algorithm):
if self.skip_iteration(algo):
return
t = self._time_
for tag, value in self.evaluate(algo.x).items():
# log metrics
metrics = self.evaluate(algo.x)
for tag, value in metrics.items():
self.tb_summary_writer.add_scalar(tag, value, algo.iteration, t)
# stop if `all(metrics < THRESHOLD)` for `threshold_window` iters
if all(metrics[tag] <= self.THRESHOLD.get(tag, self.THRESHOLD[tag[:len("AEM_VOI")]]) for tag in metrics):
self.threshold_iters += 1
if self.threshold_iters >= self.threshold_window:
raise StopIteration
else:
self.threshold_iters = 0

def evaluate(self, test_im: STIR.ImageData) -> dict[str, float]:
assert not any(self.filter.values()), "Filtering not implemented"

0 comments on commit d206fe8

Please sign in to comment.