Skip to content

Commit

Permalink
Merge OptimizationResult and OptimizationResultStore
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasBeiske committed Oct 18, 2024
1 parent 62dd84c commit ed3afb7
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 118 deletions.
2 changes: 0 additions & 2 deletions src/ctapipe/irf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from .optimize import (
GhPercentileCutCalculator,
OptimizationResult,
OptimizationResultStore,
PercentileCuts,
PointSourceSensitivityOptimizer,
ThetaPercentileCutCalculator,
Expand All @@ -37,7 +36,6 @@
"EffectiveArea2dMaker",
"ResultValidRange",
"OptimizationResult",
"OptimizationResultStore",
"PointSourceSensitivityOptimizer",
"PercentileCuts",
"EventLoader",
Expand Down
174 changes: 93 additions & 81 deletions src/ctapipe/irf/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import operator
from abc import abstractmethod
from collections.abc import Sequence

import astropy.units as u
import numpy as np
Expand All @@ -11,22 +12,45 @@
from pyirf.cuts import calculate_percentile_cut, evaluate_binned_cut

from ..core import Component, QualityQuery
from ..core.traits import AstroQuantity, Float, Integer
from ..core.traits import AstroQuantity, Float, Integer, Path
from .binning import ResultValidRange, make_bins_per_decade
from .select import EventPreProcessor


class OptimizationResult:
def __init__(self, precuts, valid_energy, valid_offset, gh, theta):
self.precuts = precuts
self.valid_energy = ResultValidRange(
min=valid_energy["energy_min"], max=valid_energy["energy_max"]
)
self.valid_offset = ResultValidRange(
min=valid_offset["offset_min"], max=valid_offset["offset_max"]
)
self.gh_cuts = gh
self.theta_cuts = theta
"""Result of an optimization of G/H and theta cuts or only G/H cuts."""

def __init__(
self,
valid_energy_min: u.Quantity,
valid_energy_max: u.Quantity,
valid_offset_min: u.Quantity,
valid_offset_max: u.Quantity,
gh_cuts: QTable,
clf_prefix: str,
theta_cuts: QTable | None = None,
precuts: QualityQuery | Sequence | None = None,
) -> None:
if precuts:
if isinstance(precuts, QualityQuery):
if len(precuts.quality_criteria) == 0:
precuts.quality_criteria = [
(" ", " ")
] # Ensures table serialises properly

self.precuts = precuts
elif isinstance(precuts, list):
self.precuts = QualityQuery(quality_criteria=precuts)
else:
self.precuts = QualityQuery(quality_criteria=list(precuts))
else:
self.precuts = QualityQuery(quality_criteria=[(" ", " ")])

self.valid_energy = ResultValidRange(min=valid_energy_min, max=valid_energy_max)
self.valid_offset = ResultValidRange(min=valid_offset_min, max=valid_offset_max)
self.gh_cuts = gh_cuts
self.clf_prefix = clf_prefix
self.theta_cuts = theta_cuts

def __repr__(self):
if self.theta_cuts is not None:
Expand All @@ -45,84 +69,68 @@ def __repr__(self):
f"with {len(self.precuts.quality_criteria)} precuts>"
)

def write(self, output_name: Path | str, overwrite: bool = False) -> None:
"""Write an ``OptimizationResult`` to a file in FITS format."""

class OptimizationResultStore:
def __init__(self, precuts=None):
self._init_precuts(precuts)
self._results = None

def _init_precuts(self, precuts):
if precuts:
if isinstance(precuts, QualityQuery):
self._precuts = precuts.quality_criteria
if len(self._precuts) == 0:
self._precuts = [(" ", " ")] # Ensures table serialises with units
elif isinstance(precuts, list):
self._precuts = precuts
else:
self._precuts = list(precuts)
else:
self._precuts = None

def set_result(
self, gh_cuts, valid_energy, valid_offset, clf_prefix, theta_cuts=None
):
if not self._precuts:
raise ValueError("Precuts must be defined before results can be saved")
cut_expr_tab = Table(
rows=self.precuts.quality_criteria,
names=["name", "cut_expr"],
dtype=[np.str_, np.str_],
)
cut_expr_tab.meta["EXTNAME"] = "QUALITY_CUTS_EXPR"

gh_cuts.meta["EXTNAME"] = "GH_CUTS"
gh_cuts.meta["CLFNAME"] = clf_prefix
self.gh_cuts.meta["EXTNAME"] = "GH_CUTS"
self.gh_cuts.meta["CLFNAME"] = self.clf_prefix

energy_lim_tab = QTable(rows=[valid_energy], names=["energy_min", "energy_max"])
energy_lim_tab = QTable(
rows=[[self.valid_energy.min, self.valid_energy.max]],
names=["energy_min", "energy_max"],
)
energy_lim_tab.meta["EXTNAME"] = "VALID_ENERGY"

offset_lim_tab = QTable(rows=[valid_offset], names=["offset_min", "offset_max"])
offset_lim_tab = QTable(
rows=[[self.valid_offset.min, self.valid_offset.max]],
names=["offset_min", "offset_max"],
)
offset_lim_tab.meta["EXTNAME"] = "VALID_OFFSET"

self._results = [gh_cuts, energy_lim_tab, offset_lim_tab]

if theta_cuts is not None:
theta_cuts.meta["EXTNAME"] = "RAD_MAX"
self._results += [theta_cuts]
results = [cut_expr_tab, self.gh_cuts, energy_lim_tab, offset_lim_tab]

def write(self, output_name, overwrite=False):
if not isinstance(self._results, list):
raise ValueError(
"The results of this object"
" have not been properly initialised,"
" call `set_results` before writing."
)

cut_expr_tab = Table(
rows=self._precuts,
names=["name", "cut_expr"],
dtype=[np.str_, np.str_],
)
cut_expr_tab.meta["EXTNAME"] = "QUALITY_CUTS_EXPR"
if self.theta_cuts is not None:
self.theta_cuts.meta["EXTNAME"] = "RAD_MAX"
results.append(self.theta_cuts)

cut_expr_tab.write(output_name, format="fits", overwrite=overwrite)
# Overwrite if needed and allowed
results[0].write(output_name, format="fits", overwrite=overwrite)

for table in self._results:
for table in results[1:]:
table.write(output_name, format="fits", append=True)

def read(self, file_name):
@classmethod
def read(cls, file_name):
"""Read an ``OptimizationResult`` from a file in FITS format."""

with fits.open(file_name) as hdul:
cut_expr_tab = Table.read(hdul[1])
cut_expr_lst = [(name, expr) for name, expr in cut_expr_tab.iterrows()]
# TODO: this crudely fixes a problem when loading non empty tables, make it nicer
try:
if (" ", " ") in cut_expr_lst:
cut_expr_lst.remove((" ", " "))
except ValueError:
pass

precuts = QualityQuery(quality_criteria=cut_expr_lst)
gh_cuts = QTable.read(hdul[2])
valid_energy = QTable.read(hdul[3])
valid_offset = QTable.read(hdul[4])
theta_cuts = QTable.read(hdul[5]) if len(hdul) > 5 else None

return OptimizationResult(
precuts, valid_energy, valid_offset, gh_cuts, theta_cuts
return cls(
precuts=precuts,
valid_energy_min=valid_energy["energy_min"],
valid_energy_max=valid_energy["energy_max"],
valid_offset_min=valid_offset["offset_min"],
valid_offset_max=valid_offset["offset_max"],
gh_cuts=gh_cuts,
clf_prefix=gh_cuts.meta["CLFNAME"],
theta_cuts=theta_cuts,
)


Expand Down Expand Up @@ -173,7 +181,7 @@ def optimize_cuts(
precuts: EventPreProcessor,
clf_prefix: str,
point_like: bool,
) -> OptimizationResultStore:
) -> OptimizationResult:
"""
Optimize G/H (and optionally theta) cuts
and fill them in an ``OptimizationResult``.
Expand Down Expand Up @@ -319,7 +327,7 @@ def optimize_cuts(
precuts: EventPreProcessor,
clf_prefix: str,
point_like: bool,
) -> OptimizationResultStore:
) -> OptimizationResult:
reco_energy_bins = make_bins_per_decade(
self.reco_energy_min.to(u.TeV),
self.reco_energy_max.to(u.TeV),
Expand All @@ -343,16 +351,18 @@ def optimize_cuts(
reco_energy_bins,
)

result_saver = OptimizationResultStore(precuts)
result_saver.set_result(
result = OptimizationResult(
precuts=precuts,
gh_cuts=gh_cuts,
valid_energy=[self.reco_energy_min, self.reco_energy_max],
# A single set of cuts is calculated for the whole fov atm
valid_offset=[0 * u.deg, np.inf * u.deg],
clf_prefix=clf_prefix,
valid_energy_min=self.reco_energy_min,
valid_energy_max=self.reco_energy_max,
# A single set of cuts is calculated for the whole fov atm
valid_offset_min=0 * u.deg,
valid_offset_max=np.inf * u.deg,
theta_cuts=theta_cuts if point_like else None,
)
return result_saver
return result


class PointSourceSensitivityOptimizer(CutOptimizerBase):
Expand Down Expand Up @@ -388,7 +398,7 @@ def optimize_cuts(
precuts: EventPreProcessor,
clf_prefix: str,
point_like: bool,
) -> OptimizationResultStore:
) -> OptimizationResult:
reco_energy_bins = make_bins_per_decade(
self.reco_energy_min.to(u.TeV),
self.reco_energy_max.to(u.TeV),
Expand Down Expand Up @@ -463,16 +473,18 @@ def optimize_cuts(
reco_energy_bins,
)

result_saver = OptimizationResultStore(precuts)
result_saver.set_result(
result = OptimizationResult(
precuts=precuts,
gh_cuts=gh_cuts,
valid_energy=valid_energy,
# A single set of cuts is calculated for the whole fov atm
valid_offset=[self.min_bkg_fov_offset, self.max_bkg_fov_offset],
clf_prefix=clf_prefix,
valid_energy_min=valid_energy[0],
valid_energy_max=valid_energy[1],
# A single set of cuts is calculated for the whole fov atm
valid_offset_min=self.min_bkg_fov_offset,
valid_offset_max=self.max_bkg_fov_offset,
theta_cuts=theta_cuts_opt if point_like else None,
)
return result_saver
return result

def _get_valid_energy_range(self, opt_sens):
keep_mask = np.isfinite(opt_sens["significance"])
Expand Down
49 changes: 22 additions & 27 deletions src/ctapipe/irf/tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,44 @@
import pytest
from astropy.table import QTable

from ctapipe.core import non_abstract_children
from ctapipe.core import QualityQuery, non_abstract_children
from ctapipe.irf import EventLoader, Spectra
from ctapipe.irf.optimize import CutOptimizerBase


def test_optimization_result_store(tmp_path, irf_event_loader_test_config):
def test_optimization_result(tmp_path, irf_event_loader_test_config):
from ctapipe.irf import (
EventPreProcessor,
OptimizationResult,
OptimizationResultStore,
ResultValidRange,
)

result_path = tmp_path / "result.h5"
epp = EventPreProcessor(irf_event_loader_test_config)
store = OptimizationResultStore(epp)

with pytest.raises(
ValueError,
match="The results of this object have not been properly initialised",
):
store.write(result_path)

gh_cuts = QTable(
data=[[0.2, 0.8, 1.5] * u.TeV, [0.8, 1.5, 10] * u.TeV, [0.82, 0.91, 0.88]],
names=["low", "high", "cut"],
)
store.set_result(
result = OptimizationResult(
precuts=epp,
gh_cuts=gh_cuts,
valid_energy=[0.2 * u.TeV, 10 * u.TeV],
valid_offset=[0 * u.deg, np.inf * u.deg],
clf_prefix="ExtraTreesClassifier",
valid_energy_min=0.2 * u.TeV,
valid_energy_max=10 * u.TeV,
valid_offset_min=0 * u.deg,
valid_offset_max=np.inf * u.deg,
theta_cuts=None,
)
store.write(result_path)
result.write(result_path)
assert result_path.exists()

result = store.read(result_path)
assert isinstance(result, OptimizationResult)
assert isinstance(result.valid_energy, ResultValidRange)
assert isinstance(result.valid_offset, ResultValidRange)
assert isinstance(result.gh_cuts, QTable)
assert result.gh_cuts.meta["CLFNAME"] == "ExtraTreesClassifier"
loaded = OptimizationResult.read(result_path)
assert isinstance(loaded, OptimizationResult)
assert isinstance(loaded.precuts, QualityQuery)
assert isinstance(loaded.valid_energy, ResultValidRange)
assert isinstance(loaded.valid_offset, ResultValidRange)
assert isinstance(loaded.gh_cuts, QTable)
assert loaded.clf_prefix == "ExtraTreesClassifier"


def test_gh_percentile_cut_calculator():
Expand Down Expand Up @@ -96,7 +91,7 @@ def test_cut_optimizer(
proton_full_reco_file,
irf_event_loader_test_config,
):
from ctapipe.irf import OptimizationResultStore
from ctapipe.irf import OptimizationResult

gamma_loader = EventLoader(
config=irf_event_loader_test_config,
Expand Down Expand Up @@ -128,8 +123,8 @@ def test_cut_optimizer(
clf_prefix="ExtraTreesClassifier",
point_like=True,
)
assert isinstance(result, OptimizationResultStore)
assert len(result._results) == 4
assert result._results[1]["energy_min"] >= result._results[0]["low"][0]
assert result._results[1]["energy_max"] <= result._results[0]["high"][-1]
assert result._results[3]["cut"].unit == u.deg
assert isinstance(result, OptimizationResult)
assert result.clf_prefix == "ExtraTreesClassifier"
assert result.valid_energy.min >= result.gh_cuts["low"][0]
assert result.valid_energy.max <= result.gh_cuts["high"][-1]
assert result.theta_cuts["cut"].unit == u.deg
Loading

0 comments on commit ed3afb7

Please sign in to comment.