Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance type checking #382

Closed
wants to merge 22 commits into from
Closed

Conversation

ggalloni
Copy link
Contributor

@ggalloni ggalloni commented Sep 4, 2024

The method validate_info of CobayaComponent, checks only bool.

This PR is enhancing that to check every relevant type, including generic types (List[], Dict[], Tuple[], etc).

@ggalloni
Copy link
Contributor Author

ggalloni commented Sep 4, 2024

The new code is raising a TypeError when max_samples="bad_value", however, the test_mcmc.py (MPI case) is still breaking as if it is not catching that.

Do you have an idea why this could be happening?

@codecov-commenter
Copy link

codecov-commenter commented Sep 4, 2024

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

Attention: Patch coverage is 26.31579% with 42 lines in your changes missing coverage. Please review.

Project coverage is 74.25%. Comparing base (735f7a8) to head (19e4143).

Files with missing lines Patch % Lines
cobaya/component.py 22.64% 41 Missing ⚠️
cobaya/samplers/polychord/polychord.py 0.00% 1 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #382      +/-   ##
==========================================
- Coverage   74.57%   74.25%   -0.33%     
==========================================
  Files         147      147              
  Lines       11200    11247      +47     
==========================================
- Hits         8352     8351       -1     
- Misses       2848     2896      +48     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ggalloni
Copy link
Contributor Author

ggalloni commented Sep 6, 2024

I left the validation of bools as it was and added an enforce_types attribute that will trigger the new code.

In this way, one can force type checking by setting enforce_types=True in any descendent class of CobayaComponent, without touching the old validation.

Copy link
Collaborator

@cmbant cmbant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting to see this. Bit mixed feeling though, a bit complicated than hoping (for something is just checking), not very clear how robust it is.

cobaya/model.py Outdated
@@ -589,7 +589,7 @@ def logpost(self,
return self.logposterior(params_values, make_finite=make_finite,
return_derived=False, cached=cached).logpost

def get_valid_point(self, max_tries: int, ignore_fixed_ref: bool = False,
def get_valid_point(self, max_tries: Union[int, str], ignore_fixed_ref: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has to be int as used in the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @cmbant, thanks for the feedback on this!

Indeed I think that, technically, that could be a float, as it is used just for evaluations like if tries<max_tries. Still, I wouldn't feel comfortable with having a floating max_tries, right?

Also, that number comes from the value of a NumberWithUnits, already parsed, so that should be either a float or an int. Thus, I would drop the Union with str (going back as it was before).

@@ -37,7 +37,7 @@ def cosmomc_root_to_cobaya_info_dict(root: str, derived_to_input=()) -> InputDic
name = name.replace('chi2_', 'chi2__')
if name.startswith('minuslogprior') or name == 'chi2':
continue
param_dict: ParamDict = {'latex': par.label}
param_dict: 'ParamDict' = {'latex': par.label}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actual type should be better than deferred where possible

Copy link
Contributor Author

@ggalloni ggalloni Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed completely deferring 'ParamDict'. This comes at the cost of moving the definition of ParamsDict in typing a few lines below, after the definition of ParamDict. Let me know what you think of this

@@ -331,6 +331,8 @@ class CobayaComponent(HasLogger, HasDefaults):
_at_resume_prefer_new: List[str] = ["version"]
_at_resume_prefer_old: List[str] = []

enforce_types: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should start with _ as not something we want changed by yaml

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about this. Indeed, adding the underscore we cannot enforce types neither through the yaml nor through the info dict (as I do in the test for instance). So, I guess this becomes a question of what default behavior we want to use. If the default value of enforce_types is False, the user can only enforce types on components he/she defines with that flag to True. But, for example, types of built-in samplers and such cannot be checked (except if we add in their definition the opposite flag). This may be a good choice since I guess that types of built-in stuff are checked anyway, in a way of the other. Vice versa, with enforce_types = True by default, everything will be explicitly checked and there is no need for the user to access that.

Am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would probably have to be False by default to keep people from hitting annoying errors and compatibility with existing external likelihoods. Was thinking of this more as an option for likelihood developers who support accurate type hint in their code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok great, then I'll make it private and default at False 👍

elif expected_type is int: # for numpy integers
if value == float('inf'): # for infinite values parsed as floats
return isinstance(value, float)
return isinstance(value, int) or "numpy.int" in str(type(value))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can simplify some of this using generic numbers.Number types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about using some more generic types, I will commit that soon.

@@ -42,7 +42,7 @@ def is_derived_param(info_param: ParamInput) -> bool:
return expand_info_param(info_param).get("derived", False) is not False


def expand_info_param(info_param: ParamInput, default_derived=True) -> ParamDict:
def expand_info_param(info_param: ParamInput, default_derived=True) -> 'ParamDict':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops, probably a search and replace gone wrong...

@@ -76,7 +76,7 @@ def expand_info_param(info_param: ParamInput, default_derived=True) -> ParamDict
return info_param


def reduce_info_param(info_param: ParamDict) -> ParamInput:
def reduce_info_param(info_param: 'ParamDict') -> ParamInput:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all of these, see above

@@ -56,11 +56,11 @@ class MCMC(CovmatSampler):

# instance variables from yaml
burn_in: NumberWithUnits
learn_every: NumberWithUnits
output_every: NumberWithUnits
learn_every: Union[NumberWithUnits, str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why only some of these changed. In terms of usage, optional str is probably not very helpful since should be converted at read in, maybe NumberWithUnits and more flexible check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the only ones causing errors, but, following your suggestion, I moved the problem on the type-checking side as we know the NumberWithUnits can come in as a string.

@ggalloni
Copy link
Contributor Author

ggalloni commented Oct 2, 2024

For assessing the robustness of this, I am not sure how to test it. An idea would be to switch the default of _enforce_types to True, let tests here run and, if successful, set it back to False. So at least we know that everything internal is working as expected. What do you think?


if hasattr(expected_type, "__origin__"):
return self._validate_composite_type(expected_type, value)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are only the values general here, not also the expected_types?
e.g. any Mapping type could accept any Mapping value? (e.g. empty_dict is MappingProxyType)

Note also in numpy can also end up with "numbers" that are zero-rank arrays like np.array(1), which I suspect may not pass isinstance(Real), though not sure if that's ever an issue for setting parameters.

@cmbant
Copy link
Collaborator

cmbant commented Nov 1, 2024

The best I can come up with that works with empty_dict, Sequence, Tuple[float] and TypedDicts, and allows numpy arrays for Sequence[float] and Tuple[float], is something like this:


      def validate_info(self, name: str, value: Any, annotations: dict):
        print(annotations)
        if name in annotations:
            expected_type = annotations[name]
            print(name, expected_type)
            if not self._validate_type(expected_type, value):
                msg = f"Attribute '{name}' must be of type {expected_type}, not {type(value)}(value={value})"
                raise TypeError(msg)

    def _validate_composite_type(self, expected_type, value):
        origin = expected_type.__origin__
        try: # for Callable and Sequence types, which have no __args__
            args = expected_type.__args__
        except AttributeError:
            pass
        if origin is Union:
            return any(self._validate_type(t, value) for t in args)
        elif origin is Optional:
            return value is None or self._validate_type(args[0], value)
        elif issubclass(origin, Sequence) and isinstance(value, Iterable) and len(args)==1:
            return all(self._validate_type(args[0], item) for item in value)
        elif issubclass(origin, Sequence):
            return isinstance(value, Sequence) and len(args) == len(value) and all(
                self._validate_type(t, v) for t, v in zip(args, value)
            )
        elif origin is dict:
            return isinstance(value, Mapping) and all(
                self._validate_type(args[0], k) and self._validate_type(args[1], v)
                for k, v in value.items()
            )
        elif origin is ClassVar:
            return self._validate_type(args[0], value)
        else:
            return isinstance(value, origin)

    def _validate_type(self, expected_type, value):
        if value is None or expected_type is Any: # Any is always valid
            return True

        if hasattr(expected_type, "__origin__"):
            return self._validate_composite_type(expected_type, value)
        else:
            print(expected_type, value)
            # Exceptions for some types
            if is_typeddict(expected_type):
               type_hints = get_type_hints(expected_type)
               if not isinstance(value, Mapping) or not set(value.keys()).issubset(set(type_hints.keys())):
                     return False  
               for key, value in value.items():
                    self.validate_info(key, value, type_hints) 
               return True                
            elif expected_type is int:
                return value == float('inf') or isinstance(value, Integral)
            elif expected_type is float:
                return isinstance(value, Real) or isinstance(value, np.ndarray) and not value.ndim
            elif expected_type is NumberWithUnits:
                return isinstance(value, (Real, str))
            return isinstance(value, expected_type)

    def validate_attributes(self):
        annotations = self.get_annotations()
        for name in annotations.keys():
            self.validate_info(name, getattr(self, name, None), annotations)

However, is_typeddict is only in core typing from 3.10.

@cmbant
Copy link
Collaborator

cmbant commented Nov 4, 2024

My attempt to generalize and refactor this a bit is now in #388.
@ggalloni did you have an SOLikeT build to test against? Anything missed?

@ggalloni
Copy link
Contributor Author

ggalloni commented Nov 6, 2024

Hello @cmbant, thanks for your help with this!
Yes, I was using SOLikeT/#192 to test this, so it should be sufficient to point it to the new branch of #388. I guess that would also tell us if something is missing since it was passing all tests using #382 instead.

@cmbant
Copy link
Collaborator

cmbant commented Nov 6, 2024

OK great, let me know any probs. I also just pushed change to hopefully also make it work with deferred types.

@ggalloni
Copy link
Contributor Author

ggalloni commented Nov 6, 2024

Currently, all non-WIndows builds are failing due to CCL not building correctly...
Still, Windows is passing all tests, which is reassuring 👍

@cmbant
Copy link
Collaborator

cmbant commented Nov 6, 2024

Except that you don't have _enforce_types=True, only enforce_types...

@ggalloni
Copy link
Contributor Author

ggalloni commented Nov 7, 2024

I fixed that (I thought I already did...) and am getting an error handling ClassVar.

This seems to happen because that is dealt with only if origin and args are defined for the expected_type.
Instead, some checks skip all that part and produce an error at line 248 of typing.py when trying to execute

isinstance(value, typing.ClassVar)

@cmbant
Copy link
Collaborator

cmbant commented Nov 7, 2024

Can you give specific example?

@cmbant
Copy link
Collaborator

cmbant commented Nov 7, 2024

I made a fix, looks like running OK on windows

@cmbant
Copy link
Collaborator

cmbant commented Nov 11, 2024

I merged, thanks!

@cmbant cmbant closed this Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants