From 35739a8087b504a8a1cccba11d601ed11ddb9403 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20N=C3=B6the?= Date: Sat, 2 Jul 2022 12:59:14 +0200 Subject: [PATCH 1/2] Replace deepcopying defaults with default_factory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Applying deepcopy to the defaults of a container is quite slow, especially for deep hierarchies, as the copy is not just made once, but at each level. Here I replace the deepcopying with a default_factory approach as known e.g. from the defaultdict. This dramatically speeds up container creation, especially for deeply nested containers. Timing code: ``` import numpy as np import timeit from ctapipe.containers import HillasParametersContainer, ArrayEventContainer import astropy.units as u width = u.Quantity(0.5, u.deg) length = u.Quantity(1.0, u.deg) code = ''' h = HillasParametersContainer( intensity=500, width=width, length=length, prefix="hillas_foo", ) ''' exec(code) N = 100 repeat = 1000 times = np.array(timeit.repeat(code, number=N, repeat=repeat, globals=globals())) print('HillasParametersContainer') print(f'{np.mean(times / N) * 1e6:.2f} ± {np.std(times / N) * 1e6:.2f} µs') code = ''' a = ArrayEventContainer() ''' exec(code) N = 100 repeat = 100 times = np.array(timeit.repeat(code, number=N, repeat=repeat, globals=globals())) print('ArrayEventContainer') print(f'{np.mean(times / N) * 1e6:.2f} ± {np.std(times / N) * 1e6:.2f} µs') ``` Result on master: HillasParametersContainer 14.45 ± 0.24 µs ArrayEventContainer 161.24 ± 8.34 µs Results on this branch: HillasParametersContainer 2.32 ± 0.10 µs ArrayEventContainer 10.67 ± 0.42 µs --- ctapipe/containers.py | 299 ++++++++++++++++++--------- ctapipe/core/container.py | 65 ++++-- ctapipe/core/tests/test_container.py | 8 +- 3 files changed, 248 insertions(+), 124 deletions(-) diff --git a/ctapipe/containers.py b/ctapipe/containers.py index f12f5c4aabf..020e0950371 100644 --- a/ctapipe/containers.py +++ b/ctapipe/containers.py @@ -2,6 +2,7 @@ Container structures for data that should be read or written to disk """ import enum +from functools import partial from astropy import units as u from astropy.time import Time @@ -291,26 +292,37 @@ class ImageParametersContainer(Container): container_prefix = "params" hillas = Field( - HillasParametersContainer(), - "Hillas Parameters", + default_factory=HillasParametersContainer, + description="Hillas Parameters", type=BaseHillasParametersContainer, ) timing = Field( - TimingParametersContainer(), - "Timing Parameters", + default_factory=TimingParametersContainer, + description="Timing Parameters", type=BaseTimingParametersContainer, ) - leakage = Field(LeakageContainer(), "Leakage Parameters") - concentration = Field(ConcentrationContainer(), "Concentration Parameters") - morphology = Field(MorphologyContainer(), "Image Morphology Parameters") + leakage = Field( + default_factory=LeakageContainer, + description="Leakage Parameters", + ) + concentration = Field( + default_factory=ConcentrationContainer, + description="Concentration Parameters", + ) + morphology = Field( + default_factory=MorphologyContainer, description="Image Morphology Parameters" + ) intensity_statistics = Field( - IntensityStatisticsContainer(), "Intensity image statistics" + default_factory=IntensityStatisticsContainer, + description="Intensity image statistics", ) peak_time_statistics = Field( - PeakTimeStatisticsContainer(), "Peak time image statistics" + default_factory=PeakTimeStatisticsContainer, + description="Peak time image statistics", ) core = Field( - CoreParametersContainer(), "Image direction in the Tilted/Ground Frame" + default_factory=CoreParametersContainer, + description="Image direction in the Tilted/Ground Frame", ) @@ -349,13 +361,18 @@ class DL1CameraContainer(Container): ), ) - parameters = Field(None, "Image parameters", type=ImageParametersContainer) + parameters = Field( + None, description="Image parameters", type=ImageParametersContainer + ) class DL1Container(Container): """DL1 Calibrated Camera Images and associated data""" - tel = Field(Map(DL1CameraContainer), "map of tel_id to DL1CameraContainer") + tel = Field( + default_factory=partial(Map, DL1CameraContainer), + description="map of tel_id to DL1CameraContainer", + ) class DL1CameraCalibrationContainer(Container): @@ -402,7 +419,10 @@ class R0Container(Container): Storage of a Merged Raw Data Event """ - tel = Field(Map(R0CameraContainer), "map of tel_id to R0CameraContainer") + tel = Field( + default_factory=partial(Map, R0CameraContainer), + description="map of tel_id to R0CameraContainer", + ) class R1CameraContainer(Container): @@ -431,7 +451,10 @@ class R1Container(Container): Storage of a r1 calibrated Data Event """ - tel = Field(Map(R1CameraContainer), "map of tel_id to R1CameraContainer") + tel = Field( + default_factory=partial(Map, R1CameraContainer), + description="map of tel_id to R1CameraContainer", + ) class DL0CameraContainer(Container): @@ -463,7 +486,10 @@ class DL0Container(Container): Storage of a data volume reduced Event """ - tel = Field(Map(DL0CameraContainer), "map of tel_id to DL0CameraContainer") + tel = Field( + default_factory=partial(Map, DL0CameraContainer), + description="map of tel_id to DL0CameraContainer", + ) class TelescopeImpactParameterContainer(Container): @@ -518,15 +544,23 @@ class SimulatedCameraContainer(Container): ) true_parameters = Field( - None, "Parameters derived from the true_image", type=ImageParametersContainer + None, + description="Parameters derived from the true_image", + type=ImageParametersContainer, ) - impact = Field(TelescopeImpactParameterContainer(), "true impact parameter") + impact = Field( + default_factory=TelescopeImpactParameterContainer, + description="true impact parameter", + ) class SimulatedEventContainer(Container): - shower = Field(SimulatedShowerContainer(), "True event information") - tel = Field(Map(SimulatedCameraContainer)) + shower = Field( + default_factory=SimulatedShowerContainer, + description="True event information", + ) + tel = Field(default_factory=partial(Map, SimulatedCameraContainer)) class SimulationConfigContainer(Container): @@ -534,76 +568,107 @@ class SimulationConfigContainer(Container): Configuration parameters of the simulation """ - corsika_version = Field(nan, "CORSIKA version * 1000") - simtel_version = Field(nan, "sim_telarray version * 1000") + corsika_version = Field(nan, description="CORSIKA version * 1000") + simtel_version = Field(nan, description="sim_telarray version * 1000") energy_range_min = Field( - nan * u.TeV, "Lower limit of energy range of primary particle", unit=u.TeV + nan * u.TeV, + description="Lower limit of energy range of primary particle", + unit=u.TeV, ) energy_range_max = Field( - nan * u.TeV, "Upper limit of energy range of primary particle", unit=u.TeV + nan * u.TeV, + description="Upper limit of energy range of primary particle", + unit=u.TeV, + ) + prod_site_B_total = Field( + nan * u.uT, description="total geomagnetic field", unit=u.uT + ) + prod_site_B_declination = Field( + nan * u.rad, description="magnetic declination", unit=u.rad ) - prod_site_B_total = Field(nan * u.uT, "total geomagnetic field", unit=u.uT) - prod_site_B_declination = Field(nan * u.rad, "magnetic declination", unit=u.rad) - prod_site_B_inclination = Field(nan * u.rad, "magnetic inclination", unit=u.rad) - prod_site_alt = Field(nan * u.m, "height of observation level", unit=u.m) - spectral_index = Field(nan, "Power-law spectral index of spectrum") + prod_site_B_inclination = Field( + nan * u.rad, description="magnetic inclination", unit=u.rad + ) + prod_site_alt = Field( + nan * u.m, description="height of observation level", unit=u.m + ) + spectral_index = Field(nan, description="Power-law spectral index of spectrum") shower_prog_start = Field( - nan, "Time when shower simulation started, CORSIKA: only date" - ) - shower_prog_id = Field(nan, "CORSIKA=1, ALTAI=2, KASCADE=3, MOCCA=4") - detector_prog_start = Field(nan, "Time when detector simulation started") - detector_prog_id = Field(nan, "simtelarray=1") - num_showers = Field(nan, "Number of showers simulated") - shower_reuse = Field(nan, "Numbers of uses of each shower") - max_alt = Field(nan * u.rad, "Maximimum shower altitude", unit=u.rad) - min_alt = Field(nan * u.rad, "Minimum shower altitude", unit=u.rad) - max_az = Field(nan * u.rad, "Maximum shower azimuth", unit=u.rad) - min_az = Field(nan * u.rad, "Minimum shower azimuth", unit=u.rad) - diffuse = Field(False, "Diffuse Mode On/Off") - max_viewcone_radius = Field(nan * u.deg, "Maximum viewcone radius", unit=u.deg) - min_viewcone_radius = Field(nan * u.deg, "Minimum viewcone radius", unit=u.deg) - max_scatter_range = Field(nan * u.m, "Maximum scatter range", unit=u.m) - min_scatter_range = Field(nan * u.m, "Minimum scatter range", unit=u.m) - core_pos_mode = Field(nan, "Core Position Mode (0=Circular, 1=Rectangular)") - injection_height = Field(nan * u.m, "Height of particle injection", unit=u.m) - atmosphere = Field(nan * u.m, "Atmospheric model number") - corsika_iact_options = Field(nan, "CORSIKA simulation options for IACTs") - corsika_low_E_model = Field(nan, "CORSIKA low-energy simulation physics model") + nan, description="Time when shower simulation started, CORSIKA: only date" + ) + shower_prog_id = Field(nan, description="CORSIKA=1, ALTAI=2, KASCADE=3, MOCCA=4") + detector_prog_start = Field( + nan, description="Time when detector simulation started" + ) + detector_prog_id = Field(nan, description="simtelarray=1") + num_showers = Field(nan, description="Number of showers simulated") + shower_reuse = Field(nan, description="Numbers of uses of each shower") + max_alt = Field(nan * u.rad, description="Maximimum shower altitude", unit=u.rad) + min_alt = Field(nan * u.rad, description="Minimum shower altitude", unit=u.rad) + max_az = Field(nan * u.rad, description="Maximum shower azimuth", unit=u.rad) + min_az = Field(nan * u.rad, description="Minimum shower azimuth", unit=u.rad) + diffuse = Field(False, description="Diffuse Mode On/Off") + max_viewcone_radius = Field( + nan * u.deg, description="Maximum viewcone radius", unit=u.deg + ) + min_viewcone_radius = Field( + nan * u.deg, description="Minimum viewcone radius", unit=u.deg + ) + max_scatter_range = Field(nan * u.m, description="Maximum scatter range", unit=u.m) + min_scatter_range = Field(nan * u.m, description="Minimum scatter range", unit=u.m) + core_pos_mode = Field( + nan, description="Core Position Mode (0=Circular, 1=Rectangular)" + ) + injection_height = Field( + nan * u.m, description="Height of particle injection", unit=u.m + ) + atmosphere = Field(nan * u.m, description="Atmospheric model number") + corsika_iact_options = Field( + nan, description="CORSIKA simulation options for IACTs" + ) + corsika_low_E_model = Field( + nan, description="CORSIKA low-energy simulation physics model" + ) corsika_high_E_model = Field( nan, "CORSIKA physics model ID for high energies " "(1=VENUS, 2=SIBYLL, 3=QGSJET, 4=DPMJET, 5=NeXus, 6=EPOS) ", ) - corsika_bunchsize = Field(nan, "Number of Cherenkov photons per bunch") + corsika_bunchsize = Field(nan, description="Number of Cherenkov photons per bunch") corsika_wlen_min = Field( - nan * u.m, "Minimum wavelength of cherenkov light", unit=u.nm + nan * u.m, description="Minimum wavelength of cherenkov light", unit=u.nm ) corsika_wlen_max = Field( - nan * u.m, "Maximum wavelength of cherenkov light", unit=u.nm + nan * u.m, description="Maximum wavelength of cherenkov light", unit=u.nm ) corsika_low_E_detail = Field( - nan, "More details on low E interaction model (version etc.)" + nan, description="More details on low E interaction model (version etc.)" ) corsika_high_E_detail = Field( - nan, "More details on high E interaction model (version etc.)" + nan, description="More details on high E interaction model (version etc.)" ) class TelescopeTriggerContainer(Container): container_prefix = "" - time = Field(NAN_TIME, "Telescope trigger time") - n_trigger_pixels = Field(-1, "Number of trigger groups (sectors) listed") - trigger_pixels = Field(None, "pixels involved in the camera trigger") + time = Field(NAN_TIME, description="Telescope trigger time") + n_trigger_pixels = Field( + -1, description="Number of trigger groups (sectors) listed" + ) + trigger_pixels = Field(None, description="pixels involved in the camera trigger") class TriggerContainer(Container): container_prefix = "" - time = Field(NAN_TIME, "central average time stamp") + time = Field(NAN_TIME, description="central average time stamp") tels_with_trigger = Field( - None, "List of telescope ids that triggered the array event" + None, description="List of telescope ids that triggered the array event" + ) + event_type = Field(EventType.SUBARRAY, description="Event type") + tel = Field( + default_factory=partial(Map, TelescopeTriggerContainer), + description="telescope-wise trigger information", ) - event_type = Field(EventType.SUBARRAY, "Event type") - tel = Field(Map(TelescopeTriggerContainer), "telescope-wise trigger information") class ReconstructedGeometryContainer(Container): @@ -722,16 +787,16 @@ class ReconstructedContainer(Container): # but most will compute only fill or two of these sub-Contaiers: geometry = Field( - Map(ReconstructedGeometryContainer), - "map of algorithm to reconstructed shower parameters", + default_factory=partial(Map, ReconstructedGeometryContainer), + description="map of algorithm to reconstructed shower parameters", ) energy = Field( - Map(ReconstructedEnergyContainer), - "map of algorithm to reconstructed energy parameters", + default_factory=partial(Map, ReconstructedEnergyContainer), + description="map of algorithm to reconstructed energy parameters", ) classification = Field( - Map(ParticleClassificationContainer), - "map of algorithm to classification parameters", + default_factory=partial(Map, ParticleClassificationContainer), + description="map of algorithm to classification parameters", ) @@ -739,8 +804,8 @@ class TelescopeReconstructedContainer(ReconstructedContainer): """Telescope-wise reconstructed quantities""" impact = Field( - Map(TelescopeImpactParameterContainer), - "map of algorithm to impact parameter info", + default_factory=partial(Map, TelescopeImpactParameterContainer), + description="map of algorithm to impact parameter info", ) @@ -751,10 +816,13 @@ class DL2Container(Container): """ tel = Field( - Map(TelescopeReconstructedContainer), - "map of tel_id to single-telescope reconstruction (DL2a)", + default_factory=partial(Map, TelescopeReconstructedContainer), + description="map of tel_id to single-telescope reconstruction (DL2a)", + ) + stereo = Field( + default_factory=ReconstructedContainer, + description="Stereo Shower reconstruction results", ) - stereo = Field(ReconstructedContainer(), "Stereo Shower reconstruction results") class TelescopePointingContainer(Container): @@ -770,7 +838,10 @@ class TelescopePointingContainer(Container): class PointingContainer(Container): - tel = Field(Map(TelescopePointingContainer), "Telescope pointing positions") + tel = Field( + default_factory=partial(Map, TelescopePointingContainer), + description="Telescope pointing positions", + ) array_azimuth = Field(nan * u.rad, "Array pointing azimuth", unit=u.rad) array_altitude = Field(nan * u.rad, "Array pointing altitude", unit=u.rad) array_ra = Field(nan * u.rad, "Array pointing right ascension", unit=u.rad) @@ -783,7 +854,8 @@ class EventCameraCalibrationContainer(Container): """ dl1 = Field( - DL1CameraCalibrationContainer(), "Container for DL1 calibration coefficients" + default_factory=DL1CameraCalibrationContainer, + description="Container for DL1 calibration coefficients", ) @@ -794,8 +866,8 @@ class EventCalibrationContainer(Container): # create the camera container tel = Field( - Map(EventCameraCalibrationContainer), - "map of tel_id to EventCameraCalibrationContainer", + default_factory=partial(Map, EventCameraCalibrationContainer), + description="map of tel_id to EventCameraCalibrationContainer", ) @@ -983,13 +1055,21 @@ class MonitoringCameraContainer(Container): Container for camera monitoring data """ - flatfield = Field(FlatFieldContainer(), "Data from flat-field event distributions") - pedestal = Field(PedestalContainer(), "Data from pedestal event distributions") + flatfield = Field( + default_factory=FlatFieldContainer, + description="Data from flat-field event distributions", + ) + pedestal = Field( + default_factory=PedestalContainer, + description="Data from pedestal event distributions", + ) pixel_status = Field( - PixelStatusContainer(), "Container for masks with pixel status" + default_factory=PixelStatusContainer, + description="Container for masks with pixel status", ) calibration = Field( - WaveformCalibrationContainer(), "Container for calibration coefficients" + default_factory=WaveformCalibrationContainer, + description="Container for calibration coefficients", ) @@ -1000,7 +1080,8 @@ class MonitoringContainer(Container): # create the camera container tel = Field( - Map(MonitoringCameraContainer), "map of tel_id to MonitoringCameraContainer" + default_factory=partial(Map, MonitoringCameraContainer), + description="map of tel_id to MonitoringCameraContainer", ) @@ -1012,35 +1093,53 @@ class SimulatedShowerDistribution(Container): container_prefix = "" - obs_id = Field(-1, "links to which events this corresponds to") - hist_id = Field(-1, "Histogram ID") - num_entries = Field(-1, "Number of entries in the histogram") + obs_id = Field(-1, description="links to which events this corresponds to") + hist_id = Field(-1, description="Histogram ID") + num_entries = Field(-1, description="Number of entries in the histogram") bins_energy = Field( - None, "array of energy bin lower edges, as in np.histogram", unit=u.TeV + None, + description="array of energy bin lower edges, as in np.histogram", + unit=u.TeV, ) bins_core_dist = Field( - None, "array of core-distance bin lower edges, as in np.histogram", unit=u.m + None, + description="array of core-distance bin lower edges, as in np.histogram", + unit=u.m, + ) + histogram = Field( + None, description="array of histogram entries, size (n_bins_x, n_bins_y)" ) - histogram = Field(None, "array of histogram entries, size (n_bins_x, n_bins_y)") class ArrayEventContainer(Container): """Top-level container for all event information""" - index = Field(EventIndexContainer(), "event indexing information") - r0 = Field(R0Container(), "Raw Data") - r1 = Field(R1Container(), "R1 Calibrated Data") - dl0 = Field(DL0Container(), "DL0 Data Volume Reduced Data") - dl1 = Field(DL1Container(), "DL1 Calibrated image") - dl2 = Field(DL2Container(), "DL2 reconstruction info") + index = Field( + default_factory=EventIndexContainer, description="event indexing information" + ) + r0 = Field(default_factory=R0Container, description="Raw Data") + r1 = Field(default_factory=R1Container, description="R1 Calibrated Data") + dl0 = Field( + default_factory=DL0Container, description="DL0 Data Volume Reduced Data" + ) + dl1 = Field(default_factory=DL1Container, description="DL1 Calibrated image") + dl2 = Field(default_factory=DL2Container, description="DL2 reconstruction info") simulation = Field( - None, "Simulated Event Information", type=SimulatedEventContainer + None, description="Simulated Event Information", type=SimulatedEventContainer + ) + trigger = Field( + default_factory=TriggerContainer, description="central trigger information" + ) + count = Field(0, description="number of events processed") + pointing = Field( + default_factory=PointingContainer, + description="Array and telescope pointing positions", ) - trigger = Field(TriggerContainer(), "central trigger information") - count = Field(0, "number of events processed") - pointing = Field(PointingContainer(), "Array and telescope pointing positions") calibration = Field( - EventCalibrationContainer(), - "Container for calibration coefficients for the current event", + default_factory=EventCalibrationContainer, + description="Container for calibration coefficients for the current event", + ) + mon = Field( + default_factory=MonitoringContainer, + description="container for event-wise monitoring data (MON)", ) - mon = Field(MonitoringContainer(), "container for event-wise monitoring data (MON)") diff --git a/ctapipe/core/container.py b/ctapipe/core/container.py index 68322fc7626..5ae30d2942c 100644 --- a/ctapipe/core/container.py +++ b/ctapipe/core/container.py @@ -1,5 +1,5 @@ from collections import defaultdict -from copy import deepcopy +from functools import partial from pprint import pformat from textwrap import wrap, dedent import warnings @@ -15,6 +15,10 @@ __all__ = ["Container", "Field", "FieldValidationError", "Map"] +def _fqdn(obj): + return f"{obj.__module__}.{obj.__qualname__}" + + class FieldValidationError(ValueError): pass @@ -26,8 +30,10 @@ class Field: Parameters ---------- default : - default value of the item. This will be set when the `Container` - is constructed, as well as when ``Container.reset`` is called + Default value of the item. This will be set when the `Container` + is constructed, as well as when ``Container.reset`` is called. + This should only be used for immutable values. For mutable values, + use ``default_factory`` instead. description : str Help text associated with the item unit : str or astropy.units.core.UnitBase @@ -46,6 +52,8 @@ class Field: max_len : int if type is str, max_len is the maximum number of bytes of the utf-8 encoded string to be used. + default_factory : Callable + A callable providing a fresh instance as default value. """ def __init__( @@ -59,8 +67,10 @@ def __init__( ndim=None, allow_none=True, max_length=None, + default_factory=None, ): self.default = default + self.default_factory = default_factory self.description = description self.unit = Unit(unit) if unit is not None else None self.ucd = ucd @@ -70,15 +80,23 @@ def __init__( self.allow_none = allow_none self.max_length = max_length + if default_factory is not None and default is not None: + raise ValueError("Must only provide one of default or default_factory") + def __repr__(self): - if isinstance(self.default, Container): - default = f"{self.default.__class__.__name__}" - elif isinstance(self.default, Map): - if isclass(self.default.default_factory): - cls = self.default.default_factory - default = f"Map({cls.__module__}.{cls.__name__})" + if self.default_factory is not None: + if isclass(self.default_factory): + default = _fqdn(self.default_factory) + elif isinstance(self.default_factory, partial): + # case for `partial(Map, Container)` + cls = _fqdn(self.default_factory.args[0]) + if self.default_factory.func is Map: + func = "Map" + else: + func = repr(self.default_factory.func) + default = f"{func}({cls})" else: - default = f"Map({repr(self.default.default_factory)}" + default = str(self.default_factory()) else: default = str(self.default) cmps = [f"Field(default={default}"] @@ -295,10 +313,11 @@ def __init__(self, **fields): for k in set(self.fields).difference(fields): # deepcopy of None is surprisingly slow - default = self.fields[k].default - if default is not None: - default = deepcopy(default) - + field = self.fields[k] + if field.default_factory is not None: + default = field.default_factory() + else: + default = field.default setattr(self, k, default) for k, v in fields.items(): @@ -362,7 +381,7 @@ def as_dict(self, recursive=False, flatten=False, add_prefix=False): d[key] = val return d - def reset(self, recursive=True): + def reset(self): """ Reset all values back to their default values @@ -372,12 +391,11 @@ def reset(self, recursive=True): If true, also reset all sub-containers """ - for name, value in self.fields.items(): - if isinstance(value, Container): - if recursive: - getattr(self, name).reset() + for name, field in self.fields.items(): + if field.default_factory is not None: + setattr(self, name, field.default_factory()) else: - setattr(self, name, deepcopy(self.fields[name].default)) + setattr(self, name, field.default) def update(self, **values): """ @@ -458,3 +476,10 @@ def reset(self, recursive=True): for val in self.values(): if isinstance(val, Container): val.reset(recursive=recursive) + + def __repr__(self): + if isclass(self.default_factory): + default = _fqdn(self.default_factory) + else: + default = repr(self.default_factory) + return f"{self.__class__.__name__}({default}, {dict.__repr__(self)!s})" diff --git a/ctapipe/core/tests/test_container.py b/ctapipe/core/tests/test_container.py index 0c0e0a01818..8a93dda323e 100644 --- a/ctapipe/core/tests/test_container.py +++ b/ctapipe/core/tests/test_container.py @@ -126,7 +126,7 @@ class ChildContainer(Container): class ParentContainer(Container): x = Field(0, "some value") - child = Field(ChildContainer(), "a child") + child = Field(default_factory=ChildContainer, description="a child") cont = ParentContainer() assert cont.child.z == 1 @@ -138,7 +138,7 @@ class ChildContainer(Container): class ParentContainer(Container): x = Field(0, "some value") - children = Field(Map(), "map of tel_id to child") + children = Field(default_factory=Map, description="map of tel_id to child") cont = ParentContainer() cont.children[10] = ChildContainer() @@ -157,7 +157,7 @@ class ChildContainer(Container): class ParentContainer(Container): x = Field(0, "some value") - child = Field(ChildContainer(), "a child") + child = Field(default_factory=ChildContainer, description="a child") cont = ParentContainer() @@ -255,7 +255,7 @@ def test_field_validation(): def test_container_validation(): - """ check that we can validate all fields in a container""" + """check that we can validate all fields in a container""" class MyContainer(Container): x = Field(3.2, "test", unit="m") From a91ffbf1ce52e7625edfc103137255847c752e6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20N=C3=B6the?= Date: Sat, 2 Jul 2022 14:27:54 +0200 Subject: [PATCH 2/2] Update container example notebook --- docs/examples/containers.ipynb | 141 ++++++++++++++++++++++++++++----- 1 file changed, 119 insertions(+), 22 deletions(-) diff --git a/docs/examples/containers.ipynb b/docs/examples/containers.ipynb index ca7a3ac0667..b9398789ffb 100644 --- a/docs/examples/containers.ipynb +++ b/docs/examples/containers.ipynb @@ -17,7 +17,8 @@ "source": [ "from ctapipe.core import Container, Field, Map\n", "import numpy as np\n", - "from astropy import units as u" + "from astropy import units as u\n", + "from functools import partial" ] }, { @@ -34,17 +35,35 @@ "outputs": [], "source": [ "class SubContainer(Container):\n", - " junk = Field(\"nothing\",\"Some junk\")\n", + " junk = Field(-1, \"Some junk\")\n", " value = Field(0.0, \"some value\", unit=u.deg)\n", "\n", + " \n", + "class TelContainer(Container):\n", + " # defaults should match the other requirements, e.g. the defaults\n", + " # should have the correct unit. It most often also makes sense to use\n", + " # an invalid value marker like nan for floats or -1 for positive integers\n", + " # as default\n", + " tel_id = Field(-1, \"telescope ID number\")\n", + " \n", + " \n", + " # For mutable structures like lists, arrays or containers, use a `default_factory` function or class\n", + " # not an instance to assure each container gets a fresh instance and there is no hidden \n", + " # shared state between containers.\n", + " image = Field(default_factory=lambda: np.zeros(10), description=\"camera pixel data\")\n", + "\n", + "\n", "class EventContainer(Container):\n", " event_id = Field(-1,\"event id number\")\n", - " tels_with_data = Field([], \"list of telescopes with data\")\n", - " sub = Field(SubContainer(), \"stuff\") # a sub-container in the hierarchy\n", "\n", - " # for dicts of sub-containers, use Map instead \n", - " # of a dict() as the default value to support serialization\n", - " tel = Field(Map(), \"telescopes\") \n" + " tels_with_data = Field(default_factory=list, description=\"list of telescopes with data\")\n", + " sub = Field(default_factory=SubContainer, description=\"stuff\") # a sub-container in the hierarchy\n", + "\n", + " # A Map is like a defaultdictionary with a specific container type as default.\n", + " # This can be used to e.g. store a container per telescope\n", + " # we use partial here to automatically get a function that creates a map with the correct container type\n", + " # as default\n", + " tel = Field(default_factory=partial(Map, TelContainer) , description=\"telescopes\") " ] }, { @@ -77,15 +96,19 @@ "outputs": [], "source": [ "print(ev.event_id)\n", + "print(ev.sub)\n", + "print(ev.tel)\n", "print(ev.tel.keys())\n", - "print(ev.tel)" + "\n", + "# default dict access will create container:\n", + "print(ev.tel[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "print the json representation" + "print the dict representation" ] }, { @@ -97,6 +120,33 @@ "print(ev)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We also get docstrings \"for free\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "EventContainer?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "SubContainer?" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -136,7 +186,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now, let's define a sub-container that we can add per telescope:" + "and we can add a few of these to the parent container inside the tel dict:" ] }, { @@ -145,17 +195,26 @@ "metadata": {}, "outputs": [], "source": [ - "class TelContainer(Container):\n", - " tel_id = Field(-1, \"telescope ID number\")\n", - " image = Field(np.zeros(10), \"camera pixel data\")\n", - "\n" + "ev.tel[10] = TelContainer()\n", + "ev.tel[5] = TelContainer()\n", + "ev.tel[42] = TelContainer()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# because we are using a default_factory to handle mutable defaults, the images are actually different:\n", + "ev.tel[42].image is ev.tel[32]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "and we can add a few of these to the parent container inside the tel dict:" + "Be careful to use the `default_factory` mechanism for mutable fields, see this **negative** example:" ] }, { @@ -164,9 +223,21 @@ "metadata": {}, "outputs": [], "source": [ - "ev.tel[10] = TelContainer()\n", - "ev.tel[5] = TelContainer()\n", - "ev.tel[42] = TelContainer()" + "class DangerousContainer(Container):\n", + " image = Field(\n", + " np.zeros(10),\n", + " description=\"Attention!!!! Globally mutable shared state. Use default_factory instead\"\n", + " )\n", + " \n", + " \n", + "c1 = DangerousContainer()\n", + "c2 = DangerousContainer()\n", + "\n", + "c1.image[5] = 9999\n", + "\n", + "print(c1.image)\n", + "print(c2.image)\n", + "print(c1.image is c2.image)" ] }, { @@ -243,7 +314,7 @@ "outputs": [], "source": [ "ev.reset()\n", - "ev.as_dict(recursive=True, flatten=True)" + "ev.as_dict(recursive=True)" ] }, { @@ -268,7 +339,7 @@ "metadata": {}, "outputs": [], "source": [ - "shower = SimulatedShowerContainer()" + "SimulatedShowerContainer?" ] }, { @@ -279,14 +350,40 @@ }, "outputs": [], "source": [ + "shower = SimulatedShowerContainer()\n", "shower" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Container prefixes\n", + "\n", + "To store the same container in the same table in a file or give more information, containers support setting\n", + "a custom prefix:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c1 = SubContainer(junk=5, value=3, prefix=\"foo\")\n", + "c2 = SubContainer(junk=10, value=9001, prefix=\"bar\")\n", + "\n", + "# create a common dict with data from both containers:\n", + "d = c1.as_dict(add_prefix=True)\n", + "d.update(c2.as_dict(add_prefix=True))\n", + "d" + ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -300,7 +397,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.8" + "version": "3.9.13" } }, "nbformat": 4,