Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance type checking #382

Open
wants to merge 21 commits into
base: master
Choose a base branch
from

Conversation

ggalloni
Copy link
Contributor

@ggalloni ggalloni commented Sep 4, 2024

The method validate_info of CobayaComponent, checks only bool.

This PR is enhancing that to check every relevant type, including generic types (List[], Dict[], Tuple[], etc).

@ggalloni
Copy link
Contributor Author

ggalloni commented Sep 4, 2024

The new code is raising a TypeError when max_samples="bad_value", however, the test_mcmc.py (MPI case) is still breaking as if it is not catching that.

Do you have an idea why this could be happening?

@codecov-commenter
Copy link

codecov-commenter commented Sep 4, 2024

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 26.31579% with 42 lines in your changes missing coverage. Please review.

Project coverage is 74.25%. Comparing base (735f7a8) to head (19e4143).

Files with missing lines Patch % Lines
cobaya/component.py 22.64% 41 Missing ⚠️
cobaya/samplers/polychord/polychord.py 0.00% 1 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #382      +/-   ##
==========================================
- Coverage   74.57%   74.25%   -0.33%     
==========================================
  Files         147      147              
  Lines       11200    11247      +47     
==========================================
- Hits         8352     8351       -1     
- Misses       2848     2896      +48     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ggalloni
Copy link
Contributor Author

ggalloni commented Sep 6, 2024

I left the validation of bools as it was and added an enforce_types attribute that will trigger the new code.

In this way, one can force type checking by setting enforce_types=True in any descendent class of CobayaComponent, without touching the old validation.

Copy link
Collaborator

@cmbant cmbant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting to see this. Bit mixed feeling though, a bit complicated than hoping (for something is just checking), not very clear how robust it is.

cobaya/model.py Outdated
@@ -589,7 +589,7 @@ def logpost(self,
return self.logposterior(params_values, make_finite=make_finite,
return_derived=False, cached=cached).logpost

def get_valid_point(self, max_tries: int, ignore_fixed_ref: bool = False,
def get_valid_point(self, max_tries: Union[int, str], ignore_fixed_ref: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has to be int as used in the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @cmbant, thanks for the feedback on this!

Indeed I think that, technically, that could be a float, as it is used just for evaluations like if tries<max_tries. Still, I wouldn't feel comfortable with having a floating max_tries, right?

Also, that number comes from the value of a NumberWithUnits, already parsed, so that should be either a float or an int. Thus, I would drop the Union with str (going back as it was before).

@@ -37,7 +37,7 @@ def cosmomc_root_to_cobaya_info_dict(root: str, derived_to_input=()) -> InputDic
name = name.replace('chi2_', 'chi2__')
if name.startswith('minuslogprior') or name == 'chi2':
continue
param_dict: ParamDict = {'latex': par.label}
param_dict: 'ParamDict' = {'latex': par.label}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actual type should be better than deferred where possible

Copy link
Contributor Author

@ggalloni ggalloni Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed completely deferring 'ParamDict'. This comes at the cost of moving the definition of ParamsDict in typing a few lines below, after the definition of ParamDict. Let me know what you think of this

@@ -331,6 +331,8 @@ class CobayaComponent(HasLogger, HasDefaults):
_at_resume_prefer_new: List[str] = ["version"]
_at_resume_prefer_old: List[str] = []

enforce_types: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should start with _ as not something we want changed by yaml

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this. Indeed, adding the underscore we cannot enforce types neither through the yaml nor through the info dict (as I do in the test for instance). So, I guess this becomes a question of what default behavior we want to use. If the default value of enforce_types is False, the user can only enforce types on components he/she defines with that flag to True. But, for example, types of built-in samplers and such cannot be checked (except if we add in their definition the opposite flag). This may be a good choice since I guess that types of built-in stuff are checked anyway, in a way of the other. Vice versa, with enforce_types = True by default, everything will be explicitly checked and there is no need for the user to access that.

Am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would probably have to be False by default to keep people from hitting annoying errors and compatibility with existing external likelihoods. Was thinking of this more as an option for likelihood developers who support accurate type hint in their code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok great, then I'll make it private and default at False 👍

elif expected_type is int: # for numpy integers
if value == float('inf'): # for infinite values parsed as floats
return isinstance(value, float)
return isinstance(value, int) or "numpy.int" in str(type(value))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can simplify some of this using generic numbers.Number types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about using some more generic types, I will commit that soon.

@@ -42,7 +42,7 @@ def is_derived_param(info_param: ParamInput) -> bool:
return expand_info_param(info_param).get("derived", False) is not False


def expand_info_param(info_param: ParamInput, default_derived=True) -> ParamDict:
def expand_info_param(info_param: ParamInput, default_derived=True) -> 'ParamDict':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops, probably a search and replace gone wrong...

@@ -76,7 +76,7 @@ def expand_info_param(info_param: ParamInput, default_derived=True) -> ParamDict
return info_param


def reduce_info_param(info_param: ParamDict) -> ParamInput:
def reduce_info_param(info_param: 'ParamDict') -> ParamInput:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all of these, see above

@@ -56,11 +56,11 @@ class MCMC(CovmatSampler):

# instance variables from yaml
burn_in: NumberWithUnits
learn_every: NumberWithUnits
output_every: NumberWithUnits
learn_every: Union[NumberWithUnits, str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why only some of these changed. In terms of usage, optional str is probably not very helpful since should be converted at read in, maybe NumberWithUnits and more flexible check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the only ones causing errors, but, following your suggestion, I moved the problem on the type-checking side as we know the NumberWithUnits can come in as a string.

@ggalloni
Copy link
Contributor Author

ggalloni commented Oct 2, 2024

For assessing the robustness of this, I am not sure how to test it. An idea would be to switch the default of _enforce_types to True, let tests here run and, if successful, set it back to False. So at least we know that everything internal is working as expected. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants