Skip to content

Commit

Permalink
Move _apply_correction into the ImageExtractor parent class
Browse files Browse the repository at this point in the history
 - _apply_correction is now a static method from ImageExtractors parent
   class
 - _calculate_correction is moved to the parent ImageExtractors class
   to avoid writing the docstrings multiple times. It needs to be
   overwritten by the child components
  • Loading branch information
Hckjs committed Apr 17, 2024
1 parent d894480 commit bcb9512
Showing 1 changed file with 36 additions and 108 deletions.
144 changes: 36 additions & 108 deletions src/ctapipe/image/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,15 +359,6 @@ def integration_correction(
return correction


def _apply_correction(charge, correction, selected_gain_channel):
"""
Helper function for applying the integration correction for certain `ImageExtractor`s.
"""
if selected_gain_channel is None:
return (charge * correction[:, np.newaxis]).astype(charge.dtype)
return (charge * correction[selected_gain_channel]).astype(charge.dtype)


class ImageExtractor(TelescopeComponent):
def __init__(self, subarray, config=None, parent=None, **kwargs):
"""
Expand Down Expand Up @@ -402,6 +393,37 @@ def __init__(self, subarray, config=None, parent=None, **kwargs):
for tel_id, telescope in subarray.tel.items()
}

def _calculate_correction(self, tel_id):
"""
Calculate the correction for the extracted charge such that the value
returned would equal 1 for a noise-less unit pulse. `ImageExtractor` types
calculating corrections need to overwrite this method.
This method should be decorated with @lru_cache to ensure it is only
calculated once per telescope.
Parameters
----------
tel_id : int
Returns
-------
correction : ndarray
The correction to apply to an extracted charge using this ImageExtractor
Has size n_channels, as a different correction value might be required
for different gain channels.
"""
pass

@staticmethod
def _apply_correction(charge, correction, selected_gain_channel):
"""
Helper function for applying the integration correction for certain `ImageExtractor`s.
"""
if selected_gain_channel is None:
return (charge * correction[:, np.newaxis]).astype(charge.dtype)
return (charge * correction[selected_gain_channel]).astype(charge.dtype)

@abstractmethod
def __call__(
self, waveforms, tel_id, selected_gain_channel, broken_pixels
Expand Down Expand Up @@ -476,24 +498,6 @@ class FixedWindowSum(ImageExtractor):

@lru_cache(maxsize=128)
def _calculate_correction(self, tel_id):
"""
Calculate the correction for the extracted charge such that the value
returned would equal 1 for a noise-less unit pulse.
This method is decorated with @lru_cache to ensure it is only
calculated once per telescope.
Parameters
----------
tel_id : int
Returns
-------
correction : ndarray
The correction to apply to an extracted charge using this ImageExtractor
Has size n_channels, as a different correction value might be required
for different gain channels.
"""
readout = self.subarray.tel[tel_id].camera.readout
return integration_correction(
readout.reference_pulse_shape,
Expand All @@ -515,7 +519,7 @@ def __call__(
)
if self.apply_integration_correction.tel[tel_id]:
correction = self._calculate_correction(tel_id=tel_id)
charge = _apply_correction(charge, correction, selected_gain_channel)
charge = self._apply_correction(charge, correction, selected_gain_channel)

# reduce dimensions for gain selected data to (n_pixels, )
if selected_gain_channel is not None:
Expand Down Expand Up @@ -562,24 +566,6 @@ class GlobalPeakWindowSum(ImageExtractor):

@lru_cache(maxsize=128)
def _calculate_correction(self, tel_id):
"""
Calculate the correction for the extracted charge such that the value
returned would equal 1 for a noise-less unit pulse.
This method is decorated with @lru_cache to ensure it is only
calculated once per telescope.
Parameters
----------
tel_id : int
Returns
-------
correction : ndarray
The correction to apply to an extracted charge using this ImageExtractor
Has size n_channels, as a different correction value might be required
for different gain channels.
"""
readout = self.subarray.tel[tel_id].camera.readout
return integration_correction(
readout.reference_pulse_shape,
Expand Down Expand Up @@ -619,7 +605,7 @@ def __call__(
)
if self.apply_integration_correction.tel[tel_id]:
correction = self._calculate_correction(tel_id=tel_id)
charge = _apply_correction(charge, correction, selected_gain_channel)
charge = self._apply_correction(charge, correction, selected_gain_channel)

# reduce dimensions for gain selected data to (n_pixels, )
if selected_gain_channel is not None:
Expand Down Expand Up @@ -651,24 +637,6 @@ class LocalPeakWindowSum(ImageExtractor):

@lru_cache(maxsize=128)
def _calculate_correction(self, tel_id):
"""
Calculate the correction for the extracted charge such that the value
returned would equal 1 for a noise-less unit pulse.
This method is decorated with @lru_cache to ensure it is only
calculated once per telescope.
Parameters
----------
tel_id : int
Returns
-------
correction : ndarray
The correction to apply to an extracted charge using this ImageExtractor
Has size n_channels, as a different correction value might be required
for different gain channels.
"""
readout = self.subarray.tel[tel_id].camera.readout
return integration_correction(
readout.reference_pulse_shape,
Expand All @@ -691,7 +659,7 @@ def __call__(
)
if self.apply_integration_correction.tel[tel_id]:
correction = self._calculate_correction(tel_id=tel_id)
charge = _apply_correction(charge, correction, selected_gain_channel)
charge = self._apply_correction(charge, correction, selected_gain_channel)

# reduce dimensions for gain selected data to (n_pixels, )
if selected_gain_channel is not None:
Expand All @@ -716,28 +684,6 @@ class SlidingWindowMaxSum(ImageExtractor):

@lru_cache(maxsize=128)
def _calculate_correction(self, tel_id):
"""
Calculate the correction for the extracted charge such that the value
returned would equal 1 for a noise-less unit pulse.
This method is decorated with @lru_cache to ensure it is only
calculated once per telescope.
The same procedure as for the actual SlidingWindowMaxSum extractor is used, but
on the reference pulse_shape (that is also more finely binned)
Parameters
----------
tel_id : int
Returns
-------
correction : ndarray
The correction to apply to an extracted charge using this ImageExtractor
Has size n_channels, as a different correction value might be required
for different gain channels.
"""

readout = self.subarray.tel[tel_id].camera.readout

# compute the number of slices to integrate in the pulse template
Expand Down Expand Up @@ -775,7 +721,7 @@ def __call__(

if self.apply_integration_correction.tel[tel_id]:
correction = self._calculate_correction(tel_id=tel_id)
charge = _apply_correction(charge, correction, selected_gain_channel)
charge = self._apply_correction(charge, correction, selected_gain_channel)

# reduce dimensions for gain selected data to (n_pixels, )
if selected_gain_channel is not None:
Expand Down Expand Up @@ -813,24 +759,6 @@ class NeighborPeakWindowSum(ImageExtractor):

@lru_cache(maxsize=128)
def _calculate_correction(self, tel_id):
"""
Calculate the correction for the extracted charge such that the value
returned would equal 1 for a noise-less unit pulse.
This method is decorated with @lru_cache to ensure it is only
calculated once per telescope.
Parameters
----------
tel_id : int
Returns
-------
correction : ndarray
The correction to apply to an extracted charge using this ImageExtractor
Has size n_channels, as a different correction value might be required
for different gain channels.
"""
readout = self.subarray.tel[tel_id].camera.readout
return integration_correction(
readout.reference_pulse_shape,
Expand Down Expand Up @@ -861,7 +789,7 @@ def __call__(

if self.apply_integration_correction.tel[tel_id]:
correction = self._calculate_correction(tel_id=tel_id)
charge = _apply_correction(charge, correction, selected_gain_channel)
charge = self._apply_correction(charge, correction, selected_gain_channel)

# reduce dimensions for gain selected data to (n_pixels, )
if selected_gain_channel is not None:
Expand Down

0 comments on commit bcb9512

Please sign in to comment.