Skip to content

Commit

Permalink
tweak to support deferred
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Nov 6, 2024
1 parent 91e1877 commit 3642d4f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
18 changes: 13 additions & 5 deletions cobaya/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import cleandoc
from packaging import version
from importlib import import_module, resources
from typing import Optional, Union, List, Set
from typing import Optional, Union, List, Set, get_type_hints

from cobaya.log import HasLogger, LoggedError, get_logger
from cobaya.typing import Any, InfoDict, InfoDictIn, empty_dict, validate_type
Expand Down Expand Up @@ -422,7 +422,7 @@ def validate_info(self, name: str, value: Any, annotations: dict):
"""
Does any validation on parameter k read from an input dictionary or yaml file,
before setting the corresponding class attribute.
You could enforce consistency with annotations here, but does not by default.
This check is always done, even if _enforce_types is not set.
:param name: name of parameter
:param value: value
Expand All @@ -434,10 +434,18 @@ def validate_info(self, name: str, value: Any, annotations: dict):
"or False, got '%s'" % (self, name, value))

def validate_attributes(self, annotations: dict):
"""
If _enforce_types or cobaya.typing.enforce_type_checking is set, this
checks all class attributes against the annotation types
:param annotations: resolved inherited dictionary of attributes for this class
:raises: TypeError if any attribute does not match the annotation type
"""
check = cobaya.typing.enforce_type_checking
if check or (self._enforce_types and check is not False):
for name, annotation in annotations.items():
validate_type(annotation, getattr(self, name, None),
if check or self._enforce_types and check is not False:
hints = get_type_hints(self.__class__) # resolve any deferred attributes
for name in annotations:
validate_type(hints[name], getattr(self, name, None),
self.get_name() + ':' + name)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion cobaya/samplers/polychord/polychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class polychord(Sampler):
blocking: Any
measure_speeds: bool
oversample_power: float
nlive: Union[str, NumberWithUnits]
nlive: NumberWithUnits
path: str
logzero: float
max_ndead: int
Expand Down
9 changes: 8 additions & 1 deletion tests/test_type_checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class GenericComponent(CobayaComponent):
array: Sequence[float]
array2: Sequence[float]
map: Mapping[float, str]
deferred: 'ParamDict'
unset = 1

_enforce_types = True

Expand All @@ -43,7 +45,8 @@ def test_component_types():
"tuple_params": (0.0, 1.0),
"array": np.arange(2, dtype=np.float64),
"array2": [1, 2],
"map": {1.0: "a", 2.0: "b"}
"map": {1.0: "a", 2.0: "b"},
"deferred": {'value': lambda x: x},
}
GenericComponent(correct_kwargs)

Expand All @@ -68,3 +71,7 @@ def test_component_types():
for case in wrong_cases:
with pytest.raises(TypeError):
GenericComponent({**correct_kwargs, **case})


class NextComponent(CobayaComponent):
pass

0 comments on commit 3642d4f

Please sign in to comment.