-
Notifications
You must be signed in to change notification settings - Fork 127
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance type checking #382
Closed
Closed
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
930db00
Allow all types for checking
ggalloni 71fc109
Handle exceptions
ggalloni e4c2971
Add `str` when allowed
ggalloni dc48e84
Clean imports
ggalloni 7a568e7
Handle `ParamDict`
ggalloni 7bf99d8
Try skipping test for now
ggalloni 61a0405
Separate bool validation and optional type enforcing
ggalloni af5c727
`ParamDict` -> `ForwardRef['ParamDict']` for type check
ggalloni a16b804
Handle `ClassVar`
ggalloni 726f522
Add test for type checking
ggalloni bbb1f6b
Test for compatibility
ggalloni 8e62704
Test `Optional`
ggalloni 39b75b4
Merge branch 'master' into checking_types
ggalloni 3fe9b84
Remove deferred types
ggalloni b985a40
Remove `ForwardRef` handling
ggalloni 6ecc79a
Handle ints and floats with generic types
ggalloni 00bcbb3
Allow `NumberWithUnits` to be a `str`
ggalloni 5988a15
Remove useless type
ggalloni 4cdd11e
Merge branch 'checking_types' of https://github.com/ggalloni/cobaya i…
ggalloni 436e70d
Change `enforce_types` to private attribute
ggalloni 19e4143
Clean
ggalloni dad347c
Merge branch 'master' into checking_types
ggalloni File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
"""General test for types of components.""" | ||
|
||
from typing import Any, ClassVar, Dict, List, Optional, Tuple | ||
import numpy as np | ||
import pytest | ||
|
||
from cobaya.component import CobayaComponent | ||
from cobaya.likelihood import Likelihood | ||
from cobaya.tools import NumberWithUnits | ||
from cobaya.typing import InputDict, ParamDict | ||
from cobaya.run import run | ||
|
||
|
||
class GenericLike(Likelihood): | ||
any: Any | ||
classvar: ClassVar[int] = 1 | ||
# forwardref_params: 'ParamDict' = {"d": [0.0, 1.0]} | ||
infinity: int = float("inf") | ||
mean: NumberWithUnits = 1 | ||
noise: float = 0 | ||
none: int = None | ||
numpy_int: int = np.int64(1) | ||
optional: Optional[int] = None | ||
paramdict_params: ParamDict = {"c": [0.0, 1.0]} | ||
params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]} | ||
tuple_params: Tuple[float, float] = (0.0, 1.0) | ||
|
||
_enforce_types = True | ||
|
||
def logp(self, **params_values): | ||
return 1 | ||
|
||
|
||
def test_sampler_types(): | ||
original_info: InputDict = { | ||
"likelihood": {"like": GenericLike}, | ||
"sampler": {"mcmc": {"max_samples": 1}}, | ||
} | ||
_ = run(original_info) | ||
|
||
info = original_info.copy() | ||
info["sampler"]["mcmc"]["max_samples"] = "not_an_int" | ||
with pytest.raises(TypeError): | ||
run(info) | ||
|
||
|
||
class GenericComponent(CobayaComponent): | ||
any: Any | ||
classvar: ClassVar[int] = 1 | ||
infinity: int = float("inf") | ||
mean: NumberWithUnits = 1 | ||
noise: float = 0 | ||
none: int = None | ||
numpy_int: int = np.int64(1) | ||
optional: Optional[int] = None | ||
paramdict_params: ParamDict = {"c": [0.0, 1.0]} | ||
params: Dict[str, List[float]] = {"a": [0.0, 1.0], "b": [0, 1]} | ||
tuple_params: Tuple[float, float] = (0.0, 1.0) | ||
|
||
_enforce_types = True | ||
|
||
def __init__( | ||
self, | ||
any, | ||
classvar, | ||
infinity, | ||
mean, | ||
noise, | ||
none, | ||
numpy_int, | ||
optional, | ||
paramdict_params, | ||
params, | ||
tuple_params, | ||
): | ||
if self._enforce_types: | ||
super().validate_attributes() | ||
|
||
|
||
def test_component_types(): | ||
correct_kwargs = { | ||
"any": 1, | ||
"classvar": 1, | ||
"infinity": float("inf"), | ||
"mean": 1, | ||
"noise": 0, | ||
"none": None, | ||
"numpy_int": 1, | ||
"optional": 3, | ||
"paramdict_params": {"c": [0.0, 1.0]}, | ||
"params": {"a": [0.0, 1.0], "b": [0, 1]}, | ||
"tuple_params": (0.0, 1.0), | ||
} | ||
_ = GenericComponent(**correct_kwargs) | ||
|
||
wrong_cases = [ | ||
{"any": "not_an_int"}, | ||
{"classvar": "not_an_int"}, | ||
{"infinity": "not_an_int"}, | ||
{"mean": "not_a_numberwithunits"}, | ||
{"noise": "not_a_float"}, | ||
{"none": "not_a_none"}, | ||
{"numpy_int": "not_an_int"}, | ||
{"paramdict_params": "not_a_paramdict"}, | ||
{"params": "not_a_dict"}, | ||
{"params": {1: [0.0, 1.0]}}, | ||
{"params": {"a": "not_a_list"}}, | ||
{"params": {"a": [0.0, "not_a_float"]}}, | ||
{"optional": "not_an_int"}, | ||
{"tuple_params": "not_a_tuple"}, | ||
{"tuple_params": (0.0, "not_a_float")}, | ||
] | ||
for case in wrong_cases: | ||
with pytest.raises(TypeError): | ||
_ = GenericComponent({**correct_kwargs, **case}) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are only the values general here, not also the expected_types?
e.g. any Mapping type could accept any Mapping value? (e.g. empty_dict is MappingProxyType)
Note also in numpy can also end up with "numbers" that are zero-rank arrays like np.array(1), which I suspect may not pass isinstance(Real), though not sure if that's ever an issue for setting parameters.