-
Notifications
You must be signed in to change notification settings - Fork 127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance type checking #382
Changes from 13 commits
930db00
71fc109
e4c2971
dc48e84
7a568e7
7bf99d8
61a0405
af5c727
a16b804
726f522
bbb1f6b
8e62704
39b75b4
3fe9b84
b985a40
6ecc79a
00bcbb3
5988a15
4cdd11e
436e70d
19e4143
dad347c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,12 @@ | |
from inspect import cleandoc | ||
from packaging import version | ||
from importlib import import_module, resources | ||
from typing import Optional, Union, List, Set | ||
from typing import ClassVar, ForwardRef, Optional, Union, List, Set | ||
|
||
from cobaya.log import HasLogger, LoggedError, get_logger | ||
from cobaya.typing import Any, InfoDict, InfoDictIn, empty_dict | ||
from cobaya.typing import Any, InfoDict, InfoDictIn, ParamDict, empty_dict | ||
from cobaya.tools import resolve_packages_path, load_module, get_base_classes, \ | ||
get_internal_class_component_name, deepcopy_where_possible, VersionCheckError | ||
get_internal_class_component_name, deepcopy_where_possible, NumberWithUnits, VersionCheckError | ||
from cobaya.conventions import kinds, cobaya_package, reserved_attributes | ||
from cobaya.yaml import yaml_load_file, yaml_dump, yaml_load | ||
from cobaya.mpi import is_main_process | ||
|
@@ -278,7 +278,7 @@ def get_defaults(cls, return_yaml=False, yaml_expand_defaults=True, | |
"(type declarations without values are fine " | ||
"with yaml file as well).", | ||
cls.get_qualified_class_name(), list(both)) | ||
options |= yaml_options | ||
options.update(yaml_options) | ||
yaml_text = None | ||
if return_yaml and not yaml_expand_defaults: | ||
return yaml_text or "" | ||
|
@@ -331,6 +331,8 @@ class CobayaComponent(HasLogger, HasDefaults): | |
_at_resume_prefer_new: List[str] = ["version"] | ||
_at_resume_prefer_old: List[str] = [] | ||
|
||
enforce_types: bool = False | ||
|
||
def __init__(self, info: InfoDictIn = empty_dict, | ||
name: Optional[str] = None, | ||
timing: Optional[bool] = None, | ||
|
@@ -349,7 +351,7 @@ def __init__(self, info: InfoDictIn = empty_dict, | |
# set attributes from the info (from yaml file or directly input dictionary) | ||
annotations = self.get_annotations() | ||
for k, value in info.items(): | ||
self.validate_info(k, value, annotations) | ||
self.validate_bool(k, value, annotations) | ||
try: | ||
setattr(self, k, value) | ||
except AttributeError: | ||
|
@@ -366,6 +368,9 @@ def __init__(self, info: InfoDictIn = empty_dict, | |
" are set (%s, %s)", self, e) | ||
raise | ||
|
||
if self.enforce_types: | ||
self.validate_attributes() | ||
|
||
def set_timing_on(self, on): | ||
self.timer = Timer() if on else None | ||
|
||
|
@@ -412,7 +417,7 @@ def has_version(self): | |
""" | ||
return True | ||
|
||
def validate_info(self, k: str, value: Any, annotations: dict): | ||
def validate_bool(self, name: str, value: Any, annotations: dict): | ||
""" | ||
Does any validation on parameter k read from an input dictionary or yaml file, | ||
before setting the corresponding class attribute. | ||
|
@@ -423,10 +428,71 @@ def validate_info(self, k: str, value: Any, annotations: dict): | |
:param annotations: resolved inherited dictionary of attributes for this class | ||
""" | ||
|
||
# by default just test booleans, e.g. for typos of "false" which evaluate true | ||
if annotations.get(k) is bool and value and isinstance(value, str): | ||
if annotations.get(name) is bool and value and isinstance(value, str): | ||
raise AttributeError("Class '%s' parameter '%s' should be True " | ||
"or False, got '%s'" % (self, k, value)) | ||
"or False, got '%s'" % (self, name, value)) | ||
|
||
def validate_info(self, name: str, value: Any, annotations: dict): | ||
if name in annotations: | ||
expected_type = annotations[name] | ||
if not self._validate_type(expected_type, value): | ||
msg = f"Attribute '{name}' must be of type {expected_type}, not {type(value)}(value={value})" | ||
raise TypeError(msg) | ||
|
||
def _validate_composite_type(self, expected_type, value): | ||
origin = expected_type.__origin__ | ||
try: # for Callable and Sequence types, which have no __args__ | ||
args = expected_type.__args__ | ||
except AttributeError: | ||
pass | ||
|
||
if origin is Union: | ||
return any(self._validate_type(t, value) for t in args) | ||
elif origin is Optional: | ||
return value is None or self._validate_type(args[0], value) | ||
elif origin is list: | ||
return all(self._validate_type(args[0], item) for item in value) | ||
elif origin is dict: | ||
return all( | ||
self._validate_type(args[0], k) and self._validate_type(args[1], v) | ||
for k, v in value.items() | ||
) | ||
elif origin is tuple: | ||
return len(args) == len(value) and all( | ||
self._validate_type(t, v) for t, v in zip(args, value) | ||
) | ||
elif origin is ClassVar: | ||
return self._validate_type(args[0], value) | ||
else: | ||
return isinstance(value, origin) | ||
|
||
def _validate_type(self, expected_type, value): | ||
if value is None or expected_type is Any: # Any is always valid | ||
return True | ||
|
||
if hasattr(expected_type, "__origin__"): | ||
return self._validate_composite_type(expected_type, value) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are only the values general here, not also the expected_types? Note also in numpy can also end up with "numbers" that are zero-rank arrays like np.array(1), which I suspect may not pass isinstance(Real), though not sure if that's ever an issue for setting parameters. |
||
# Exceptions for types that are not exactly the same | ||
if isinstance(expected_type, ForwardRef): # for custom types as ParamDict | ||
if "Dict" in expected_type.__forward_arg__: | ||
expected_type = dict | ||
elif expected_type is ParamDict or "ParamDict" == str(expected_type): | ||
return isinstance(value, dict) | ||
if expected_type is NumberWithUnits: | ||
return isinstance(value, (int, float)) | ||
elif expected_type is int: # for numpy integers | ||
if value == float('inf'): # for infinite values parsed as floats | ||
return isinstance(value, float) | ||
return isinstance(value, int) or "numpy.int" in str(type(value)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can simplify some of this using generic numbers.Number types? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I was thinking about using some more generic types, I will commit that soon. |
||
elif expected_type is float: # for ints that can be floats | ||
return isinstance(value, (int, float)) | ||
return isinstance(value, expected_type) | ||
|
||
def validate_attributes(self): | ||
annotations = self.get_annotations() | ||
for name in annotations.keys(): | ||
self.validate_info(name, getattr(self, name, None), annotations) | ||
|
||
@classmethod | ||
def get_kind(cls): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,7 +37,7 @@ def cosmomc_root_to_cobaya_info_dict(root: str, derived_to_input=()) -> InputDic | |
name = name.replace('chi2_', 'chi2__') | ||
if name.startswith('minuslogprior') or name == 'chi2': | ||
continue | ||
param_dict: ParamDict = {'latex': par.label} | ||
param_dict: 'ParamDict' = {'latex': par.label} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actual type should be better than deferred where possible There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed completely deferring 'ParamDict'. This comes at the cost of moving the definition of |
||
d[name] = param_dict | ||
if par.renames: | ||
param_dict['renames'] = par.renames | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -589,7 +589,7 @@ def logpost(self, | |
return self.logposterior(params_values, make_finite=make_finite, | ||
return_derived=False, cached=cached).logpost | ||
|
||
def get_valid_point(self, max_tries: int, ignore_fixed_ref: bool = False, | ||
def get_valid_point(self, max_tries: Union[int, str], ignore_fixed_ref: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Has to be int as used in the code? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hello @cmbant, thanks for the feedback on this! Indeed I think that, technically, that could be a float, as it is used just for evaluations like Also, that number comes from the |
||
logposterior_as_dict: bool = False, random_state=None, | ||
) -> Union[Tuple[np.ndarray, LogPosterior], | ||
Tuple[np.ndarray, dict]]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,7 +42,7 @@ def is_derived_param(info_param: ParamInput) -> bool: | |
return expand_info_param(info_param).get("derived", False) is not False | ||
|
||
|
||
def expand_info_param(info_param: ParamInput, default_derived=True) -> ParamDict: | ||
def expand_info_param(info_param: ParamInput, default_derived=True) -> 'ParamDict': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ops, probably a search and replace gone wrong... |
||
""" | ||
Expands the info of a parameter, from the user-friendly, shorter format | ||
to a more unambiguous one. | ||
|
@@ -76,7 +76,7 @@ def expand_info_param(info_param: ParamInput, default_derived=True) -> ParamDict | |
return info_param | ||
|
||
|
||
def reduce_info_param(info_param: ParamDict) -> ParamInput: | ||
def reduce_info_param(info_param: 'ParamDict') -> ParamInput: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed all of these, see above |
||
""" | ||
Compresses the info of a parameter, suppressing default values. | ||
This is the opposite of :func:`~input.expand_info_param`. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,11 +56,11 @@ class MCMC(CovmatSampler): | |
|
||
# instance variables from yaml | ||
burn_in: NumberWithUnits | ||
learn_every: NumberWithUnits | ||
output_every: NumberWithUnits | ||
learn_every: Union[NumberWithUnits, str] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure why only some of these changed. In terms of usage, optional str is probably not very helpful since should be converted at read in, maybe NumberWithUnits and more flexible check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed the only ones causing errors, but, following your suggestion, I moved the problem on the type-checking side as we know the |
||
output_every: Union[NumberWithUnits, str] | ||
callback_every: NumberWithUnits | ||
temperature: float | ||
max_tries: NumberWithUnits | ||
max_tries: Union[NumberWithUnits, str] | ||
max_samples: int | ||
drag: bool | ||
callback_function: Optional[Callable] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
"""General test for types of components.""" | ||
|
||
from typing import Any, ClassVar, Dict, List, Optional, Tuple | ||
import numpy as np | ||
import pytest | ||
|
||
from cobaya.component import CobayaComponent | ||
from cobaya.likelihood import Likelihood | ||
from cobaya.tools import NumberWithUnits | ||
from cobaya.typing import InputDict, ParamDict | ||
from cobaya.run import run | ||
|
||
|
||
class GenericLike(Likelihood): | ||
any: Any | ||
classvar: ClassVar[int] = 1 | ||
forwardref_params: "ParamDict" = {"d": [0.0, 1.0]} | ||
infinity: int = float("inf") | ||
mean: NumberWithUnits = 1 | ||
noise: float = 0 | ||
none: int = None | ||
numpy_int: int = np.int64(1) | ||
optional: Optional[int] = None | ||
paramdict_params: ParamDict = {"c": [0.0, 1.0]} | ||
params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]} | ||
tuple_params: Tuple[float, float] = (0.0, 1.0) | ||
|
||
enforce_types = True | ||
|
||
def logp(self, **params_values): | ||
return 1 | ||
|
||
|
||
def test_sampler_types(): | ||
original_info: InputDict = { | ||
"likelihood": {"like": GenericLike}, | ||
"sampler": {"mcmc": {"max_samples": 1, "enforce_types": True}}, | ||
} | ||
_ = run(original_info) | ||
|
||
info = original_info.copy() | ||
info["sampler"]["mcmc"]["max_samples"] = "not_an_int" | ||
with pytest.raises(TypeError): | ||
run(info) | ||
|
||
|
||
class GenericComponent(CobayaComponent): | ||
any: Any | ||
classvar: ClassVar[int] = 1 | ||
forwardref_params: "ParamDict" = {"d": [0.0, 1.0]} | ||
infinity: int = float("inf") | ||
mean: NumberWithUnits = 1 | ||
noise: float = 0 | ||
none: int = None | ||
numpy_int: int = np.int64(1) | ||
optional: Optional[int] = None | ||
paramdict_params: ParamDict = {"c": [0.0, 1.0]} | ||
params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]} | ||
tuple_params: Tuple[float, float] = (0.0, 1.0) | ||
|
||
enforce_types = True | ||
|
||
def __init__( | ||
self, | ||
any, | ||
classvar, | ||
forwardref_params, | ||
infinity, | ||
mean, | ||
noise, | ||
none, | ||
numpy_int, | ||
optional, | ||
paramdict_params, | ||
params, | ||
tuple_params, | ||
): | ||
if self.enforce_types: | ||
super().validate_attributes() | ||
|
||
|
||
def test_component_types(): | ||
correct_kwargs = { | ||
"any": 1, | ||
"classvar": 1, | ||
"forwardref_params": {"d": [0.0, 1.0]}, | ||
"infinity": float("inf"), | ||
"mean": 1, | ||
"noise": 0, | ||
"none": None, | ||
"numpy_int": 1, | ||
"optional": 3, | ||
"paramdict_params": {"c": [0.0, 1.0]}, | ||
"params": {"a": [0.0, 1.0], "b": [0, 1]}, | ||
"tuple_params": (0.0, 1.0), | ||
} | ||
_ = GenericComponent(**correct_kwargs) | ||
|
||
wrong_cases = [ | ||
{"any": "not_an_int"}, | ||
{"classvar": "not_an_int"}, | ||
{"forwardref_params": "not_a_paramdict"}, | ||
{"infinity": "not_an_int"}, | ||
{"mean": "not_a_numberwithunits"}, | ||
{"noise": "not_a_float"}, | ||
{"none": "not_a_none"}, | ||
{"numpy_int": "not_an_int"}, | ||
{"paramdict_params": "not_a_paramdict"}, | ||
{"params": "not_a_dict"}, | ||
{"params": {1: [0.0, 1.0]}}, | ||
{"params": {"a": "not_a_list"}}, | ||
{"params": {"a": [0.0, "not_a_float"]}}, | ||
{"optional": "not_an_int"}, | ||
{"tuple_params": "not_a_tuple"}, | ||
{"tuple_params": (0.0, "not_a_float")}, | ||
] | ||
for case in wrong_cases: | ||
with pytest.raises(TypeError): | ||
_ = GenericComponent({**correct_kwargs, **case}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should start with _ as not something we want changed by yaml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about this. Indeed, adding the underscore we cannot enforce types neither through the
yaml
nor through the info dict (as I do in the test for instance). So, I guess this becomes a question of what default behavior we want to use. If the default value ofenforce_types
isFalse
, the user can only enforce types on components he/she defines with that flag toTrue
. But, for example, types of built-in samplers and such cannot be checked (except if we add in their definition the opposite flag). This may be a good choice since I guess that types of built-in stuff are checked anyway, in a way of the other. Vice versa, withenforce_types = True
by default, everything will be explicitly checked and there is no need for the user to access that.Am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would probably have to be False by default to keep people from hitting annoying errors and compatibility with existing external likelihoods. Was thinking of this more as an option for likelihood developers who support accurate type hint in their code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok great, then I'll make it private and default at
False
👍