Skip to content

Commit

Permalink
Fixed the formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Christoph Toennis committed Jun 3, 2024
1 parent 47b9781 commit 59cd446
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 22 deletions.
31 changes: 15 additions & 16 deletions src/ctapipe/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,12 @@ class EventType(enum.Enum):


class VarianceType(enum.Enum):
"""Enum of variance types used for the DL1PedestalVarianceContainer
"""
"""Enum of variance types used for the DL1PedestalVarianceContainer"""

#Simple variance of waveform
SIMPLE=0
#Variance of intgrated samples of a waveform
SAMPLE=1
# Simple variance of waveform
SIMPLE = 0
# Variance of integrated samples of a waveform
SAMPLE = 1


class PixelStatus(enum.IntFlag):
Expand Down Expand Up @@ -529,7 +528,7 @@ class DL1CameraContainer(Container):

class DL1PedestalVarianceContainer(Container):
"""
Storage of output of camera variance image e.g.
Storage of output of camera variance image e.g.
the variance of each pixel composed as an image.
"""

Expand All @@ -539,25 +538,25 @@ class DL1PedestalVarianceContainer(Container):
"Shape: (n_pixel) if n_channels is 1 or data is gain selected"
"else: (n_channels, n_pixel)",
)

trigger_time = Field(
None,
"Trigger time for this image"
"Will be needed by the startracker code later to determine ",
"Trigger time for this variance image" "Value is a float",
)

pointing = Field(default_factory=TelescopePointingContainer,
description="Telescope pointing for the startracker code",
)


VarMethod = Field(
VarianceType.SIMPLE,
"Method by which the variance was calculated"
"This can either be a plain variance"
"or a variance of integrated samples",
type=VarianceType,
)
is_valid = Field(
False,
(
"True if image extraction succeeded, False if failed "
"or in the case of TwoPass methods, that the first "
"pass only was returned."
),
)


class DL1Container(Container):
Expand Down
13 changes: 8 additions & 5 deletions src/ctapipe/image/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,11 +1307,14 @@ class VarianceExtractor(ImageExtractor):
"""

def __call__(
self, waveforms, tel_id, trigger_time
) -> DL1PedestalVarianceContainer:
variance = np.nanvar(waveforms,axis=2)
return DL1PedestalVarianceContainer(image=variance, method=VarianceType.SIMPLE, trigger_time=trigger_time)
def __call__(self, waveforms, tel_id, trigger_time) -> DL1PedestalVarianceContainer:
variance = np.nanvar(waveforms, dtype="float32", axis=2)
return DL1PedestalVarianceContainer(
image=variance,
VarMethod=VarianceType.SIMPLE,
is_valid=True,
trigger_time=np.float32(trigger_time),
)


def deconvolution_parameters(
Expand Down
13 changes: 12 additions & 1 deletion src/ctapipe/image/tests/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
NeighborPeakWindowSum,
SlidingWindowMaxSum,
TwoPassWindowSum,
VarianceExtractor,
__filtfilt_fast,
adaptive_centroid,
deconvolve,
Expand Down Expand Up @@ -408,6 +409,9 @@ def test_extractors(Extractor, toymodels, request):
extractor(waveforms, tel_id, selected_gain_channel, broken_pixels)
return

if Extractor is VarianceExtractor:
return

dl1 = extractor(waveforms, tel_id, selected_gain_channel, broken_pixels)
assert dl1.is_valid
if dl1.image.ndim == 1:
Expand All @@ -423,7 +427,7 @@ def test_extractors(Extractor, toymodels, request):
@pytest.mark.parametrize("Extractor", extractors)
def test_integration_correction_off(Extractor, toymodels, request):
# full waveform extractor does not have an integration correction
if Extractor is FullWaveformSum:
if Extractor is FullWaveformSum or Extractor is VarianceExtractor:
return

(
Expand Down Expand Up @@ -713,7 +717,14 @@ def test_dtype(Extractor, subarray):
extractor = Extractor(subarray=subarray)
n_channels, n_pixels, _ = waveforms.shape
broken_pixels = np.zeros((n_channels, n_pixels), dtype=bool)
if Extractor is VarianceExtractor:
var = extractor(waveforms, tel_id, 0.0)
assert var.image.dtype == np.float32
assert var.trigger_time.dtype == np.float32
return

dl1 = extractor(waveforms, tel_id, selected_gain_channel, broken_pixels)

assert dl1.image.dtype == np.float32
assert dl1.peak_time.dtype == np.float32

Expand Down

0 comments on commit 59cd446

Please sign in to comment.