Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance type checking #382

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 75 additions & 9 deletions cobaya/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from inspect import cleandoc
from packaging import version
from importlib import import_module, resources
from typing import Optional, Union, List, Set
from typing import ClassVar, ForwardRef, Optional, Union, List, Set

from cobaya.log import HasLogger, LoggedError, get_logger
from cobaya.typing import Any, InfoDict, InfoDictIn, empty_dict
from cobaya.typing import Any, InfoDict, InfoDictIn, ParamDict, empty_dict
from cobaya.tools import resolve_packages_path, load_module, get_base_classes, \
get_internal_class_component_name, deepcopy_where_possible, VersionCheckError
get_internal_class_component_name, deepcopy_where_possible, NumberWithUnits, VersionCheckError
from cobaya.conventions import kinds, cobaya_package, reserved_attributes
from cobaya.yaml import yaml_load_file, yaml_dump, yaml_load
from cobaya.mpi import is_main_process
Expand Down Expand Up @@ -278,7 +278,7 @@ def get_defaults(cls, return_yaml=False, yaml_expand_defaults=True,
"(type declarations without values are fine "
"with yaml file as well).",
cls.get_qualified_class_name(), list(both))
options |= yaml_options
options.update(yaml_options)
yaml_text = None
if return_yaml and not yaml_expand_defaults:
return yaml_text or ""
Expand Down Expand Up @@ -331,6 +331,8 @@ class CobayaComponent(HasLogger, HasDefaults):
_at_resume_prefer_new: List[str] = ["version"]
_at_resume_prefer_old: List[str] = []

enforce_types: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should start with _ as not something we want changed by yaml

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this. Indeed, adding the underscore we cannot enforce types neither through the yaml nor through the info dict (as I do in the test for instance). So, I guess this becomes a question of what default behavior we want to use. If the default value of enforce_types is False, the user can only enforce types on components he/she defines with that flag to True. But, for example, types of built-in samplers and such cannot be checked (except if we add in their definition the opposite flag). This may be a good choice since I guess that types of built-in stuff are checked anyway, in a way of the other. Vice versa, with enforce_types = True by default, everything will be explicitly checked and there is no need for the user to access that.

Am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would probably have to be False by default to keep people from hitting annoying errors and compatibility with existing external likelihoods. Was thinking of this more as an option for likelihood developers who support accurate type hint in their code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok great, then I'll make it private and default at False 👍


def __init__(self, info: InfoDictIn = empty_dict,
name: Optional[str] = None,
timing: Optional[bool] = None,
Expand All @@ -349,7 +351,7 @@ def __init__(self, info: InfoDictIn = empty_dict,
# set attributes from the info (from yaml file or directly input dictionary)
annotations = self.get_annotations()
for k, value in info.items():
self.validate_info(k, value, annotations)
self.validate_bool(k, value, annotations)
try:
setattr(self, k, value)
except AttributeError:
Expand All @@ -366,6 +368,9 @@ def __init__(self, info: InfoDictIn = empty_dict,
" are set (%s, %s)", self, e)
raise

if self.enforce_types:
self.validate_attributes()

def set_timing_on(self, on):
self.timer = Timer() if on else None

Expand Down Expand Up @@ -412,7 +417,7 @@ def has_version(self):
"""
return True

def validate_info(self, k: str, value: Any, annotations: dict):
def validate_bool(self, name: str, value: Any, annotations: dict):
"""
Does any validation on parameter k read from an input dictionary or yaml file,
before setting the corresponding class attribute.
Expand All @@ -423,10 +428,71 @@ def validate_info(self, k: str, value: Any, annotations: dict):
:param annotations: resolved inherited dictionary of attributes for this class
"""

# by default just test booleans, e.g. for typos of "false" which evaluate true
if annotations.get(k) is bool and value and isinstance(value, str):
if annotations.get(name) is bool and value and isinstance(value, str):
raise AttributeError("Class '%s' parameter '%s' should be True "
"or False, got '%s'" % (self, k, value))
"or False, got '%s'" % (self, name, value))

def validate_info(self, name: str, value: Any, annotations: dict):
if name in annotations:
expected_type = annotations[name]
if not self._validate_type(expected_type, value):
msg = f"Attribute '{name}' must be of type {expected_type}, not {type(value)}(value={value})"
raise TypeError(msg)

def _validate_composite_type(self, expected_type, value):
origin = expected_type.__origin__
try: # for Callable and Sequence types, which have no __args__
args = expected_type.__args__
except AttributeError:
pass

if origin is Union:
return any(self._validate_type(t, value) for t in args)
elif origin is Optional:
return value is None or self._validate_type(args[0], value)
elif origin is list:
return all(self._validate_type(args[0], item) for item in value)
elif origin is dict:
return all(
self._validate_type(args[0], k) and self._validate_type(args[1], v)
for k, v in value.items()
)
elif origin is tuple:
return len(args) == len(value) and all(
self._validate_type(t, v) for t, v in zip(args, value)
)
elif origin is ClassVar:
return self._validate_type(args[0], value)
else:
return isinstance(value, origin)

def _validate_type(self, expected_type, value):
if value is None or expected_type is Any: # Any is always valid
return True

if hasattr(expected_type, "__origin__"):
return self._validate_composite_type(expected_type, value)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are only the values general here, not also the expected_types?
e.g. any Mapping type could accept any Mapping value? (e.g. empty_dict is MappingProxyType)

Note also in numpy can also end up with "numbers" that are zero-rank arrays like np.array(1), which I suspect may not pass isinstance(Real), though not sure if that's ever an issue for setting parameters.

# Exceptions for types that are not exactly the same
if isinstance(expected_type, ForwardRef): # for custom types as ParamDict
if "Dict" in expected_type.__forward_arg__:
expected_type = dict
elif expected_type is ParamDict or "ParamDict" == str(expected_type):
return isinstance(value, dict)
if expected_type is NumberWithUnits:
return isinstance(value, (int, float))
elif expected_type is int: # for numpy integers
if value == float('inf'): # for infinite values parsed as floats
return isinstance(value, float)
return isinstance(value, int) or "numpy.int" in str(type(value))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can simplify some of this using generic numbers.Number types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about using some more generic types, I will commit that soon.

elif expected_type is float: # for ints that can be floats
return isinstance(value, (int, float))
return isinstance(value, expected_type)

def validate_attributes(self):
annotations = self.get_annotations()
for name in annotations.keys():
self.validate_info(name, getattr(self, name, None), annotations)

@classmethod
def get_kind(cls):
Expand Down
2 changes: 1 addition & 1 deletion cobaya/cosmo_input/convert_cosmomc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def cosmomc_root_to_cobaya_info_dict(root: str, derived_to_input=()) -> InputDic
name = name.replace('chi2_', 'chi2__')
if name.startswith('minuslogprior') or name == 'chi2':
continue
param_dict: ParamDict = {'latex': par.label}
param_dict: 'ParamDict' = {'latex': par.label}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actual type should be better than deferred where possible

Copy link
Contributor Author

@ggalloni ggalloni Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed completely deferring 'ParamDict'. This comes at the cost of moving the definition of ParamsDict in typing a few lines below, after the definition of ParamDict. Let me know what you think of this

d[name] = param_dict
if par.renames:
param_dict['renames'] = par.renames
Expand Down
2 changes: 1 addition & 1 deletion cobaya/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def logpost(self,
return self.logposterior(params_values, make_finite=make_finite,
return_derived=False, cached=cached).logpost

def get_valid_point(self, max_tries: int, ignore_fixed_ref: bool = False,
def get_valid_point(self, max_tries: Union[int, str], ignore_fixed_ref: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has to be int as used in the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @cmbant, thanks for the feedback on this!

Indeed I think that, technically, that could be a float, as it is used just for evaluations like if tries<max_tries. Still, I wouldn't feel comfortable with having a floating max_tries, right?

Also, that number comes from the value of a NumberWithUnits, already parsed, so that should be either a float or an int. Thus, I would drop the Union with str (going back as it was before).

logposterior_as_dict: bool = False, random_state=None,
) -> Union[Tuple[np.ndarray, LogPosterior],
Tuple[np.ndarray, dict]]:
Expand Down
4 changes: 2 additions & 2 deletions cobaya/parameterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def is_derived_param(info_param: ParamInput) -> bool:
return expand_info_param(info_param).get("derived", False) is not False


def expand_info_param(info_param: ParamInput, default_derived=True) -> ParamDict:
def expand_info_param(info_param: ParamInput, default_derived=True) -> 'ParamDict':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops, probably a search and replace gone wrong...

"""
Expands the info of a parameter, from the user-friendly, shorter format
to a more unambiguous one.
Expand Down Expand Up @@ -76,7 +76,7 @@ def expand_info_param(info_param: ParamInput, default_derived=True) -> ParamDict
return info_param


def reduce_info_param(info_param: ParamDict) -> ParamInput:
def reduce_info_param(info_param: 'ParamDict') -> ParamInput:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all of these, see above

"""
Compresses the info of a parameter, suppressing default values.
This is the opposite of :func:`~input.expand_info_param`.
Expand Down
6 changes: 3 additions & 3 deletions cobaya/samplers/mcmc/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class MCMC(CovmatSampler):

# instance variables from yaml
burn_in: NumberWithUnits
learn_every: NumberWithUnits
output_every: NumberWithUnits
learn_every: Union[NumberWithUnits, str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why only some of these changed. In terms of usage, optional str is probably not very helpful since should be converted at read in, maybe NumberWithUnits and more flexible check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the only ones causing errors, but, following your suggestion, I moved the problem on the type-checking side as we know the NumberWithUnits can come in as a string.

output_every: Union[NumberWithUnits, str]
callback_every: NumberWithUnits
temperature: float
max_tries: NumberWithUnits
max_tries: Union[NumberWithUnits, str]
max_samples: int
drag: bool
callback_function: Optional[Callable]
Expand Down
4 changes: 2 additions & 2 deletions cobaya/samplers/polychord/polychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ class polychord(Sampler):

# variables from yaml
do_clustering: bool
num_repeats: int
num_repeats: Union[int, str]
confidence_for_unbounded: float
callback_function: Callable
blocking: Any
measure_speeds: bool
oversample_power: float
nlive: NumberWithUnits
nlive: Union[str, NumberWithUnits]
path: str
logzero: float
max_ndead: int
Expand Down
2 changes: 1 addition & 1 deletion tests/test_cosmo_multi_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class BinnedPk(Theory):
k_min_bin: float = np.log10(0.001)
k_max_bin: float = np.log10(0.35)
scale: float = 1e-9
bin_par: ParamDict = {'prior': {'min': 0, 'max': 100}}
bin_par: 'ParamDict' = {'prior': {'min': 0, 'max': 100}}

def initialize(self):
self.ks = np.logspace(self.k_min_bin, self.k_max_bin, self.nbins)
Expand Down
1 change: 1 addition & 0 deletions tests/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def test_mcmc_drag_results(temperature):


@pytest.mark.mpionly
@pytest.mark.skip("Setting 'max_samples' to a bad value raises an error which is not caught.")
def test_mcmc_sync():
info: InputDict = yaml_load(yaml)
logger.info('Test end synchronization')
Expand Down
119 changes: 119 additions & 0 deletions tests/test_type_checking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""General test for types of components."""

from typing import Any, ClassVar, Dict, List, Optional, Tuple
import numpy as np
import pytest

from cobaya.component import CobayaComponent
from cobaya.likelihood import Likelihood
from cobaya.tools import NumberWithUnits
from cobaya.typing import InputDict, ParamDict
from cobaya.run import run


class GenericLike(Likelihood):
any: Any
classvar: ClassVar[int] = 1
forwardref_params: "ParamDict" = {"d": [0.0, 1.0]}
infinity: int = float("inf")
mean: NumberWithUnits = 1
noise: float = 0
none: int = None
numpy_int: int = np.int64(1)
optional: Optional[int] = None
paramdict_params: ParamDict = {"c": [0.0, 1.0]}
params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]}
tuple_params: Tuple[float, float] = (0.0, 1.0)

enforce_types = True

def logp(self, **params_values):
return 1


def test_sampler_types():
original_info: InputDict = {
"likelihood": {"like": GenericLike},
"sampler": {"mcmc": {"max_samples": 1, "enforce_types": True}},
}
_ = run(original_info)

info = original_info.copy()
info["sampler"]["mcmc"]["max_samples"] = "not_an_int"
with pytest.raises(TypeError):
run(info)


class GenericComponent(CobayaComponent):
any: Any
classvar: ClassVar[int] = 1
forwardref_params: "ParamDict" = {"d": [0.0, 1.0]}
infinity: int = float("inf")
mean: NumberWithUnits = 1
noise: float = 0
none: int = None
numpy_int: int = np.int64(1)
optional: Optional[int] = None
paramdict_params: ParamDict = {"c": [0.0, 1.0]}
params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]}
tuple_params: Tuple[float, float] = (0.0, 1.0)

enforce_types = True

def __init__(
self,
any,
classvar,
forwardref_params,
infinity,
mean,
noise,
none,
numpy_int,
optional,
paramdict_params,
params,
tuple_params,
):
if self.enforce_types:
super().validate_attributes()


def test_component_types():
correct_kwargs = {
"any": 1,
"classvar": 1,
"forwardref_params": {"d": [0.0, 1.0]},
"infinity": float("inf"),
"mean": 1,
"noise": 0,
"none": None,
"numpy_int": 1,
"optional": 3,
"paramdict_params": {"c": [0.0, 1.0]},
"params": {"a": [0.0, 1.0], "b": [0, 1]},
"tuple_params": (0.0, 1.0),
}
_ = GenericComponent(**correct_kwargs)

wrong_cases = [
{"any": "not_an_int"},
{"classvar": "not_an_int"},
{"forwardref_params": "not_a_paramdict"},
{"infinity": "not_an_int"},
{"mean": "not_a_numberwithunits"},
{"noise": "not_a_float"},
{"none": "not_a_none"},
{"numpy_int": "not_an_int"},
{"paramdict_params": "not_a_paramdict"},
{"params": "not_a_dict"},
{"params": {1: [0.0, 1.0]}},
{"params": {"a": "not_a_list"}},
{"params": {"a": [0.0, "not_a_float"]}},
{"optional": "not_an_int"},
{"tuple_params": "not_a_tuple"},
{"tuple_params": (0.0, "not_a_float")},
]
for case in wrong_cases:
with pytest.raises(TypeError):
_ = GenericComponent({**correct_kwargs, **case})
Loading