From 59cd446346534225eb6c9b3fdab53b1a8107829a Mon Sep 17 00:00:00 2001 From: Christoph Toennis Date: Mon, 3 Jun 2024 09:08:48 +0200 Subject: [PATCH] Fixed the formatting issues --- src/ctapipe/containers.py | 31 +++++++++++------------ src/ctapipe/image/extractor.py | 13 ++++++---- src/ctapipe/image/tests/test_extractor.py | 13 +++++++++- 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/ctapipe/containers.py b/src/ctapipe/containers.py index cf528bfacd8..739e93092f7 100644 --- a/src/ctapipe/containers.py +++ b/src/ctapipe/containers.py @@ -168,13 +168,12 @@ class EventType(enum.Enum): class VarianceType(enum.Enum): - """Enum of variance types used for the DL1PedestalVarianceContainer - """ + """Enum of variance types used for the DL1PedestalVarianceContainer""" - #Simple variance of waveform - SIMPLE=0 - #Variance of intgrated samples of a waveform - SAMPLE=1 + # Simple variance of waveform + SIMPLE = 0 + # Variance of integrated samples of a waveform + SAMPLE = 1 class PixelStatus(enum.IntFlag): @@ -529,7 +528,7 @@ class DL1CameraContainer(Container): class DL1PedestalVarianceContainer(Container): """ - Storage of output of camera variance image e.g. + Storage of output of camera variance image e.g. the variance of each pixel composed as an image. """ @@ -539,18 +538,10 @@ class DL1PedestalVarianceContainer(Container): "Shape: (n_pixel) if n_channels is 1 or data is gain selected" "else: (n_channels, n_pixel)", ) - trigger_time = Field( None, - "Trigger time for this image" - "Will be needed by the startracker code later to determine ", + "Trigger time for this variance image" "Value is a float", ) - - pointing = Field(default_factory=TelescopePointingContainer, - description="Telescope pointing for the startracker code", - ) - - VarMethod = Field( VarianceType.SIMPLE, "Method by which the variance was calculated" @@ -558,6 +549,14 @@ class DL1PedestalVarianceContainer(Container): "or a variance of integrated samples", type=VarianceType, ) + is_valid = Field( + False, + ( + "True if image extraction succeeded, False if failed " + "or in the case of TwoPass methods, that the first " + "pass only was returned." + ), + ) class DL1Container(Container): diff --git a/src/ctapipe/image/extractor.py b/src/ctapipe/image/extractor.py index 6b834560bd2..4a710deee56 100644 --- a/src/ctapipe/image/extractor.py +++ b/src/ctapipe/image/extractor.py @@ -1307,11 +1307,14 @@ class VarianceExtractor(ImageExtractor): """ - def __call__( - self, waveforms, tel_id, trigger_time - ) -> DL1PedestalVarianceContainer: - variance = np.nanvar(waveforms,axis=2) - return DL1PedestalVarianceContainer(image=variance, method=VarianceType.SIMPLE, trigger_time=trigger_time) + def __call__(self, waveforms, tel_id, trigger_time) -> DL1PedestalVarianceContainer: + variance = np.nanvar(waveforms, dtype="float32", axis=2) + return DL1PedestalVarianceContainer( + image=variance, + VarMethod=VarianceType.SIMPLE, + is_valid=True, + trigger_time=np.float32(trigger_time), + ) def deconvolution_parameters( diff --git a/src/ctapipe/image/tests/test_extractor.py b/src/ctapipe/image/tests/test_extractor.py index a49d2b1324b..1ff22b9a803 100644 --- a/src/ctapipe/image/tests/test_extractor.py +++ b/src/ctapipe/image/tests/test_extractor.py @@ -19,6 +19,7 @@ NeighborPeakWindowSum, SlidingWindowMaxSum, TwoPassWindowSum, + VarianceExtractor, __filtfilt_fast, adaptive_centroid, deconvolve, @@ -408,6 +409,9 @@ def test_extractors(Extractor, toymodels, request): extractor(waveforms, tel_id, selected_gain_channel, broken_pixels) return + if Extractor is VarianceExtractor: + return + dl1 = extractor(waveforms, tel_id, selected_gain_channel, broken_pixels) assert dl1.is_valid if dl1.image.ndim == 1: @@ -423,7 +427,7 @@ def test_extractors(Extractor, toymodels, request): @pytest.mark.parametrize("Extractor", extractors) def test_integration_correction_off(Extractor, toymodels, request): # full waveform extractor does not have an integration correction - if Extractor is FullWaveformSum: + if Extractor is FullWaveformSum or Extractor is VarianceExtractor: return ( @@ -713,7 +717,14 @@ def test_dtype(Extractor, subarray): extractor = Extractor(subarray=subarray) n_channels, n_pixels, _ = waveforms.shape broken_pixels = np.zeros((n_channels, n_pixels), dtype=bool) + if Extractor is VarianceExtractor: + var = extractor(waveforms, tel_id, 0.0) + assert var.image.dtype == np.float32 + assert var.trigger_time.dtype == np.float32 + return + dl1 = extractor(waveforms, tel_id, selected_gain_channel, broken_pixels) + assert dl1.image.dtype == np.float32 assert dl1.peak_time.dtype == np.float32