diff --git a/src/ctapipe/irf/__init__.py b/src/ctapipe/irf/__init__.py index 7c6dfac1750..b5af4f756c4 100644 --- a/src/ctapipe/irf/__init__.py +++ b/src/ctapipe/irf/__init__.py @@ -19,7 +19,6 @@ from .optimize import ( GhPercentileCutCalculator, OptimizationResult, - OptimizationResultStore, PercentileCuts, PointSourceSensitivityOptimizer, ThetaPercentileCutCalculator, @@ -37,7 +36,6 @@ "EffectiveArea2dMaker", "ResultValidRange", "OptimizationResult", - "OptimizationResultStore", "PointSourceSensitivityOptimizer", "PercentileCuts", "EventLoader", diff --git a/src/ctapipe/irf/optimize.py b/src/ctapipe/irf/optimize.py index 0ffd566948e..0c0419c3509 100644 --- a/src/ctapipe/irf/optimize.py +++ b/src/ctapipe/irf/optimize.py @@ -2,6 +2,7 @@ import operator from abc import abstractmethod +from collections.abc import Sequence import astropy.units as u import numpy as np @@ -11,22 +12,45 @@ from pyirf.cuts import calculate_percentile_cut, evaluate_binned_cut from ..core import Component, QualityQuery -from ..core.traits import AstroQuantity, Float, Integer +from ..core.traits import AstroQuantity, Float, Integer, Path from .binning import ResultValidRange, make_bins_per_decade from .select import EventPreProcessor class OptimizationResult: - def __init__(self, precuts, valid_energy, valid_offset, gh, theta): - self.precuts = precuts - self.valid_energy = ResultValidRange( - min=valid_energy["energy_min"], max=valid_energy["energy_max"] - ) - self.valid_offset = ResultValidRange( - min=valid_offset["offset_min"], max=valid_offset["offset_max"] - ) - self.gh_cuts = gh - self.theta_cuts = theta + """Result of an optimization of G/H and theta cuts or only G/H cuts.""" + + def __init__( + self, + valid_energy_min: u.Quantity, + valid_energy_max: u.Quantity, + valid_offset_min: u.Quantity, + valid_offset_max: u.Quantity, + gh_cuts: QTable, + clf_prefix: str, + theta_cuts: QTable | None = None, + precuts: QualityQuery | Sequence | None = None, + ) -> None: + if precuts: + if isinstance(precuts, QualityQuery): + if len(precuts.quality_criteria) == 0: + precuts.quality_criteria = [ + (" ", " ") + ] # Ensures table serialises properly + + self.precuts = precuts + elif isinstance(precuts, list): + self.precuts = QualityQuery(quality_criteria=precuts) + else: + self.precuts = QualityQuery(quality_criteria=list(precuts)) + else: + self.precuts = QualityQuery(quality_criteria=[(" ", " ")]) + + self.valid_energy = ResultValidRange(min=valid_energy_min, max=valid_energy_max) + self.valid_offset = ResultValidRange(min=valid_offset_min, max=valid_offset_max) + self.gh_cuts = gh_cuts + self.clf_prefix = clf_prefix + self.theta_cuts = theta_cuts def __repr__(self): if self.theta_cuts is not None: @@ -45,75 +69,52 @@ def __repr__(self): f"with {len(self.precuts.quality_criteria)} precuts>" ) + def write(self, output_name: Path | str, overwrite: bool = False) -> None: + """Write an ``OptimizationResult`` to a file in FITS format.""" -class OptimizationResultStore: - def __init__(self, precuts=None): - self._init_precuts(precuts) - self._results = None - - def _init_precuts(self, precuts): - if precuts: - if isinstance(precuts, QualityQuery): - self._precuts = precuts.quality_criteria - if len(self._precuts) == 0: - self._precuts = [(" ", " ")] # Ensures table serialises with units - elif isinstance(precuts, list): - self._precuts = precuts - else: - self._precuts = list(precuts) - else: - self._precuts = None - - def set_result( - self, gh_cuts, valid_energy, valid_offset, clf_prefix, theta_cuts=None - ): - if not self._precuts: - raise ValueError("Precuts must be defined before results can be saved") + cut_expr_tab = Table( + rows=self.precuts.quality_criteria, + names=["name", "cut_expr"], + dtype=[np.str_, np.str_], + ) + cut_expr_tab.meta["EXTNAME"] = "QUALITY_CUTS_EXPR" - gh_cuts.meta["EXTNAME"] = "GH_CUTS" - gh_cuts.meta["CLFNAME"] = clf_prefix + self.gh_cuts.meta["EXTNAME"] = "GH_CUTS" + self.gh_cuts.meta["CLFNAME"] = self.clf_prefix - energy_lim_tab = QTable(rows=[valid_energy], names=["energy_min", "energy_max"]) + energy_lim_tab = QTable( + rows=[[self.valid_energy.min, self.valid_energy.max]], + names=["energy_min", "energy_max"], + ) energy_lim_tab.meta["EXTNAME"] = "VALID_ENERGY" - offset_lim_tab = QTable(rows=[valid_offset], names=["offset_min", "offset_max"]) + offset_lim_tab = QTable( + rows=[[self.valid_offset.min, self.valid_offset.max]], + names=["offset_min", "offset_max"], + ) offset_lim_tab.meta["EXTNAME"] = "VALID_OFFSET" - self._results = [gh_cuts, energy_lim_tab, offset_lim_tab] - - if theta_cuts is not None: - theta_cuts.meta["EXTNAME"] = "RAD_MAX" - self._results += [theta_cuts] + results = [cut_expr_tab, self.gh_cuts, energy_lim_tab, offset_lim_tab] - def write(self, output_name, overwrite=False): - if not isinstance(self._results, list): - raise ValueError( - "The results of this object" - " have not been properly initialised," - " call `set_results` before writing." - ) - - cut_expr_tab = Table( - rows=self._precuts, - names=["name", "cut_expr"], - dtype=[np.str_, np.str_], - ) - cut_expr_tab.meta["EXTNAME"] = "QUALITY_CUTS_EXPR" + if self.theta_cuts is not None: + self.theta_cuts.meta["EXTNAME"] = "RAD_MAX" + results.append(self.theta_cuts) - cut_expr_tab.write(output_name, format="fits", overwrite=overwrite) + # Overwrite if needed and allowed + results[0].write(output_name, format="fits", overwrite=overwrite) - for table in self._results: + for table in results[1:]: table.write(output_name, format="fits", append=True) - def read(self, file_name): + @classmethod + def read(cls, file_name): + """Read an ``OptimizationResult`` from a file in FITS format.""" + with fits.open(file_name) as hdul: cut_expr_tab = Table.read(hdul[1]) cut_expr_lst = [(name, expr) for name, expr in cut_expr_tab.iterrows()] - # TODO: this crudely fixes a problem when loading non empty tables, make it nicer - try: + if (" ", " ") in cut_expr_lst: cut_expr_lst.remove((" ", " ")) - except ValueError: - pass precuts = QualityQuery(quality_criteria=cut_expr_lst) gh_cuts = QTable.read(hdul[2]) @@ -121,8 +122,15 @@ def read(self, file_name): valid_offset = QTable.read(hdul[4]) theta_cuts = QTable.read(hdul[5]) if len(hdul) > 5 else None - return OptimizationResult( - precuts, valid_energy, valid_offset, gh_cuts, theta_cuts + return cls( + precuts=precuts, + valid_energy_min=valid_energy["energy_min"], + valid_energy_max=valid_energy["energy_max"], + valid_offset_min=valid_offset["offset_min"], + valid_offset_max=valid_offset["offset_max"], + gh_cuts=gh_cuts, + clf_prefix=gh_cuts.meta["CLFNAME"], + theta_cuts=theta_cuts, ) @@ -173,7 +181,7 @@ def optimize_cuts( precuts: EventPreProcessor, clf_prefix: str, point_like: bool, - ) -> OptimizationResultStore: + ) -> OptimizationResult: """ Optimize G/H (and optionally theta) cuts and fill them in an ``OptimizationResult``. @@ -319,7 +327,7 @@ def optimize_cuts( precuts: EventPreProcessor, clf_prefix: str, point_like: bool, - ) -> OptimizationResultStore: + ) -> OptimizationResult: reco_energy_bins = make_bins_per_decade( self.reco_energy_min.to(u.TeV), self.reco_energy_max.to(u.TeV), @@ -343,16 +351,18 @@ def optimize_cuts( reco_energy_bins, ) - result_saver = OptimizationResultStore(precuts) - result_saver.set_result( + result = OptimizationResult( + precuts=precuts, gh_cuts=gh_cuts, - valid_energy=[self.reco_energy_min, self.reco_energy_max], - # A single set of cuts is calculated for the whole fov atm - valid_offset=[0 * u.deg, np.inf * u.deg], clf_prefix=clf_prefix, + valid_energy_min=self.reco_energy_min, + valid_energy_max=self.reco_energy_max, + # A single set of cuts is calculated for the whole fov atm + valid_offset_min=0 * u.deg, + valid_offset_max=np.inf * u.deg, theta_cuts=theta_cuts if point_like else None, ) - return result_saver + return result class PointSourceSensitivityOptimizer(CutOptimizerBase): @@ -388,7 +398,7 @@ def optimize_cuts( precuts: EventPreProcessor, clf_prefix: str, point_like: bool, - ) -> OptimizationResultStore: + ) -> OptimizationResult: reco_energy_bins = make_bins_per_decade( self.reco_energy_min.to(u.TeV), self.reco_energy_max.to(u.TeV), @@ -463,16 +473,18 @@ def optimize_cuts( reco_energy_bins, ) - result_saver = OptimizationResultStore(precuts) - result_saver.set_result( + result = OptimizationResult( + precuts=precuts, gh_cuts=gh_cuts, - valid_energy=valid_energy, - # A single set of cuts is calculated for the whole fov atm - valid_offset=[self.min_bkg_fov_offset, self.max_bkg_fov_offset], clf_prefix=clf_prefix, + valid_energy_min=valid_energy[0], + valid_energy_max=valid_energy[1], + # A single set of cuts is calculated for the whole fov atm + valid_offset_min=self.min_bkg_fov_offset, + valid_offset_max=self.max_bkg_fov_offset, theta_cuts=theta_cuts_opt if point_like else None, ) - return result_saver + return result def _get_valid_energy_range(self, opt_sens): keep_mask = np.isfinite(opt_sens["significance"]) diff --git a/src/ctapipe/irf/tests/test_optimize.py b/src/ctapipe/irf/tests/test_optimize.py index 282ce82a596..3c93c6ad8f9 100644 --- a/src/ctapipe/irf/tests/test_optimize.py +++ b/src/ctapipe/irf/tests/test_optimize.py @@ -5,49 +5,44 @@ import pytest from astropy.table import QTable -from ctapipe.core import non_abstract_children +from ctapipe.core import QualityQuery, non_abstract_children from ctapipe.irf import EventLoader, Spectra from ctapipe.irf.optimize import CutOptimizerBase -def test_optimization_result_store(tmp_path, irf_event_loader_test_config): +def test_optimization_result(tmp_path, irf_event_loader_test_config): from ctapipe.irf import ( EventPreProcessor, OptimizationResult, - OptimizationResultStore, ResultValidRange, ) result_path = tmp_path / "result.h5" epp = EventPreProcessor(irf_event_loader_test_config) - store = OptimizationResultStore(epp) - - with pytest.raises( - ValueError, - match="The results of this object have not been properly initialised", - ): - store.write(result_path) - gh_cuts = QTable( data=[[0.2, 0.8, 1.5] * u.TeV, [0.8, 1.5, 10] * u.TeV, [0.82, 0.91, 0.88]], names=["low", "high", "cut"], ) - store.set_result( + result = OptimizationResult( + precuts=epp, gh_cuts=gh_cuts, - valid_energy=[0.2 * u.TeV, 10 * u.TeV], - valid_offset=[0 * u.deg, np.inf * u.deg], clf_prefix="ExtraTreesClassifier", + valid_energy_min=0.2 * u.TeV, + valid_energy_max=10 * u.TeV, + valid_offset_min=0 * u.deg, + valid_offset_max=np.inf * u.deg, theta_cuts=None, ) - store.write(result_path) + result.write(result_path) assert result_path.exists() - result = store.read(result_path) - assert isinstance(result, OptimizationResult) - assert isinstance(result.valid_energy, ResultValidRange) - assert isinstance(result.valid_offset, ResultValidRange) - assert isinstance(result.gh_cuts, QTable) - assert result.gh_cuts.meta["CLFNAME"] == "ExtraTreesClassifier" + loaded = OptimizationResult.read(result_path) + assert isinstance(loaded, OptimizationResult) + assert isinstance(loaded.precuts, QualityQuery) + assert isinstance(loaded.valid_energy, ResultValidRange) + assert isinstance(loaded.valid_offset, ResultValidRange) + assert isinstance(loaded.gh_cuts, QTable) + assert loaded.clf_prefix == "ExtraTreesClassifier" def test_gh_percentile_cut_calculator(): @@ -96,7 +91,7 @@ def test_cut_optimizer( proton_full_reco_file, irf_event_loader_test_config, ): - from ctapipe.irf import OptimizationResultStore + from ctapipe.irf import OptimizationResult gamma_loader = EventLoader( config=irf_event_loader_test_config, @@ -128,8 +123,8 @@ def test_cut_optimizer( clf_prefix="ExtraTreesClassifier", point_like=True, ) - assert isinstance(result, OptimizationResultStore) - assert len(result._results) == 4 - assert result._results[1]["energy_min"] >= result._results[0]["low"][0] - assert result._results[1]["energy_max"] <= result._results[0]["high"][-1] - assert result._results[3]["cut"].unit == u.deg + assert isinstance(result, OptimizationResult) + assert result.clf_prefix == "ExtraTreesClassifier" + assert result.valid_energy.min >= result.gh_cuts["low"][0] + assert result.valid_energy.max <= result.gh_cuts["high"][-1] + assert result.theta_cuts["cut"].unit == u.deg diff --git a/src/ctapipe/tools/make_irf.py b/src/ctapipe/tools/make_irf.py index 8958b57d1cd..d21cfeb4c7b 100644 --- a/src/ctapipe/tools/make_irf.py +++ b/src/ctapipe/tools/make_irf.py @@ -15,7 +15,7 @@ from ..irf import ( EventLoader, EventPreProcessor, - OptimizationResultStore, + OptimizationResult, Spectra, check_bins_in_range, ) @@ -209,7 +209,7 @@ class IrfTool(Tool): ) def setup(self): - self.opt_result = OptimizationResultStore().read(self.cuts_file) + self.opt_result = OptimizationResult.read(self.cuts_file) if self.point_like and self.opt_result.theta_cuts is None: raise ToolConfigurationError( @@ -452,14 +452,14 @@ def start(self): quality_criteria=self.opt_result.precuts.quality_criteria, ) - if sel.epp.gammaness_classifier != self.opt_result.gh_cuts.meta["CLFNAME"]: + if sel.epp.gammaness_classifier != self.opt_result.clf_prefix: raise RuntimeError( "G/H cuts are only valid for gammaness scores predicted by " "the same classifier model. Requested model: %s. " "Model used for g/h cuts: %s." % ( sel.epp.gammaness_classifier, - self.opt_result.gh_cuts.meta["CLFNAME"], + self.opt_result.clf_prefix, ) ) diff --git a/src/ctapipe/tools/tests/test_optimize_event_selection.py b/src/ctapipe/tools/tests/test_optimize_event_selection.py index 8cc6cc501df..728cfb2b7f1 100644 --- a/src/ctapipe/tools/tests/test_optimize_event_selection.py +++ b/src/ctapipe/tools/tests/test_optimize_event_selection.py @@ -4,7 +4,7 @@ import pytest from astropy.table import QTable -from ctapipe.core import run_tool +from ctapipe.core import QualityQuery, run_tool @pytest.mark.parametrize("point_like", (True, False)) @@ -17,7 +17,6 @@ def test_cuts_optimization( ): from ctapipe.irf import ( OptimizationResult, - OptimizationResultStore, ResultValidRange, ) from ctapipe.tools.optimize_event_selection import IrfEventSelector @@ -41,12 +40,13 @@ def test_cuts_optimization( ret = run_tool(IrfEventSelector(), argv=argv) assert ret == 0 - result = OptimizationResultStore().read(output_path) + result = OptimizationResult.read(output_path) assert isinstance(result, OptimizationResult) + assert isinstance(result.precuts, QualityQuery) assert isinstance(result.valid_energy, ResultValidRange) assert isinstance(result.valid_offset, ResultValidRange) assert isinstance(result.gh_cuts, QTable) - assert result.gh_cuts.meta["CLFNAME"] == "ExtraTreesClassifier" + assert result.clf_prefix == "ExtraTreesClassifier" assert "cut" in result.gh_cuts.colnames if point_like: assert isinstance(result.theta_cuts, QTable)