diff --git a/skpro/datatypes/_base/__init__.py b/skpro/datatypes/_base/__init__.py new file mode 100644 index 000000000..b6727e858 --- /dev/null +++ b/skpro/datatypes/_base/__init__.py @@ -0,0 +1,5 @@ +"""Base module for datatypes.""" + +from skpro.datatypes._base._base import BaseConverter, BaseDatatype + +__all__ = ["BaseConverter", "BaseDatatype"] diff --git a/skpro/datatypes/_base/_base.py b/skpro/datatypes/_base/_base.py new file mode 100644 index 000000000..65d36e3e8 --- /dev/null +++ b/skpro/datatypes/_base/_base.py @@ -0,0 +1,357 @@ +# copyright: sktime developers, BSD-3-Clause License (see LICENSE file) +"""Base class for data types.""" + +__author__ = ["fkiraly"] + +from skpro.base import BaseObject +from skpro.datatypes._common import _ret +from skpro.utils.deep_equals import deep_equals + + +class BaseDatatype(BaseObject): + """Base class for data types. + + This class is the base class for all data types in sktime. + """ + + _tags = { + "object_type": "datatype", + "scitype": None, + "name": None, # any string + "name_python": None, # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": None, + } + + def __init__(self): + super().__init__() + + # call defaults to check + def __call__(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : instance of self only returned if return_metadata is True. + Metadata dictionary. + """ + return self._check(obj=obj, return_metadata=return_metadata, var_name=var_name) + + def check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + If self has parameters set, the check will in addition + check whether metadata of obj is equal to self's parameters. + In this case, ``return_metadata`` will always include the + metadata fields required to check the parameters. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : instance of self only returned if return_metadata is True. + Metadata dictionary. + """ + self_params = self.get_params() + + need_check = [k for k in self_params if self_params[k] is not None] + self_dict = {k: self_params[k] for k in need_check} + + return_metadata_orig = return_metadata + + # update return_metadata to retrieve any self_params + # return_metadata_bool has updated condition + if not len(need_check) == 0: + if isinstance(return_metadata, bool): + if not return_metadata: + return_metadata = need_check + return_metadata_bool = True + else: + return_metadata = set(return_metadata).union(need_check) + return_metadata = list(return_metadata) + return_metadata_bool = True + elif isinstance(return_metadata, bool): + return_metadata_bool = return_metadata + else: + return_metadata_bool = True + + # call inner _check + check_res = self._check( + obj=obj, return_metadata=return_metadata, var_name=var_name + ) + + if return_metadata_bool: + valid = check_res[0] + msg = check_res[1] + metadata = check_res[2] + else: + valid = check_res + msg = "" + + if not valid: + return _ret(False, msg, None, return_metadata_orig) + + # now we know the check is valid, but we need to compare fields + metadata_sub = {k: metadata[k] for k in self_dict} + eqs, msg = deep_equals(self_dict, metadata_sub, return_msg=True) + if not eqs: + msg = f"metadata of type unequal, {msg}" + return _ret(False, msg, None, return_metadata_orig) + + self_type = type(self)(**metadata) + return _ret(True, "", self_type, return_metadata_orig) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + raise NotImplementedError + + def __getitem__(self, key): + """Get attribute by key. + + Parameters + ---------- + key : str + Attribute name. + + Returns + ------- + value : any + Attribute value. + """ + return getattr(self, key) + + def get(self, key, default=None): + """Get attribute by key. + + Parameters + ---------- + key : str + Attribute name. + default : any, optional (default=None) + Default value if attribute does not exist. + + Returns + ------- + value : any + Attribute value. + """ + return getattr(self, key, default) + + def _get_key(self): + """Get unique dictionary key corresponding to self. + + Private function, used in collecting a dictionary of checks. + """ + mtype = self.get_class_tag("name") + scitype = self.get_class_tag("scitype") + return (mtype, scitype) + + +class BaseConverter(BaseObject): + """Base class for data type converters. + + This class is the base class for all data type converters in sktime. + """ + + _tags = { + "object_type": "converter", + "mtype_from": None, # type to convert from - BaseDatatype class or str + "mtype_to": None, # type to convert to - BaseDatatype class or str + "multiple_conversions": False, # whether converter encodes multiple conversions + "python_version": None, + "python_dependencies": None, + } + + def __init__(self, mtype_from=None, mtype_to=None): + self.mtype_from = mtype_from + self.mtype_to = mtype_to + super().__init__() + + if mtype_from is not None: + self.set_tags(**{"mtype_from": mtype_from}) + if mtype_to is not None: + self.set_tags(**{"mtype_to": mtype_to}) + + mtype_from = self.get_tag("mtype_from") + mtype_to = self.get_tag("mtype_to") + + if mtype_from is None: + raise ValueError( + f"Error in instantiating {self.__class__.__name__}: " + "mtype_from and mtype_to must be set if the class has no defaults. " + "For valid pairs of defaults, use get_conversions." + ) + if mtype_to is None: + raise ValueError( + f"Error in instantiating {self.__class__.__name__}: " + "mtype_to must be set in constructor, as the class has no defaults. " + "For valid pairs of defaults, use get_conversions." + ) + if (mtype_from, mtype_to) not in self.__class__.get_conversions(): + raise ValueError( + f"Error in instantiating {self.__class__.__name__}: " + "mtype_from and mtype_to must be a valid pair of defaults. " + "For valid pairs of defaults, use get_conversions." + ) + + # call defaults to convert + def __call__(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + + Returns + ------- + converted_obj : any + Object obj converted to another machine type. + """ + return self.convert(obj=obj, store=store) + + def convert(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + """ + return self._convert(obj, store) + + def _convert(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + """ + raise NotImplementedError + + @classmethod + def get_conversions(cls): + """Get all conversions. + + Returns + ------- + list of tuples (BaseDatatype subclass, BaseDatatype subclass) + List of all conversions in this class. + """ + cls_from = cls.get_class_tag("mtype_from") + cls_to = cls.get_class_tag("mtype_to") + + if cls_from is not None and cls_to is not None: + return [(cls_from, cls_to)] + # if multiple conversions are encoded, this should be overridden + raise NotImplementedError + + def _get_cls_from_to(self): + """Get classes from and to. + + Returns + ------- + cls_from : BaseDatatype subclass + Class to convert from. + cls_to : BaseDatatype subclass + Class to convert to. + """ + cls_from = self.get_tag("mtype_from") + cls_to = self.get_tag("mtype_to") + + cls_from = _coerce_str_to_cls(cls_from) + cls_to = _coerce_str_to_cls(cls_to) + + return cls_from, cls_to + + def _get_key(self): + """Get unique dictionary key corresponding to self. + + Private function, used in collecting a dictionary of checks. + """ + cls_from, cls_to = self._get_cls_from_to() + + mtype_from = cls_from.get_class_tag("name") + mtype_to = cls_to.get_class_tag("name") + scitype = cls_to.get_class_tag("scitype") + return (mtype_from, mtype_to, scitype) + + +def _coerce_str_to_cls(cls_or_str): + """Get class from string. + + Parameters + ---------- + cls_or_str : str or class + Class or string. If string, assumed to be a unique mtype string from + one of the BaseDatatype subclasses. + + Returns + ------- + cls : cls_or_str, if was class; otherwise, class corresponding to string. + """ + if not isinstance(cls_or_str, str): + return cls_or_str + + # otherwise, we use the string to get the class from the check dict + # perhaps it is nicer to transfer this to a registry later. + from skpro.datatypes._check import get_check_dict + + cd = get_check_dict(soft_deps="all") + cls = [cd[k].__class__ for k in cd if k[0] == cls_or_str] + if len(cls) > 1: + raise ValueError(f"Error in converting string to class: {cls_or_str}") + elif len(cls) < 1: + return None + return cls[0] diff --git a/skpro/datatypes/_check.py b/skpro/datatypes/_check.py index 8b1069445..ca8065d38 100644 --- a/skpro/datatypes/_check.py +++ b/skpro/datatypes/_check.py @@ -23,21 +23,68 @@ "mtype", ] +from functools import lru_cache + import numpy as np +from skpro.datatypes._base import BaseDatatype from skpro.datatypes._common import _metadata_requested, _ret from skpro.datatypes._proba import check_dict_Proba from skpro.datatypes._registry import AMBIGUOUS_MTYPES, SCITYPE_LIST, mtype_to_scitype -from skpro.datatypes._table import check_dict_Table -# pool convert_dict-s -check_dict = dict() -check_dict.update(check_dict_Table) -check_dict.update(check_dict_Proba) + +def get_check_dict(soft_deps="present"): + """Retrieve check_dict, caches the first time it is requested. + + This is to avoid repeated, time consuming crawling in generate_check_dict, + which would otherwise be called every time check_dict is requested. + + Parameters + ---------- + soft_deps : str, optional - one of "present", "all" + "present" - only checks with soft dependencies present are included + "all" - all checks are included + """ + if soft_deps not in ["present", "all"]: + raise ValueError( + "Error in get_check_dict, soft_deps argument must be 'present' or 'all', " + f"found {soft_deps}" + ) + check_dict = generate_check_dict(soft_deps=soft_deps) + return check_dict.copy() + + +@lru_cache(maxsize=1) +def generate_check_dict(soft_deps="present"): + """Generate check_dict using lookup.""" + from skbase.utils.dependencies import _check_estimator_deps + + from skpro.utils.retrieval import _all_classes + + classes = _all_classes("skpro.datatypes") + classes = [x[1] for x in classes] + classes = [x for x in classes if issubclass(x, BaseDatatype)] + classes = [x for x in classes if not x.__name__.startswith("Base")] + + # subset only to data types with soft dependencies present + if soft_deps == "present": + classes = [x for x in classes if _check_estimator_deps(x, severity="none")] + + check_dict = dict() + for cls in classes: + k = cls() + key = k._get_key() + check_dict[key] = k + + # temporary while refactoring + check_dict.update(check_dict_Proba) + + return check_dict def _check_scitype_valid(scitype: str = None): """Check validity of scitype.""" + check_dict = get_check_dict() valid_scitypes = list({x[1] for x in check_dict.keys()}) if not isinstance(scitype, str): @@ -153,6 +200,7 @@ def check_is_mtype( """ mtype = _coerce_list_of_str(mtype, var_name="mtype") + check_dict = get_check_dict() valid_keys = check_dict.keys() # we loop through individual mtypes in mtype and see whether they pass the check @@ -300,6 +348,7 @@ def mtype( for scitype in as_scitype: _check_scitype_valid(scitype) + check_dict = get_check_dict() m_plus_scitypes = [ (x[0], x[1]) for x in check_dict.keys() if x[0] not in exclude_mtypes ] @@ -409,6 +458,7 @@ def check_is_scitype( for x in scitype: _check_scitype_valid(x) + check_dict = get_check_dict() valid_keys = check_dict.keys() # find all the mtype keys corresponding to the scitypes diff --git a/skpro/datatypes/_convert.py b/skpro/datatypes/_convert.py index 8c0c720e4..2084ec3c9 100644 --- a/skpro/datatypes/_convert.py +++ b/skpro/datatypes/_convert.py @@ -65,19 +65,86 @@ ] from copy import deepcopy +from functools import lru_cache import numpy as np import pandas as pd +from skpro.datatypes._base import BaseConverter from skpro.datatypes._check import mtype as infer_mtype from skpro.datatypes._proba import convert_dict_Proba from skpro.datatypes._registry import mtype_to_scitype from skpro.datatypes._table import convert_dict_Table -# pool convert_dict-s and infer_mtype_dict-s -convert_dict = dict() -convert_dict.update(convert_dict_Table) -convert_dict.update(convert_dict_Proba) + +def get_convert_dict(soft_deps="present"): + """Retrieve convert_dict, caches the first time it is requested. + + This is to avoid repeated, time consuming crawling in generate_check_dict, + which would otherwise be called every time check_dict is requested. + + Parameters + ---------- + soft_deps : str, optional - one of "present", "all" + "present" - only conversions with soft dependencies present are included + "all" - all conversions are included + """ + if soft_deps not in ["present", "all"]: + raise ValueError( + "Error in get_check_dict, soft_deps argument must be 'present' or 'all', " + f"found {soft_deps}" + ) + convert_dict = generate_convert_dict(soft_deps=soft_deps) + return convert_dict.copy() + + +@lru_cache(maxsize=1) +def generate_convert_dict(soft_deps="present"): + """Generate convert_dict using lookup.""" + from skbase.utils.dependencies import _check_estimator_deps + + from skpro.utils.retrieval import _all_classes + + classes = _all_classes("skpro.datatypes") + classes = [x[1] for x in classes] + classes = [x for x in classes if issubclass(x, BaseConverter)] + classes = [x for x in classes if not x.__name__.startswith("Base")] + + # subset only to data types with soft dependencies present + if soft_deps == "present": + classes = [x for x in classes if _check_estimator_deps(x, severity="none")] + + convert_dict = dict() + for cls in classes: + if not cls.get_class_tag("multiple_conversions", False): + k = cls() + key = k._get_key() + convert_dict[key] = k + else: + for cls_to_cls in cls.get_conversions(): + k = cls(*cls_to_cls) + + # check dependencies for both classes + # only add conversions if dependencies are satisfied for to and from + cls_from, cls_to = k._get_cls_from_to() + + # do not add conversion if dependencies are not satisfied + if cls_from is None or cls_to is None: + continue + filter_sd = soft_deps in ["present"] + if filter_sd and not _check_estimator_deps(cls_from, severity="none"): + continue + if filter_sd and not _check_estimator_deps(cls_to, severity="none"): + continue + + key = k._get_key() + convert_dict[key] = k + + # temporary while refactoring + convert_dict.update(convert_dict_Proba) + convert_dict.update(convert_dict_Table) + + return convert_dict def convert( @@ -159,6 +226,7 @@ def convert( key = (from_type, to_type, as_scitype) + convert_dict = get_convert_dict() if key not in convert_dict.keys(): raise NotImplementedError( "no conversion defined from type " + str(from_type) + " to " + str(to_type) @@ -304,13 +372,16 @@ def _get_first_mtype_of_same_scitype(from_mtype, to_mtypes, varname="to_mtypes") return to_type -def _conversions_defined(scitype: str): +def _conversions_defined(scitype: str, soft_deps: str = "present"): """Return an indicator matrix which conversions are defined for scitype. Parameters ---------- scitype: str - name of scitype for which conversions are queried valid scitype strings, with explanation, are in datatypes.SCITYPE_REGISTER + soft_deps : str, optional - one of "present", "all" + "present" - only conversions with soft dependencies present are included + "all" - all conversions are included Returns ------- @@ -318,6 +389,7 @@ def _conversions_defined(scitype: str): entry of row i, col j is 1 if conversion from i to j is defined, 0 if conversion from i to j is not defined """ + convert_dict = get_convert_dict(soft_deps=soft_deps) pairs = [(x[0], x[1]) for x in list(convert_dict.keys()) if x[2] == scitype] cols0 = {x[0] for x in list(convert_dict.keys()) if x[2] == scitype} cols1 = {x[1] for x in list(convert_dict.keys()) if x[2] == scitype} diff --git a/skpro/datatypes/_proba/_check.py b/skpro/datatypes/_proba/_check.py index 4355706c9..c6c1c063f 100644 --- a/skpro/datatypes/_proba/_check.py +++ b/skpro/datatypes/_proba/_check.py @@ -1,6 +1,6 @@ -"""Machine type checkers for Series scitype. +"""Machine type checkers for Proba (probabilistic return) scitype. -Exports checkers for Series scitype: +Exports checkers for Proba scitype: check_dict: dict indexed by pairs of str 1st element = mtype - str diff --git a/skpro/datatypes/_registry.py b/skpro/datatypes/_registry.py index bc0c36c36..4da55b86a 100644 --- a/skpro/datatypes/_registry.py +++ b/skpro/datatypes/_registry.py @@ -46,7 +46,10 @@ MTYPE_REGISTER += MTYPE_REGISTER_TABLE MTYPE_REGISTER += MTYPE_REGISTER_PROBA -MTYPE_SOFT_DEPS = {} +MTYPE_SOFT_DEPS = { + "polars_eager_table": "polars", + "polars_lazy_table": "polars", +} # mtypes to exclude in checking since they are ambiguous and rare diff --git a/skpro/datatypes/_table/__init__.py b/skpro/datatypes/_table/__init__.py index ef620b0d9..0481dcaee 100644 --- a/skpro/datatypes/_table/__init__.py +++ b/skpro/datatypes/_table/__init__.py @@ -1,6 +1,5 @@ """Module exports: Series type checkers, converters and mtype inference.""" -from skpro.datatypes._table._check import check_dict as check_dict_Table from skpro.datatypes._table._convert import convert_dict as convert_dict_Table from skpro.datatypes._table._examples import example_dict as example_dict_Table from skpro.datatypes._table._examples import ( @@ -12,7 +11,6 @@ from skpro.datatypes._table._registry import MTYPE_LIST_TABLE, MTYPE_REGISTER_TABLE __all__ = [ - "check_dict_Table", "convert_dict_Table", "MTYPE_LIST_TABLE", "MTYPE_REGISTER_TABLE", diff --git a/skpro/datatypes/_table/_base.py b/skpro/datatypes/_table/_base.py new file mode 100644 index 000000000..3681e1fd0 --- /dev/null +++ b/skpro/datatypes/_table/_base.py @@ -0,0 +1,55 @@ +# copyright: skpro developers, BSD-3-Clause License (see LICENSE file) +"""Base class for data types.""" + +__author__ = ["fkiraly"] + +from skpro.datatypes._base import BaseDatatype + + +class BaseTable(BaseDatatype): + """Base class for Table data types. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff table has one variable + is_empty: bool + True iff table has no variables or no instances + has_nans: bool + True iff the table contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in table + feature_names: list of int or object + names of variables in table + """ + + _tags = { + "scitype": "Table", + "name": None, # any string + "name_python": None, # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": None, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + self.is_univariate = is_univariate + self.is_empty = is_empty + self.has_nans = has_nans + self.n_instances = n_instances + self.n_features = n_features + self.feature_names = feature_names + + super().__init__() diff --git a/skpro/datatypes/_table/_check.py b/skpro/datatypes/_table/_check.py index 569f6b3c9..290c1f94d 100644 --- a/skpro/datatypes/_table/_check.py +++ b/skpro/datatypes/_table/_check.py @@ -1,14 +1,6 @@ -"""Machine type checkers for Table scitype. +"""Machine type classes for Table scitype. -Exports checkers for Table scitype: - -check_dict: dict indexed by pairs of str - 1st element = mtype - str - 2nd element = scitype - str -elements are checker/validation functions for mtype - -Function signature of all elements -check_dict[(mtype, scitype)] +Checks for each class are defined in the "check" method, of signature: Parameters ---------- @@ -36,21 +28,91 @@ __author__ = ["fkiraly"] -__all__ = ["check_dict"] - import numpy as np import pandas as pd from skpro.datatypes._common import _req, _ret -from skpro.utils.validation._dependencies import _check_soft_dependencies - -check_dict = dict() - +from skpro.datatypes._table._base import BaseTable PRIMITIVE_TYPES = (float, int, str) -def check_pddataframe_table(obj, return_metadata=False, var_name="obj"): +class TablePdDataFrame(BaseTable): + """Data type: pandas.DataFrame based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff table has one variable + is_empty: bool + True iff table has no variables or no instances + has_nans: bool + True iff the table contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in table + feature_names: list of int or object + names of variables in table + """ + + _tags = { + "scitype": "Table", + "name": "pd_DataFrame_Table", # any string + "name_python": "table_pd_df", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "pandas", + "capability:multivariate": True, + "capability:missing_values": True, + "capability:index": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) + + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return _check_pddataframe_table(obj, return_metadata, var_name) + + +def _check_pddataframe_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, pd.DataFrame): @@ -80,10 +142,82 @@ def check_pddataframe_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("pd_DataFrame_Table", "Table")] = check_pddataframe_table - +class TablePdSeries(BaseTable): + """Data type: pandas.Series based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff table has one variable + is_empty: bool + True iff table has no variables or no instances + has_nans: bool + True iff the table contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in table + feature_names: list of int or object + names of variables in table + """ + + _tags = { + "scitype": "Table", + "name": "pd_Series_Table", # any string + "name_python": "table_pd_series", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "pandas", + "capability:multivariate": False, + "capability:missing_values": True, + "capability:index": True, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) -def check_pdseries_table(obj, return_metadata=False, var_name="obj"): + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return _check_pdseries_table(obj, return_metadata, var_name) + + +def _check_pdseries_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, pd.Series): @@ -119,10 +253,82 @@ def check_pdseries_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("pd_Series_Table", "Table")] = check_pdseries_table - +class TableNp1D(BaseTable): + """Data type: 1D np.ndarray based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff table has one variable + is_empty: bool + True iff table has no variables or no instances + has_nans: bool + True iff the table contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in table + feature_names: list of int or object + names of variables in table + """ + + _tags = { + "scitype": "Table", + "name": "numpy1D", # any string + "name_python": "table_numpy1d", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "numpy", + "capability:multivariate": False, + "capability:missing_values": True, + "capability:index": False, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) -def check_numpy1d_table(obj, return_metadata=False, var_name="obj"): + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return _check_numpy1d_table(obj, return_metadata, var_name) + + +def _check_numpy1d_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, np.ndarray): @@ -153,10 +359,82 @@ def check_numpy1d_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("numpy1D", "Table")] = check_numpy1d_table - +class TableNp2D(BaseTable): + """Data type: 2D np.ndarray based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff table has one variable + is_empty: bool + True iff table has no variables or no instances + has_nans: bool + True iff the table contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in table + feature_names: list of int or object + names of variables in table + """ + + _tags = { + "scitype": "Table", + "name": "numpy2D", # any string + "name_python": "table_numpy2d", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "numpy", + "capability:multivariate": True, + "capability:missing_values": True, + "capability:index": False, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) -def check_numpy2d_table(obj, return_metadata=False, var_name="obj"): + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return _check_numpy2d_table(obj, return_metadata, var_name) + + +def _check_numpy2d_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, np.ndarray): @@ -186,10 +464,82 @@ def check_numpy2d_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("numpy2D", "Table")] = check_numpy2d_table - +class TableListOfDict(BaseTable): + """Data type: list of dict based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff table has one variable + is_empty: bool + True iff table has no variables or no instances + has_nans: bool + True iff the table contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in table + feature_names: list of int or object + names of variables in table + """ + + _tags = { + "scitype": "Table", + "name": "list_of_dict", # any string + "name_python": "table_list_of_dict", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": "numpy", + "capability:multivariate": True, + "capability:missing_values": True, + "capability:index": False, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, + ) -def check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + return _check_list_of_dict_table(obj, return_metadata, var_name) + + +def _check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): metadata = dict() if not isinstance(obj, list): @@ -242,28 +592,155 @@ def check_list_of_dict_table(obj, return_metadata=False, var_name="obj"): return _ret(True, None, metadata, return_metadata) -check_dict[("list_of_dict", "Table")] = check_list_of_dict_table - - -if _check_soft_dependencies(["polars", "pyarrow"], severity="none"): - from skpro.datatypes._adapter.polars import check_polars_frame - - def check_polars_table(obj, return_metadata=False, var_name="obj"): - return check_polars_frame( - obj=obj, - return_metadata=return_metadata, - var_name=var_name, - lazy=False, +class TablePolarsEager(BaseTable): + """Data type: eager polars DataFrame based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff table has one variable + is_empty: bool + True iff table has no variables or no instances + has_nans: bool + True iff the table contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in table + feature_names: list of int or object + names of variables in table + """ + + _tags = { + "scitype": "Table", + "name": "polars_eager_table", # any string + "name_python": "table_polars_eager", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": ["polars", "pyarrow"], + "capability:multivariate": True, + "capability:missing_values": True, + "capability:index": False, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, ) - check_dict[("polars_eager_table", "Table")] = check_polars_table - - def check_polars_table_lazy(obj, return_metadata=False, var_name="obj"): - return check_polars_frame( - obj=obj, - return_metadata=return_metadata, - var_name=var_name, - lazy=True, + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + from skpro.datatypes._adapter.polars import check_polars_frame + + return check_polars_frame(obj, return_metadata, var_name, lazy=False) + + +class TablePolarsLazy(BaseTable): + """Data type: lazy polars DataFrame based specification of data frame table. + + Parameters are inferred by check. + + Parameters + ---------- + is_univariate: bool + True iff table has one variable + is_empty: bool + True iff table has no variables or no instances + has_nans: bool + True iff the table contains NaN values + n_instances: int + number of instances/rows in the table + n_features: int + number of variables in table + feature_names: list of int or object + names of variables in table + """ + + _tags = { + "scitype": "Table", + "name": "polars_lazy_table", # any string + "name_python": "table_polars_lazy", # lower_snake_case + "name_aliases": [], + "python_version": None, + "python_dependencies": ["polars", "pyarrow"], + "capability:multivariate": True, + "capability:missing_values": True, + "capability:index": False, + } + + def __init__( + self, + is_univariate=None, + is_empty=None, + has_nans=None, + n_instances=None, + n_features=None, + feature_names=None, + ): + super().__init__( + is_univariate=is_univariate, + n_instances=n_instances, + is_empty=is_empty, + has_nans=has_nans, + n_features=n_features, + feature_names=feature_names, ) - check_dict[("polars_lazy_table", "Table")] = check_polars_table_lazy + def _check(self, obj, return_metadata=False, var_name="obj"): + """Check if obj is of this data type. + + Parameters + ---------- + obj : any + Object to check. + return_metadata : bool, optional (default=False) + Whether to return metadata. + var_name : str, optional (default="obj") + Name of the variable to check, for use in error messages. + + Returns + ------- + valid : bool + Whether obj is of this data type. + msg : str, only returned if return_metadata is True. + Error message if obj is not of this data type. + metadata : dict, only returned if return_metadata is True. + Metadata dictionary. + """ + from skpro.datatypes._adapter.polars import check_polars_frame + + return check_polars_frame(obj, return_metadata, var_name, lazy=True) diff --git a/skpro/datatypes/_table/_convert.py b/skpro/datatypes/_table/_convert.py index a74aef8a7..b1ffd30db 100644 --- a/skpro/datatypes/_table/_convert.py +++ b/skpro/datatypes/_table/_convert.py @@ -33,6 +33,7 @@ import numpy as np import pandas as pd +from skpro.datatypes._base import BaseConverter from skpro.datatypes._convert_utils._convert import _extend_conversions from skpro.datatypes._table._registry import MTYPE_LIST_TABLE from skpro.utils.validation._dependencies import _check_soft_dependencies @@ -44,13 +45,70 @@ convert_dict = dict() -def convert_identity(obj, store=None): - return obj +class TableIdentity(BaseConverter): + """All Table scitype conversions of any mtype to itself. + + This is the identity conversion for Table scitype, + no coercion is done, the object is returned as is. + """ + + _tags = { + "object_type": "converter", + "mtype_from": None, + "mtype_to": None, + "multiple_conversions": True, + "python_version": None, + "python_dependencies": None, + } + + @classmethod + def get_conversions(cls): + """Get all conversions. + + Returns + ------- + list of tuples (BaseDatatype subclass, BaseDatatype subclass) + List of all conversions in this class. + """ + return [(tp, tp) for tp in MTYPE_LIST_TABLE] + + # identity conversion + def _convert(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + """ + return obj + + +class Numpy1dToNumpy2D(BaseConverter): + """Conversion: numpy1D -> numpy2D, of Table scitype.""" + _tags = { + "object_type": "converter", + "mtype_from": "numpy1D", # type to convert from - BaseDatatype class or str + "mtype_to": "numpy2D", # type to convert to - BaseDatatype class or str + "multiple_conversions": False, # whether converter encodes multiple conversions + "python_version": None, + "python_dependencies": None, + } -# assign identity function to type conversion to self -for tp in MTYPE_LIST_TABLE: - convert_dict[(tp, tp, "Table")] = convert_identity + def _convert(self, obj, store=None): + """Convert obj to another machine type. + + Parameters + ---------- + obj : any + Object to convert. + store : dict, optional (default=None) + Reference of storage for lossy conversions. + """ + return convert_1D_to_2D_numpy_as_Table(obj=obj, store=store) def convert_1D_to_2D_numpy_as_Table(obj: np.ndarray, store=None) -> np.ndarray: @@ -65,9 +123,6 @@ def convert_1D_to_2D_numpy_as_Table(obj: np.ndarray, store=None) -> np.ndarray: return res -convert_dict[("numpy1D", "numpy2D", "Table")] = convert_1D_to_2D_numpy_as_Table - - def convert_2D_to_1D_numpy_as_Table(obj: np.ndarray, store=None) -> np.ndarray: if not isinstance(obj, np.ndarray): raise TypeError("input must be a np.ndarray") diff --git a/skpro/datatypes/tests/test_check.py b/skpro/datatypes/tests/test_check.py index b703cd413..001f9749d 100644 --- a/skpro/datatypes/tests/test_check.py +++ b/skpro/datatypes/tests/test_check.py @@ -7,9 +7,9 @@ from skpro.datatypes._check import ( AMBIGUOUS_MTYPES, - check_dict, check_is_mtype, check_is_scitype, + get_check_dict, ) from skpro.datatypes._check import mtype as infer_mtype from skpro.datatypes._check import scitype as infer_scitype @@ -129,6 +129,7 @@ def test_check_positive(scitype, mtype, fixture_index): fixture = get_examples(mtype=mtype, as_scitype=scitype).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist, when full metadata is queried @@ -184,6 +185,7 @@ def test_check_positive_check_scitype(scitype, mtype, fixture_index): fixture = get_examples(mtype=mtype, as_scitype=scitype).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist, when full metadata is queried @@ -236,6 +238,7 @@ def test_check_metadata_inference(scitype, mtype, fixture_index): ).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # if the examples have no metadata to them, don't test metadata_provided = expected_metadata is not None @@ -358,6 +361,7 @@ def test_check_negative(scitype, mtype): fixture_wrong_type = fixtures[wrong_mtype].get(i) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist @@ -414,6 +418,7 @@ def test_mtype_infer(scitype, mtype, fixture_index): fixture = get_examples(mtype=mtype, as_scitype=scitype).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist @@ -466,6 +471,7 @@ def test_scitype_infer(scitype, mtype, fixture_index): fixture = get_examples(mtype=mtype, as_scitype=scitype).get(fixture_index) # todo: possibly remove this once all checks are defined + check_dict = get_check_dict() check_is_defined = (mtype, scitype) in check_dict.keys() # check fixtures that exist against checks that exist diff --git a/skpro/datatypes/tests/test_convert.py b/skpro/datatypes/tests/test_convert.py index 89db16cf0..f2fc02278 100644 --- a/skpro/datatypes/tests/test_convert.py +++ b/skpro/datatypes/tests/test_convert.py @@ -27,9 +27,9 @@ def _generate_fixture_tuples(): if scitype in SCITYPES_NO_CONVERSIONS: continue - conv_mat = _conversions_defined(scitype) + conv_mat = _conversions_defined(scitype, soft_deps="all") - mtypes = scitype_to_mtype(scitype, softdeps="exclude") + mtypes = scitype_to_mtype(scitype, softdeps="present") if len(mtypes) == 0: # if there are no mtypes, this must have been reached by mistake/bug diff --git a/skpro/datatypes/tests/test_polars.py b/skpro/datatypes/tests/test_polars.py index 27b425aa8..55b5ed573 100644 --- a/skpro/datatypes/tests/test_polars.py +++ b/skpro/datatypes/tests/test_polars.py @@ -11,8 +11,7 @@ if _check_soft_dependencies(["polars", "pyarrow"], severity="none"): import polars as pl - from skpro.datatypes._table._check import check_polars_table - from skpro.datatypes._table._convert import convert_pandas_to_polars_eager + from skpro.datatypes import check_is_mtype, convert TEST_ALPHAS = [0.05, 0.1, 0.25] @@ -43,12 +42,16 @@ def estimator(): return _estimator +def _pd_to_pl(df): + return convert(df, from_type="pd_DataFrame_Table", to_type="polars_eager_table") + + @pytest.fixture def polars_load_diabetes_polars(polars_load_diabetes_pandas): X_train, X_test, y_train = polars_load_diabetes_pandas - X_train_pl = convert_pandas_to_polars_eager(X_train) - X_test_pl = convert_pandas_to_polars_eager(X_test) - y_train_pl = convert_pandas_to_polars_eager(y_train) + X_train_pl = _pd_to_pl(X_train) + X_test_pl = _pd_to_pl(X_test) + y_train_pl = _pd_to_pl(y_train) # drop the index in the polars frame X_train_pl = X_train_pl.drop(["__index__"]) @@ -60,9 +63,9 @@ def polars_load_diabetes_polars(polars_load_diabetes_pandas): def polars_load_diabetes_polars_with_index(polars_load_diabetes_pandas): X_train, X_test, y_train = polars_load_diabetes_pandas - X_train_pl = convert_pandas_to_polars_eager(X_train) - X_test_pl = convert_pandas_to_polars_eager(X_test) - y_train_pl = convert_pandas_to_polars_eager(y_train) + X_train_pl = _pd_to_pl(X_train) + X_test_pl = _pd_to_pl(X_test) + y_train_pl = _pd_to_pl(y_train) return [X_train_pl, X_test_pl, y_train_pl] @@ -83,9 +86,9 @@ def test_polars_eager_conversion_methods( X_train, X_test, y_train = polars_load_diabetes_pandas X_train_pl, X_test_pl, y_train_pl = polars_load_diabetes_polars - assert check_polars_table(X_train_pl) - assert check_polars_table(X_test_pl) - assert check_polars_table(y_train_pl) + assert check_is_mtype(X_train_pl, "polars_eager_table") + assert check_is_mtype(X_test_pl, "polars_eager_table") + assert check_is_mtype(y_train_pl, "polars_eager_table") assert (X_train.values == X_train_pl.to_numpy()).all() assert (X_test.values == X_test_pl.to_numpy()).all() diff --git a/skpro/utils/retrieval.py b/skpro/utils/retrieval.py new file mode 100644 index 000000000..54eac58b8 --- /dev/null +++ b/skpro/utils/retrieval.py @@ -0,0 +1,97 @@ +"""Utility functions for retrieving objects from modules.""" +import importlib +import inspect +import pkgutil +from functools import lru_cache + +EXCLUDE_MODULES_STARTING_WITH = ("all", "test", "contrib") + + +def _all_functions(module_name): + """Get all functions from a module, including submodules. + + Excludes modules starting with 'all' or 'test'. + + Parameters + ---------- + module_name : str + Name of the module. + + Returns + ------- + functions_list : list + List of tuples (function_name: str, function_object: function). + """ + # copy to avoid modifying the cache + return _all_cond(module_name, inspect.isfunction).copy() + + +def _all_classes(module_name): + """Get all classes from a module, including submodules. + + Excludes modules starting with 'all' or 'test'. + + Parameters + ---------- + module_name : str + Name of the module. + + Returns + ------- + classes_list : list + List of tuples (class_name: str, class_ref: class). + """ + # copy to avoid modifying the cache + return _all_cond(module_name, inspect.isclass).copy() + + +@lru_cache +def _all_cond(module_name, cond): + """Get all objects from a module satisfying a condition. + + The condition should be a hashable callable, + of signature ``condition(obj) -> bool``. + + Excludes modules starting with 'all' or 'test'. + + Parameters + ---------- + module_name : str + Name of the module. + cond : callable + Condition to satisfy. + Signature: ``condition(obj) -> bool``, + passed as predicate to ``inspect.getmembers``. + + Returns + ------- + functions_list : list + List of tuples (function_name, function_object). + """ + # Import the package + package = importlib.import_module(module_name) + + # Initialize an empty list to hold all objects + obj_list = [] + + # Walk through the package's modules + package_path = package.__path__[0] + for _, modname, _ in pkgutil.walk_packages( + path=[package_path], prefix=package.__name__ + "." + ): + # Skip modules starting with 'all' or 'test' + if modname.split(".")[-1].startswith(EXCLUDE_MODULES_STARTING_WITH): + continue + + # Import the module + module = importlib.import_module(modname) + + # Get all objects from the module + for name, obj in inspect.getmembers(module, cond): + # if object is imported from another module, skip it + if obj.__module__ != module.__name__: + continue + # add the object to the list + obj_list.append((name, obj)) + + return obj_list