Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance type checking #382

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 73 additions & 9 deletions cobaya/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from inspect import cleandoc
from packaging import version
from importlib import import_module, resources
from typing import Optional, Union, List, Set
from numbers import Integral, Real
from typing import ClassVar, ForwardRef, Optional, Union, List, Set

from cobaya.log import HasLogger, LoggedError, get_logger
from cobaya.typing import Any, InfoDict, InfoDictIn, empty_dict
from cobaya.typing import Any, InfoDict, InfoDictIn, ParamDict, empty_dict
from cobaya.tools import resolve_packages_path, load_module, get_base_classes, \
get_internal_class_component_name, deepcopy_where_possible, VersionCheckError
get_internal_class_component_name, deepcopy_where_possible, NumberWithUnits, VersionCheckError
from cobaya.conventions import kinds, cobaya_package, reserved_attributes
from cobaya.yaml import yaml_load_file, yaml_dump, yaml_load
from cobaya.mpi import is_main_process
Expand Down Expand Up @@ -278,7 +279,7 @@ def get_defaults(cls, return_yaml=False, yaml_expand_defaults=True,
"(type declarations without values are fine "
"with yaml file as well).",
cls.get_qualified_class_name(), list(both))
options |= yaml_options
options.update(yaml_options)
yaml_text = None
if return_yaml and not yaml_expand_defaults:
return yaml_text or ""
Expand Down Expand Up @@ -331,6 +332,8 @@ class CobayaComponent(HasLogger, HasDefaults):
_at_resume_prefer_new: List[str] = ["version"]
_at_resume_prefer_old: List[str] = []

_enforce_types: bool = False

def __init__(self, info: InfoDictIn = empty_dict,
name: Optional[str] = None,
timing: Optional[bool] = None,
Expand All @@ -349,7 +352,7 @@ def __init__(self, info: InfoDictIn = empty_dict,
# set attributes from the info (from yaml file or directly input dictionary)
annotations = self.get_annotations()
for k, value in info.items():
self.validate_info(k, value, annotations)
self.validate_bool(k, value, annotations)
try:
setattr(self, k, value)
except AttributeError:
Expand All @@ -366,6 +369,9 @@ def __init__(self, info: InfoDictIn = empty_dict,
" are set (%s, %s)", self, e)
raise

if self._enforce_types:
self.validate_attributes()

def set_timing_on(self, on):
self.timer = Timer() if on else None

Expand Down Expand Up @@ -412,7 +418,7 @@ def has_version(self):
"""
return True

def validate_info(self, k: str, value: Any, annotations: dict):
def validate_bool(self, name: str, value: Any, annotations: dict):
"""
Does any validation on parameter k read from an input dictionary or yaml file,
before setting the corresponding class attribute.
Expand All @@ -423,10 +429,68 @@ def validate_info(self, k: str, value: Any, annotations: dict):
:param annotations: resolved inherited dictionary of attributes for this class
"""

# by default just test booleans, e.g. for typos of "false" which evaluate true
if annotations.get(k) is bool and value and isinstance(value, str):
if annotations.get(name) is bool and value and isinstance(value, str):
raise AttributeError("Class '%s' parameter '%s' should be True "
"or False, got '%s'" % (self, k, value))
"or False, got '%s'" % (self, name, value))

def validate_info(self, name: str, value: Any, annotations: dict):
if name in annotations:
expected_type = annotations[name]
if not self._validate_type(expected_type, value):
msg = f"Attribute '{name}' must be of type {expected_type}, not {type(value)}(value={value})"
raise TypeError(msg)

def _validate_composite_type(self, expected_type, value):
origin = expected_type.__origin__
try: # for Callable and Sequence types, which have no __args__
args = expected_type.__args__
except AttributeError:
pass

if origin is Union:
return any(self._validate_type(t, value) for t in args)
elif origin is Optional:
return value is None or self._validate_type(args[0], value)
elif origin is list:
return all(self._validate_type(args[0], item) for item in value)
elif origin is dict:
return all(
self._validate_type(args[0], k) and self._validate_type(args[1], v)
for k, v in value.items()
)
elif origin is tuple:
return len(args) == len(value) and all(
self._validate_type(t, v) for t, v in zip(args, value)
)
elif origin is ClassVar:
return self._validate_type(args[0], value)
else:
return isinstance(value, origin)

def _validate_type(self, expected_type, value):
if value is None or expected_type is Any: # Any is always valid
return True

if hasattr(expected_type, "__origin__"):
return self._validate_composite_type(expected_type, value)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are only the values general here, not also the expected_types?
e.g. any Mapping type could accept any Mapping value? (e.g. empty_dict is MappingProxyType)

Note also in numpy can also end up with "numbers" that are zero-rank arrays like np.array(1), which I suspect may not pass isinstance(Real), though not sure if that's ever an issue for setting parameters.

# Exceptions for some types
if expected_type is ParamDict:
return isinstance(value, dict)
elif expected_type is int:
if value == float('inf'): # for infinite values parsed as floats
return isinstance(value, float)
return isinstance(value, Integral)
elif expected_type is float:
return isinstance(value, Real)
elif expected_type is NumberWithUnits:
return isinstance(value, Real) or isinstance(value, str)
return isinstance(value, expected_type)

def validate_attributes(self):
annotations = self.get_annotations()
for name in annotations.keys():
self.validate_info(name, getattr(self, name, None), annotations)

@classmethod
def get_kind(cls):
Expand Down
2 changes: 1 addition & 1 deletion cobaya/samplers/polychord/polychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class polychord(Sampler):

# variables from yaml
do_clustering: bool
num_repeats: int
num_repeats: Union[int, str]
confidence_for_unbounded: float
callback_function: Callable
blocking: Any
Expand Down
25 changes: 13 additions & 12 deletions cobaya/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,6 @@
SamplersDict = Dict[str, Optional[SamplerDict]]
PriorsDict = Dict[str, Union[str, Callable]]

# parameters in a params list can be specified on input by
# 1. a ParamDict dictionary
# 2. constant value
# 3. a string giving lambda function of other parameters
# 4. None - must be a computed output parameter
# 5. Sequence specifying uniform prior range [min, max] and optionally
# 'ref' mean and standard deviation for starting positions, and optionally
# proposal width. Allowed lengths, 2, 4, 5
ParamInput = Union['ParamDict', None, str, float, Sequence[float]]
ParamsDict = Dict[str, ParamInput]
ExpandedParamsDict = Dict[str, 'ParamDict']

partags = {"prior", "ref", "proposal", "value", "drop",
"derived", "latex", "renames", "min", "max"}

Expand Down Expand Up @@ -71,6 +59,19 @@ class ParamDict(TypedDict, total=False):
max: float


# parameters in a params list can be specified on input by
# 1. a ParamDict dictionary
# 2. constant value
# 3. a string giving lambda function of other parameters
# 4. None - must be a computed output parameter
# 5. Sequence specifying uniform prior range [min, max] and optionally
# 'ref' mean and standard deviation for starting positions, and optionally
# proposal width. Allowed lengths, 2, 4, 5
ParamInput = Union[ParamDict, None, str, float, Sequence[float]]
ParamsDict = Dict[str, ParamInput]
ExpandedParamsDict = Dict[str, ParamDict]


class ModelDict(TypedDict, total=False):
theory: TheoriesDict
likelihood: LikesDict
Expand Down
1 change: 1 addition & 0 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def test_mcmc_drag_results(temperature):


@pytest.mark.mpionly
@pytest.mark.skip("Setting 'max_samples' to a bad value raises an error which is not caught.")
def test_mcmc_sync():
info: InputDict = yaml_load(yaml)
logger.info('Test end synchronization')
Expand Down
115 changes: 115 additions & 0 deletions tests/test_type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""General test for types of components."""

from typing import Any, ClassVar, Dict, List, Optional, Tuple
import numpy as np
import pytest

from cobaya.component import CobayaComponent
from cobaya.likelihood import Likelihood
from cobaya.tools import NumberWithUnits
from cobaya.typing import InputDict, ParamDict
from cobaya.run import run


class GenericLike(Likelihood):
any: Any
classvar: ClassVar[int] = 1
# forwardref_params: 'ParamDict' = {"d": [0.0, 1.0]}
infinity: int = float("inf")
mean: NumberWithUnits = 1
noise: float = 0
none: int = None
numpy_int: int = np.int64(1)
optional: Optional[int] = None
paramdict_params: ParamDict = {"c": [0.0, 1.0]}
params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]}
tuple_params: Tuple[float, float] = (0.0, 1.0)

_enforce_types = True

def logp(self, **params_values):
return 1


def test_sampler_types():
original_info: InputDict = {
"likelihood": {"like": GenericLike},
"sampler": {"mcmc": {"max_samples": 1}},
}
_ = run(original_info)

info = original_info.copy()
info["sampler"]["mcmc"]["max_samples"] = "not_an_int"
with pytest.raises(TypeError):
run(info)


class GenericComponent(CobayaComponent):
any: Any
classvar: ClassVar[int] = 1
infinity: int = float("inf")
mean: NumberWithUnits = 1
noise: float = 0
none: int = None
numpy_int: int = np.int64(1)
optional: Optional[int] = None
paramdict_params: ParamDict = {"c": [0.0, 1.0]}
params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]}
tuple_params: Tuple[float, float] = (0.0, 1.0)

_enforce_types = True

def __init__(
self,
any,
classvar,
infinity,
mean,
noise,
none,
numpy_int,
optional,
paramdict_params,
params,
tuple_params,
):
if self._enforce_types:
super().validate_attributes()


def test_component_types():
correct_kwargs = {
"any": 1,
"classvar": 1,
"infinity": float("inf"),
"mean": 1,
"noise": 0,
"none": None,
"numpy_int": 1,
"optional": 3,
"paramdict_params": {"c": [0.0, 1.0]},
"params": {"a": [0.0, 1.0], "b": [0, 1]},
"tuple_params": (0.0, 1.0),
}
_ = GenericComponent(**correct_kwargs)

wrong_cases = [
{"any": "not_an_int"},
{"classvar": "not_an_int"},
{"infinity": "not_an_int"},
{"mean": "not_a_numberwithunits"},
{"noise": "not_a_float"},
{"none": "not_a_none"},
{"numpy_int": "not_an_int"},
{"paramdict_params": "not_a_paramdict"},
{"params": "not_a_dict"},
{"params": {1: [0.0, 1.0]}},
{"params": {"a": "not_a_list"}},
{"params": {"a": [0.0, "not_a_float"]}},
{"optional": "not_an_int"},
{"tuple_params": "not_a_tuple"},
{"tuple_params": (0.0, "not_a_float")},
]
for case in wrong_cases:
with pytest.raises(TypeError):
_ = GenericComponent({**correct_kwargs, **case})
Loading