Skip to content

Commit

Permalink
Merge pull request #2543 from cta-observatory/VarianceExtractor
Browse files Browse the repository at this point in the history
Variance extractor
  • Loading branch information
kosack authored Jul 31, 2024
2 parents 8117dcd + 721b67b commit cce22f3
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/changes/2543.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
A new ImageExtractor called ``VarianceExtractor`` was added
An Enum class was added to containers.py that is used in the metadata of the VarianceExtractor output
10 changes: 9 additions & 1 deletion src/ctapipe/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,15 @@ class EventType(enum.Enum):
UNKNOWN = 255


class VarianceType(enum.Enum):
"""Enum of variance types used for the VarianceContainer"""

# Simple variance of waveform
WAVEFORM = 0
# Variance of integrated samples of a waveform
INTEGRATED = 1


class PixelStatus(enum.IntFlag):
"""
Pixel status information
Expand Down Expand Up @@ -510,7 +519,6 @@ class DL1CameraContainer(Container):
"pass only was returned."
),
)

parameters = Field(
None, description="Image parameters", type=ImageParametersContainer
)
Expand Down
19 changes: 18 additions & 1 deletion src/ctapipe/image/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"NeighborPeakWindowSum",
"BaselineSubtractedNeighborPeakWindowSum",
"TwoPassWindowSum",
"VarianceExtractor",
"extract_around_peak",
"extract_sliding_window",
"neighbor_average_maximum",
Expand All @@ -33,7 +34,10 @@
from scipy.ndimage import convolve1d
from traitlets import Bool, Int

from ctapipe.containers import DL1CameraContainer
from ctapipe.containers import (
DL1CameraContainer,
VarianceType,
)
from ctapipe.core import TelescopeComponent
from ctapipe.core.traits import (
BoolTelescopeParameter,
Expand Down Expand Up @@ -1297,6 +1301,19 @@ def __call__(
)


class VarianceExtractor(ImageExtractor):
"""Calculate the variance over samples in each waveform."""

def __call__(
self, waveforms, tel_id, selected_gain_channel, broken_pixels
) -> DL1CameraContainer:
container = DL1CameraContainer(
image=np.nanvar(waveforms, dtype="float32", axis=2),
)
container.meta["ExtractionMethod"] = str(VarianceType.WAVEFORM)
return container


def deconvolution_parameters(
camera: CameraDescription,
upsampling: int,
Expand Down
22 changes: 20 additions & 2 deletions src/ctapipe/image/tests/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
NeighborPeakWindowSum,
SlidingWindowMaxSum,
TwoPassWindowSum,
VarianceExtractor,
__filtfilt_fast,
adaptive_centroid,
deconvolve,
Expand Down Expand Up @@ -281,6 +282,17 @@ def test_extract_around_peak_charge_expected():
assert_equal(charge, n_samples)


def test_variance_extractor(toymodel):
_, subarray, _, _, _, _ = toymodel
# make dummy data with known variance
rng = np.random.default_rng(0)
var_data = rng.normal(2.0, 5.0, size=(2, 1855, 5000))
extractor = ImageExtractor.from_name("VarianceExtractor", subarray=subarray)

variance = extractor(var_data, 0, None, None).image
np.testing.assert_allclose(variance, np.var(var_data, axis=2), rtol=1e-3)


@pytest.mark.parametrize("toymodels", camera_toymodels)
def test_neighbor_average_peakpos(toymodels, request):
waveforms, subarray, tel_id, _, _, _ = request.getfixturevalue(toymodels)
Expand Down Expand Up @@ -408,6 +420,9 @@ def test_extractors(Extractor, toymodels, request):
extractor(waveforms, tel_id, selected_gain_channel, broken_pixels)
return

if Extractor is VarianceExtractor:
return

dl1 = extractor(waveforms, tel_id, selected_gain_channel, broken_pixels)
assert dl1.is_valid
if dl1.image.ndim == 1:
Expand All @@ -423,7 +438,7 @@ def test_extractors(Extractor, toymodels, request):
@pytest.mark.parametrize("Extractor", extractors)
def test_integration_correction_off(Extractor, toymodels, request):
# full waveform extractor does not have an integration correction
if Extractor is FullWaveformSum:
if Extractor in (FullWaveformSum, VarianceExtractor):
return

(
Expand Down Expand Up @@ -714,8 +729,11 @@ def test_dtype(Extractor, subarray):
n_channels, n_pixels, _ = waveforms.shape
broken_pixels = np.zeros((n_channels, n_pixels), dtype=bool)
dl1 = extractor(waveforms, tel_id, selected_gain_channel, broken_pixels)

if Extractor is not VarianceExtractor:
assert dl1.peak_time.dtype == np.float32

assert dl1.image.dtype == np.float32
assert dl1.peak_time.dtype == np.float32


def test_global_peak_window_sum_with_pixel_fraction(subarray):
Expand Down

0 comments on commit cce22f3

Please sign in to comment.