Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve validation tests for numpy.ndarray #598

Open
m-aguena opened this issue Jul 12, 2023 · 2 comments
Open

Improve validation tests for numpy.ndarray #598

m-aguena opened this issue Jul 12, 2023 · 2 comments
Labels
enhancement Tackle this someday but not immediately

Comments

@m-aguena
Copy link
Collaborator

When the type of a variable is numpy.ndarray (usually from something like np.array(2)), the tests in validate_argument fail.
This could be fixed by improving:

def _is_valid(arg, valid_type):
    if valid_type == "function":
        return callable(arg)
    if (valid_type in ("int_array", "float_array") and np.iterable(arg)):
        return isinstance(arg[0], _valid_types[valid_type])
    if isinstance(arg, np.ndarray):
        if (valid_type in (int, "int_array")):
            return arg.dtype.char in np.typecodes['AllInteger']
        if (valid_type in (float, "float_array")):
            return arg.dtype.char in np.typecodes['AllFloat']
        return False
    return isinstance(arg, _valid_types.get(valid_type, valid_type))
@m-aguena m-aguena added the enhancement Tackle this someday but not immediately label Jul 12, 2023
@hsinfan1996
Copy link
Collaborator

hsinfan1996 commented Jul 13, 2023

Will isinstance(obj, collections.abc.Iterable) help? It behaves differently from np.iterable(obj) for the 0-d array case. See https://numpy.org/doc/stable/reference/generated/numpy.iterable.html.

@m-aguena
Copy link
Collaborator Author

@hsinfan1996, thanks for the suggestion. But I think there is an even simpler solution:

def _is_valid(arg, valid_type):
    if valid_type == "function":
        return callable(arg)    
    if valid_type == "int_array":
        return np.array(arg).dtype.char in np.typecodes['AllInteger']    
    if valid_type == "float_array":
        return np.array(arg).dtype.char in np.typecodes['AllFloat']+np.typecodes['AllInteger'] 
    return isinstance(arg, _valid_types.get(valid_type, valid_type))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Tackle this someday but not immediately
Projects
None yet
Development

No branches or pull requests

2 participants