Skip to content

Commit

Permalink
recursively validate containers (#1976)
Browse files Browse the repository at this point in the history
* recursively validate containers

* also make this work for Fields containing Maps
  • Loading branch information
kosack authored Jul 4, 2022
1 parent 8adc434 commit 0892e62
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
14 changes: 14 additions & 0 deletions ctapipe/core/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ def validate(self, value):
f"{errorstr} Should be an instance of {self.type}"
)

if isinstance(value, Container):
# recursively check sub-containers
value.validate()
return

if isinstance(value, Map):
for key, map_value in value.items():
if isinstance(map_value, Container):
try:
map_value.validate()
except FieldValidationError as err:
raise FieldValidationError(f"[{key}]: {err} ")
return

if self.unit is not None:
if not isinstance(value, Quantity):
raise FieldValidationError(
Expand Down
27 changes: 26 additions & 1 deletion ctapipe/core/tests/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_field_validation():


def test_container_validation():
""" check that we can validate all fields in a container"""
"""check that we can validate all fields in a container"""

class MyContainer(Container):
x = Field(3.2, "test", unit="m")
Expand All @@ -264,4 +264,29 @@ class MyContainer(Container):
with pytest.raises(FieldValidationError):
MyContainer().validate() # fails since 3.2 has no units

with pytest.raises(FieldValidationError):
MyContainer(x=10 * u.s).validate() # seconds is not convertable to meters

MyContainer(x=6.4 * u.m).validate() # works


def test_recursive_validation():
"""
Check both sub-containers and Maps work with recursive validation
"""

class ChildContainer(Container):
x = Field(3.2 * u.m, "test", unit="m")

class ParentContainer(Container):
cont = Field(None, "test sub", type=ChildContainer)
map = Field(Map(ChildContainer), "many children")

with pytest.raises(FieldValidationError):
ParentContainer(cont=ChildContainer(x=1 * u.s)).validate()

with pytest.raises(FieldValidationError):
cont = ParentContainer(cont=ChildContainer(x=1 * u.m))
cont.map[0] = ChildContainer(x=1 * u.m)
cont.map[1] = ChildContainer(x=1 * u.s)
cont.validate()
6 changes: 4 additions & 2 deletions examples/load_one_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
if len(sys.argv) >= 2:
filename = sys.argv[1]
else:
filename = get_dataset_path("gamma_test_large.simtel.gz")
filename = get_dataset_path(
"gamma_LaPalma_baseline_20Zd_180Az_prod3b_test.simtel.gz"
)

with EventSource(filename, max_events=1) as source:
with EventSource(filename, max_events=1, focal_length_choice="nominal") as source:
calib = CameraCalibrator(subarray=source.subarray)
process_images = ImageProcessor(subarray=source.subarray)
process_shower = ShowerProcessor(subarray=source.subarray)
Expand Down

0 comments on commit 0892e62

Please sign in to comment.