diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 303cefc4e..a2da5644f 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -256,6 +256,7 @@ def __init__( self.input_validator: Optional[BaseInputValidator] = None self.search_space_updates = search_space_updates + if search_space_updates is not None: if not isinstance(self.search_space_updates, HyperparameterSearchSpaceUpdates): diff --git a/autoPyTorch/constants.py b/autoPyTorch/constants.py index bfd56d27f..154d562ac 100644 --- a/autoPyTorch/constants.py +++ b/autoPyTorch/constants.py @@ -78,3 +78,5 @@ # To avoid that we get a sequence that is too long to be fed to a network MAX_WINDOW_SIZE_BASE = 500 + +MIN_CATEGORIES_FOR_EMBEDDING_MAX = 7 diff --git a/autoPyTorch/data/base_feature_validator.py b/autoPyTorch/data/base_feature_validator.py index 8f65f8607..bf9ad90ed 100644 --- a/autoPyTorch/data/base_feature_validator.py +++ b/autoPyTorch/data/base_feature_validator.py @@ -46,11 +46,11 @@ def __init__( # Required for dataset properties self.num_features: Optional[int] = None - self.categories: List[List[int]] = [] self.categorical_columns: List[int] = [] self.numerical_columns: List[int] = [] self.encode_columns: List[str] = [] + self.num_categories_per_col: Optional[List[int]] = [] self.all_nan_columns: Optional[Set[Union[int, str]]] = None self._is_fitted = False diff --git a/autoPyTorch/data/tabular_feature_validator.py b/autoPyTorch/data/tabular_feature_validator.py index 3beb19cba..a34e03131 100644 --- a/autoPyTorch/data/tabular_feature_validator.py +++ b/autoPyTorch/data/tabular_feature_validator.py @@ -77,10 +77,9 @@ class TabularFeatureValidator(BaseFeatureValidator): transformer. Attributes: - categories (List[List[str]]): - List for which an element at each index is a - list containing the categories for the respective - categorical column. + num_categories_per_col (List[int]): + List for which an element at each index is the number + of categories for the respective categorical column. transformed_columns (List[str]) List of columns that were transformed. column_transformer (Optional[BaseEstimator]) @@ -202,10 +201,8 @@ def _fit( encoded_categories = self.column_transformer.\ named_transformers_['categorical_pipeline'].\ named_steps['ordinalencoder'].categories_ - self.categories = [ - list(range(len(cat))) - for cat in encoded_categories - ] + + self.num_categories_per_col = [len(cat) for cat in encoded_categories] # differently to categorical_columns and numerical_columns, # this saves the index of the column. @@ -283,7 +280,6 @@ def transform( X = self.numpy_to_pandas(X) if ispandas(X) and not issparse(X): - if self.all_nan_columns is None: raise ValueError('_fit must be called before calling transform') diff --git a/autoPyTorch/data/tabular_validator.py b/autoPyTorch/data/tabular_validator.py index 0f6f89e1c..0735d49b4 100644 --- a/autoPyTorch/data/tabular_validator.py +++ b/autoPyTorch/data/tabular_validator.py @@ -111,6 +111,8 @@ def _compress_dataset( y=y, is_classification=self.is_classification, random_state=self.seed, + categorical_columns=self.feature_validator.categorical_columns, + n_categories_per_cat_column=self.feature_validator.num_categories_per_col, **self.dataset_compression # type: ignore [arg-type] ) self._reduced_dtype = dict(X.dtypes) if is_dataframe else X.dtype diff --git a/autoPyTorch/data/utils.py b/autoPyTorch/data/utils.py index 20ad5612e..2a44dd5c2 100644 --- a/autoPyTorch/data/utils.py +++ b/autoPyTorch/data/utils.py @@ -25,6 +25,7 @@ from sklearn.utils import _approximate_mode, check_random_state from sklearn.utils.validation import _num_samples, check_array +from autoPyTorch.constants import MIN_CATEGORIES_FOR_EMBEDDING_MAX from autoPyTorch.data.base_target_validator import SupportedTargetTypes from autoPyTorch.utils.common import ispandas @@ -459,8 +460,8 @@ def _subsample_by_indices( return X, y -def megabytes(arr: DatasetCompressionInputType) -> float: - +def get_raw_memory_usage(arr: DatasetCompressionInputType) -> float: + memory_in_bytes: float if isinstance(arr, np.ndarray): memory_in_bytes = arr.nbytes elif issparse(arr): @@ -470,8 +471,43 @@ def megabytes(arr: DatasetCompressionInputType) -> float: else: raise ValueError(f"Unrecognised data type of X, expected data type to " f"be in (np.ndarray, spmatrix, pd.DataFrame) but got :{type(arr)}") + return memory_in_bytes + + +def get_approximate_mem_usage_in_mb( + arr: DatasetCompressionInputType, + categorical_columns: List, + n_categories_per_cat_column: Optional[List[int]] = None +) -> float: + + err_msg = "Value number of categories per categorical is required when the data has categorical columns" + if ispandas(arr): + arr_dtypes = arr.dtypes.to_dict() + multipliers = [dtype.itemsize for col, dtype in arr_dtypes.items() if col not in categorical_columns] + if len(categorical_columns) > 0: + if n_categories_per_cat_column is None: + raise ValueError(err_msg) + for col, num_cat in zip(categorical_columns, n_categories_per_cat_column): + if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX: + multipliers.append(num_cat * arr_dtypes[col].itemsize) + else: + multipliers.append(arr_dtypes[col].itemsize) + size_one_row = sum(multipliers) + + elif isinstance(arr, (np.ndarray, spmatrix)): + n_cols = arr.shape[-1] - len(categorical_columns) + multiplier = arr.dtype.itemsize + if len(categorical_columns) > 0: + if n_categories_per_cat_column is None: + raise ValueError(err_msg) + # multiply num categories with the size of the column to capture memory after one hot encoding + n_cols += sum(num_cat if num_cat < MIN_CATEGORIES_FOR_EMBEDDING_MAX else 1 for num_cat in n_categories_per_cat_column) + size_one_row = n_cols * multiplier + else: + raise ValueError(f"Unrecognised data type of X, expected data type to " + f"be in (np.ndarray, spmatrix, pd.DataFrame), but got :{type(arr)}") - return float(memory_in_bytes / (2**20)) + return float(arr.shape[0] * size_one_row / (2**20)) def reduce_dataset_size_if_too_large( @@ -479,10 +515,13 @@ def reduce_dataset_size_if_too_large( memory_allocation: Union[int, float], is_classification: bool, random_state: Union[int, np.random.RandomState], + categorical_columns: List, + n_categories_per_cat_column: Optional[List[int]] = None, y: Optional[SupportedTargetTypes] = None, methods: List[str] = ['precision', 'subsample'], ) -> DatasetCompressionInputType: - f""" Reduces the size of the dataset if it's too close to the memory limit. + f""" + Reduces the size of the dataset if it's too close to the memory limit. Follows the order of the operations passed in and retains the type of its input. @@ -513,7 +552,6 @@ def reduce_dataset_size_if_too_large( Reduce the amount of samples of the dataset such that it fits into the allocated memory. Ensures stratification and that unique labels are present - memory_allocation (Union[int, float]): The amount of memory to allocate to the dataset. It should specify an absolute amount. @@ -524,7 +562,7 @@ def reduce_dataset_size_if_too_large( """ for method in methods: - if megabytes(X) <= memory_allocation: + if get_approximate_mem_usage_in_mb(X, categorical_columns, n_categories_per_cat_column) <= memory_allocation: break if method == 'precision': @@ -540,7 +578,8 @@ def reduce_dataset_size_if_too_large( # into the allocated memory, we subsample it so that it does n_samples_before = X.shape[0] - sample_percentage = memory_allocation / megabytes(X) + sample_percentage = memory_allocation / get_approximate_mem_usage_in_mb( + X, categorical_columns, n_categories_per_cat_column) # NOTE: type ignore # diff --git a/autoPyTorch/datasets/tabular_dataset.py b/autoPyTorch/datasets/tabular_dataset.py index 6cabfe525..04a5df96b 100644 --- a/autoPyTorch/datasets/tabular_dataset.py +++ b/autoPyTorch/datasets/tabular_dataset.py @@ -81,7 +81,7 @@ def __init__(self, self.categorical_columns = validator.feature_validator.categorical_columns self.numerical_columns = validator.feature_validator.numerical_columns self.num_features = validator.feature_validator.num_features - self.categories = validator.feature_validator.categories + self.num_categories_per_col = validator.feature_validator.num_categories_per_col super().__init__(train_tensors=(X, Y), test_tensors=(X_test, Y_test), shuffle=shuffle, resampling_strategy=resampling_strategy, diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index 670eb44c9..9296f47df 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -559,7 +559,7 @@ def __init__(self, self.num_features: int = self.validator.feature_validator.num_features # type: ignore[assignment] self.num_targets: int = self.validator.target_validator.out_dimensionality # type: ignore[assignment] - self.categories = self.validator.feature_validator.categories + self.num_categories_per_col = self.validator.feature_validator.num_categories_per_col self.feature_shapes = self.validator.feature_shapes self.feature_names = tuple(self.validator.feature_names) @@ -1072,7 +1072,7 @@ def get_required_dataset_info(self) -> Dict[str, Any]: 'categorical_features': self.categorical_features, 'numerical_columns': self.numerical_columns, 'categorical_columns': self.categorical_columns, - 'categories': self.categories, + 'num_categories_per_col': self.num_categories_per_col, }) return info diff --git a/autoPyTorch/evaluation/train_evaluator.py b/autoPyTorch/evaluation/train_evaluator.py index f57d5b15a..392eee418 100644 --- a/autoPyTorch/evaluation/train_evaluator.py +++ b/autoPyTorch/evaluation/train_evaluator.py @@ -1,3 +1,5 @@ +import json +import os from multiprocessing.queues import Queue from typing import Any, Dict, List, Optional, Tuple, Union @@ -20,7 +22,9 @@ fit_and_suppress_warnings ) from autoPyTorch.evaluation.utils import DisableFileOutputParameters +from autoPyTorch.pipeline.base_pipeline import BasePipeline from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric +from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline from autoPyTorch.utils.common import dict_repr, subsampler from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates @@ -193,6 +197,8 @@ def fit_predict_and_loss(self) -> None: additional_run_info = pipeline.get_additional_run_info() if hasattr( pipeline, 'get_additional_run_info') else {} + # self._write_run_summary(pipeline) + status = StatusType.SUCCESS self.logger.debug("In train evaluator.fit_predict_and_loss, num_run: {} loss:{}," @@ -348,6 +354,27 @@ def fit_predict_and_loss(self) -> None: status=status, ) + def _write_run_summary(self, pipeline: BasePipeline) -> None: + # add learning curve of configurations to additional_run_info + if isinstance(pipeline, TabularClassificationPipeline): + assert isinstance(self.configuration, Configuration) + if hasattr(pipeline.named_steps['trainer'], 'run_summary'): + run_summary = pipeline.named_steps['trainer'].run_summary + split_types = ['train', 'val', 'test'] + run_summary_dict = dict( + run_summary={}, + budget=self.budget, + seed=self.seed, + config_id=self.configuration.config_id, + num_run=self.num_run) + for split_type in split_types: + run_summary_dict['run_summary'][f'{split_type}_loss'] = run_summary.performance_tracker.get( + f'{split_type}_loss', None) + run_summary_dict['run_summary'][f'{split_type}_metrics'] = run_summary.performance_tracker.get( + f'{split_type}_metrics', None) + with open(os.path.join(self.backend.temporary_directory, 'run_summary.txt'), 'a') as file: + file.write(f"{json.dumps(run_summary_dict)}\n") + def _fit_and_predict(self, pipeline: BaseEstimator, fold: int, train_indices: Union[np.ndarray, List], test_indices: Union[np.ndarray, List], add_pipeline_to_self: bool diff --git a/autoPyTorch/optimizer/smbo.py b/autoPyTorch/optimizer/smbo.py index 43b2c80c8..bbdb154f9 100644 --- a/autoPyTorch/optimizer/smbo.py +++ b/autoPyTorch/optimizer/smbo.py @@ -93,7 +93,8 @@ def get_smac_object( initial_design=None, run_id=seed, intensifier=intensifier, - intensifier_kwargs=intensifier_kwargs, + intensifier_kwargs={'initial_budget': initial_budget, 'max_budget': max_budget, + 'eta': 2, 'min_chall': 1, 'instance_order': 'shuffle_once'}, dask_client=dask_client, n_jobs=n_jobs, ) diff --git a/autoPyTorch/pipeline/base_pipeline.py b/autoPyTorch/pipeline/base_pipeline.py index 6ded2adf6..e6ae1bd59 100644 --- a/autoPyTorch/pipeline/base_pipeline.py +++ b/autoPyTorch/pipeline/base_pipeline.py @@ -300,8 +300,7 @@ def _get_hyperparameter_search_space(self, def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpace: """ Add forbidden conditions to ensure valid configurations. - Currently, Learned Entity Embedding is only valid when encoder is one hot encoder - and CyclicLR is disabled when using stochastic weight averaging and snapshot + Currently, CyclicLR is disabled when using stochastic weight averaging and snapshot ensembling. Args: @@ -314,33 +313,6 @@ def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpac """ - # Learned Entity Embedding is only valid when encoder is one hot encoder - if 'network_embedding' in self.named_steps.keys() and 'encoder' in self.named_steps.keys(): - embeddings = cs.get_hyperparameter('network_embedding:__choice__').choices - if 'LearnedEntityEmbedding' in embeddings: - encoders = cs.get_hyperparameter('encoder:__choice__').choices - possible_default_embeddings = copy(list(embeddings)) - del possible_default_embeddings[possible_default_embeddings.index('LearnedEntityEmbedding')] - - for encoder in encoders: - if encoder == 'OneHotEncoder': - continue - while True: - try: - cs.add_forbidden_clause(ForbiddenAndConjunction( - ForbiddenEqualsClause(cs.get_hyperparameter( - 'network_embedding:__choice__'), 'LearnedEntityEmbedding'), - ForbiddenEqualsClause(cs.get_hyperparameter('encoder:__choice__'), encoder) - )) - break - except ValueError: - # change the default and try again - try: - default = possible_default_embeddings.pop() - except IndexError: - raise ValueError("Cannot find a legal default configuration") - cs.get_hyperparameter('network_embedding:__choice__').default_value = default - # Disable CyclicLR until todo is completed. if 'lr_scheduler' in self.named_steps.keys() and 'trainer' in self.named_steps.keys(): trainers = cs.get_hyperparameter('trainer:__choice__').choices @@ -350,16 +322,19 @@ def _add_forbidden_conditions(self, cs: ConfigurationSpace) -> ConfigurationSpac cyclic_lr_name = 'CyclicLR' if cyclic_lr_name in available_schedulers: # disable snapshot ensembles and stochastic weight averaging - cs.add_forbidden_clause(ForbiddenAndConjunction( - ForbiddenEqualsClause(cs.get_hyperparameter( - f'trainer:{trainer}:use_snapshot_ensemble'), True), - ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) - )) - cs.add_forbidden_clause(ForbiddenAndConjunction( - ForbiddenEqualsClause(cs.get_hyperparameter( - f'trainer:{trainer}:use_stochastic_weight_averaging'), True), - ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) - )) + snapshot_ensemble_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_snapshot_ensemble') + if hasattr(snapshot_ensemble_hyperparameter, 'choices') and \ + True in snapshot_ensemble_hyperparameter.choices: + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(snapshot_ensemble_hyperparameter, True), + ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) + )) + swa_hyperparameter = cs.get_hyperparameter(f'trainer:{trainer}:use_stochastic_weight_averaging') + if hasattr(swa_hyperparameter, 'choices') and True in swa_hyperparameter.choices: + cs.add_forbidden_clause(ForbiddenAndConjunction( + ForbiddenEqualsClause(swa_hyperparameter, True), + ForbiddenEqualsClause(cs.get_hyperparameter('lr_scheduler:__choice__'), cyclic_lr_name) + )) return cs def __repr__(self) -> str: diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py index 6b38b4650..58a55a1df 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/TabularColumnTransformer.py @@ -23,7 +23,10 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N self.preprocessor: Optional[ColumnTransformer] = None self.add_fit_requirements([ FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True), - FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True)]) + FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True), + FitRequirement('encode_columns', (List,), user_defined=False, dataset_property=False), + FitRequirement('embed_columns', (List,), user_defined=False, dataset_property=False)]) + def get_column_transformer(self) -> ColumnTransformer: """ @@ -52,17 +55,31 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> "TabularColumnTransformer": self.check_requirements(X, y) preprocessors = get_tabular_preprocessers(X) + column_transformers: List[Tuple[str, BaseEstimator, List[int]]] = [] + + numerical_pipeline = 'passthrough' + encode_pipeline = 'passthrough' + if len(preprocessors['numerical']) > 0: numerical_pipeline = make_pipeline(*preprocessors['numerical']) - column_transformers.append( - ('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns']) - ) - if len(preprocessors['categorical']) > 0: - categorical_pipeline = make_pipeline(*preprocessors['categorical']) - column_transformers.append( - ('categorical_pipeline', categorical_pipeline, X['dataset_properties']['categorical_columns']) - ) + + column_transformers.append( + ('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns']) + ) + + if len(preprocessors['encode']) > 0: + encode_pipeline = make_pipeline(*preprocessors['encode']) + + column_transformers.append( + ('encode_pipeline', encode_pipeline, X['encode_columns']) + ) + + # if len(preprocessors['categorical']) > 0: + # categorical_pipeline = make_pipeline(*preprocessors['categorical']) + # column_transformers.append( + # ('categorical_pipeline', categorical_pipeline, X['dataset_properties']['categorical_columns']) + # ) # in case the preprocessing steps are disabled # i.e, NoEncoder for categorical, we want to diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/base_tabular_preprocessing.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/base_tabular_preprocessing.py index aefe9ddf8..74b1a4d58 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/base_tabular_preprocessing.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/base_tabular_preprocessing.py @@ -14,19 +14,19 @@ class autoPyTorchTabularPreprocessingComponent(autoPyTorchPreprocessingComponent def __init__(self) -> None: super().__init__() self.preprocessor: Union[Dict[str, Optional[BaseEstimator]], BaseEstimator] = dict( - numerical=None, categorical=None) + numerical=None, encode=None, categorical=None) def get_preprocessor_dict(self) -> Dict[str, BaseEstimator]: """ - Returns early_preprocessor dictionary containing the sklearn numerical - and categorical early_preprocessor with "numerical" and "categorical" - keys. May contain None for a key if early_preprocessor does not + Returns early_preprocessor dictionary containing the sklearn numerical, + categorical and encode early_preprocessor with "numerical", "categorical" + "encode" keys. May contain None for a key if early_preprocessor does not handle the datatype defined by key Returns: Dict[str, BaseEstimator]: early_preprocessor dictionary """ - if (self.preprocessor['numerical'] and self.preprocessor['categorical']) is None: + if (self.preprocessor['numerical'] and self.preprocessor['categorical'] and self.preprocessor['encode']) is None: raise AttributeError("{} can't return early_preprocessor dict without fitting first" .format(self.__class__.__name__)) return self.preprocessor diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/base_coalescer.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/base_coalescer.py index b572f8343..59918f62c 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/base_coalescer.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/coalescer/base_coalescer.py @@ -12,7 +12,6 @@ def __init__(self) -> None: self._processing = True self.add_fit_requirements([ FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True), - FitRequirement('categories', (List,), user_defined=True, dataset_property=True) ]) def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/ColumnSplitter.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/ColumnSplitter.py new file mode 100644 index 000000000..6902fb1bb --- /dev/null +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/ColumnSplitter.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + UniformIntegerHyperparameter, +) + +import numpy as np + +from autoPyTorch.constants import MIN_CATEGORIES_FOR_EMBEDDING_MAX +from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import \ + autoPyTorchTabularPreprocessingComponent +from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter + + +class ColumnSplitter(autoPyTorchTabularPreprocessingComponent): + """ + Splits categorical columns into embed or encode columns based on a hyperparameter. + """ + def __init__( + self, + min_categories_for_embedding: float = 5, + random_state: Optional[np.random.RandomState] = None + ): + self.min_categories_for_embedding = min_categories_for_embedding + self.random_state = random_state + + self.special_feature_types: Dict[str, List] = dict(encode_columns=[], embed_columns=[]) + self.num_categories_per_col: Optional[List] = None + super().__init__() + + def fit(self, X: Dict[str, Any], y: Optional[Any] = None) -> 'ColumnSplitter': + + self.check_requirements(X, y) + + if len(X['dataset_properties']['categorical_columns']) > 0: + self.num_categories_per_col = [] + for categories_per_column, column in zip(X['dataset_properties']['num_categories_per_col'], + X['dataset_properties']['categorical_columns']): + if ( + categories_per_column >= self.min_categories_for_embedding + ): + self.special_feature_types['embed_columns'].append(column) + # we only care about the categories for columns to be embedded + self.num_categories_per_col.append(categories_per_column) + else: + self.special_feature_types['encode_columns'].append(column) + + return self + + def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: + if self.num_categories_per_col is not None: + # update such that only n categories for embedding columns is passed + X['dataset_properties']['num_categories_per_col'] = self.num_categories_per_col + X.update(self.special_feature_types) + return X + + @staticmethod + def get_properties( + dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None + ) -> Dict[str, Union[str, bool]]: + + return { + 'shortname': 'ColumnSplitter', + 'name': 'Column Splitter', + 'handles_sparse': False, + } + + @staticmethod + def get_hyperparameter_search_space( + dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None, + min_categories_for_embedding: HyperparameterSearchSpace = HyperparameterSearchSpace( + hyperparameter="min_categories_for_embedding", + value_range=(3, MIN_CATEGORIES_FOR_EMBEDDING_MAX), + default_value=3, + log=True), + ) -> ConfigurationSpace: + cs = ConfigurationSpace() + + add_hyperparameter(cs, min_categories_for_embedding, UniformIntegerHyperparameter) + + return cs diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/__init__.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/column_splitting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OneHotEncoder.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OneHotEncoder.py index 5c9281891..2f382a574 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OneHotEncoder.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/OneHotEncoder.py @@ -20,12 +20,12 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEncoder: self.check_requirements(X, y) - self.preprocessor['categorical'] = OHE( - # It is safer to have the OHE produce a 0 array than to crash a good configuration - categories=X['dataset_properties']['categories'] - if len(X['dataset_properties']['categories']) > 0 else 'auto', - sparse=False, - handle_unknown='ignore') + if self._has_encode_columns(X): + self.preprocessor['encode'] = OHE( + # It is safer to have the OHE produce a 0 array than to crash a good configuration + sparse=False, + handle_unknown='ignore', + dtype=np.float32) return self @staticmethod diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/__init__.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/__init__.py index bca525781..eca46acb2 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/__init__.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/__init__.py @@ -86,7 +86,7 @@ def get_hyperparameter_search_space(self, "choices in {} got {}".format(self.__class__.__name__, available_preprocessors, choice_hyperparameter.value_range)) - if len(choice_hyperparameter) == 0: + if len(categorical_columns) == 0: assert len(choice_hyperparameter.value_range) == 1 assert 'NoEncoder' in choice_hyperparameter.value_range, \ "Provided {} in choices, however, the dataset " \ diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py index eadc0a188..b62822107 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/encoding/base_encoder.py @@ -13,8 +13,11 @@ class BaseEncoder(autoPyTorchTabularPreprocessingComponent): def __init__(self) -> None: super().__init__() self.add_fit_requirements([ - FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True), - FitRequirement('categories', (List,), user_defined=True, dataset_property=True)]) + FitRequirement('encode_columns', (List,), user_defined=True, dataset_property=False)]) + + @staticmethod + def _has_encode_columns(X: Dict[str, Any]): + return len(X.get('encode_columns', [])) > 0 def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: """ @@ -25,8 +28,5 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: Returns: (Dict[str, Any]): the updated 'X' dictionary """ - if self.preprocessor['numerical'] is None and self.preprocessor['categorical'] is None: - raise ValueError("cant call transform on {} without fitting first." - .format(self.__class__.__name__)) X.update({'encoder': self.preprocessor}) return X diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/utils.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/utils.py index a8c57959e..1968e9f3e 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/utils.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/utils.py @@ -1,6 +1,6 @@ import warnings from math import ceil, floor -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, Tuple from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.utils.common import HyperparameterSearchSpace, HyperparameterValueType @@ -82,17 +82,16 @@ def percentage_value_range_to_integer_range( else: log = hyperparameter_search_space.log - min_hyperparameter_value = hyperparameter_search_space.value_range[0] - if len(hyperparameter_search_space.value_range) > 1: - max_hyperparameter_value = hyperparameter_search_space.value_range[1] + value_range: Tuple + if len(hyperparameter_search_space.value_range) == 2: + value_range = (floor(float(hyperparameter_search_space.value_range[0]) * n_features), + floor(float(hyperparameter_search_space.value_range[-1]) * n_features)) else: - max_hyperparameter_value = hyperparameter_search_space.value_range[0] + value_range = (floor(float(hyperparameter_search_space.value_range[0]) * n_features),) hyperparameter_search_space = HyperparameterSearchSpace( hyperparameter=hyperparameter_name, - value_range=( - floor(float(min_hyperparameter_value) * n_features), - floor(float(max_hyperparameter_value) * n_features)), + value_range=value_range, default_value=ceil(float(hyperparameter_search_space.default_value) * n_features), log=log) else: diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/utils.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/utils.py index e71583e3e..20f0e0320 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/utils.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/utils.py @@ -21,7 +21,7 @@ def get_tabular_preprocessers(X: Dict[str, Any]) -> Dict[str, List[BaseEstimator Returns: (Dict[str, List[BaseEstimator]]): dictionary with list of numerical and categorical preprocessors """ - preprocessor: Dict[str, List[BaseEstimator]] = dict(numerical=list(), categorical=list()) + preprocessor: Dict[str, List[BaseEstimator]] = dict(numerical=list(), categorical=list(), encode=list()) for key, value in X.items(): if isinstance(value, dict): # as each preprocessor is child of BaseEstimator @@ -29,5 +29,7 @@ def get_tabular_preprocessers(X: Dict[str, Any]) -> Dict[str, List[BaseEstimator preprocessor['numerical'].append(value['numerical']) if 'categorical' in value and isinstance(value['categorical'], BaseEstimator): preprocessor['categorical'].append(value['categorical']) + if 'encode' in value and isinstance(value['encode'], BaseEstimator): + preprocessor['encode'].append(value['encode']) return preprocessor diff --git a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/TimeSeriesTransformer.py b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/TimeSeriesTransformer.py index ecca60570..3ee2e6227 100644 --- a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/TimeSeriesTransformer.py +++ b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/TimeSeriesTransformer.py @@ -12,7 +12,7 @@ autoPyTorchTimeSeriesPreprocessingComponent, autoPyTorchTimeSeriesTargetPreprocessingComponent) from autoPyTorch.pipeline.components.preprocessing.time_series_preprocessing.utils import ( - get_time_series_preprocessers, get_time_series_target_preprocessers) + get_time_series_preprocessors, get_time_series_target_preprocessers) from autoPyTorch.utils.common import FitRequirement @@ -24,6 +24,7 @@ def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = N self.add_fit_requirements([ FitRequirement('numerical_features', (List,), user_defined=True, dataset_property=True), FitRequirement('categorical_features', (List,), user_defined=True, dataset_property=True)]) + self.output_feature_order = None def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: """ @@ -38,18 +39,25 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: """ self.check_requirements(X, y) - preprocessors = get_time_series_preprocessers(X) + preprocessors = get_time_series_preprocessors(X) column_transformers: List[Tuple[str, BaseEstimator, List[int]]] = [] + + numerical_pipeline = 'passthrough' + encode_pipeline = 'passthrough' + if len(preprocessors['numerical']) > 0: numerical_pipeline = make_pipeline(*preprocessors['numerical']) - column_transformers.append( - ('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns']) - ) - if len(preprocessors['categorical']) > 0: - categorical_pipeline = make_pipeline(*preprocessors['categorical']) - column_transformers.append( - ('categorical_pipeline', categorical_pipeline, X['dataset_properties']['categorical_columns']) - ) + + column_transformers.append( + ('numerical_pipeline', numerical_pipeline, X['dataset_properties']['numerical_columns']) + ) + + if len(preprocessors['encode']) > 0: + encode_pipeline = make_pipeline(*preprocessors['encode']) + + column_transformers.append( + ('encode_pipeline', encode_pipeline, X['encode_columns']) + ) # in case the preprocessing steps are disabled # i.e, NoEncoder for categorical, we want to @@ -67,6 +75,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: X_train = X['backend'].load_datamanager().train_tensors[0] self.preprocessor.fit(X_train) + self.output_feature_order = self.get_output_column_orders(len(X['dataset_properties']['feature_names'])) return self def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: @@ -79,14 +88,14 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: Returns: X (Dict[str, Any]): updated fit dictionary """ - X.update({'time_series_feature_transformer': self}) + X.update({'time_series_feature_transformer': self, + 'feature_order_after_preprocessing': self.output_feature_order}) return X def __call__(self, X: pd.DataFrame) -> pd.DataFrame: if self.preprocessor is None: raise ValueError("cant call {} without fitting the column transformer first." .format(self.__class__.__name__)) - return self.preprocessor.transform(X) def get_column_transformer(self) -> ColumnTransformer: @@ -102,6 +111,33 @@ def get_column_transformer(self) -> ColumnTransformer: .format(self.__class__.__name__)) return self.preprocessor + def get_output_column_orders(self, n_input_columns: int) -> List[int]: + """ + get the order of the output features transformed by self.preprocessor + TODO: replace this function with self.preprocessor.get_feature_names_out() when switch to sklearn 1.0 ! + + Args: + n_input_columns (int): number of input columns that will be transformed + + Returns: + np.ndarray: a list of index indicating the order of each columns after transformation. Its length should + equal to n_input_columns + """ + if self.preprocessor is None: + raise ValueError("cant call {} without fitting the column transformer first." + .format(self.__class__.__name__)) + transformers = self.preprocessor.transformers + + n_reordered_input = np.arange(n_input_columns) + processed_columns = np.asarray([], dtype=np.int) + + for tran in transformers: + trans_columns = np.array(tran[-1], dtype=np.int) + unprocessed_columns = np.setdiff1d(processed_columns, trans_columns) + processed_columns = np.hstack([unprocessed_columns, trans_columns]) + unprocessed_columns = np.setdiff1d(n_reordered_input, processed_columns) + return np.hstack([processed_columns, unprocessed_columns]).tolist() # type: ignore[return-value] + class TimeSeriesTargetTransformer(autoPyTorchTimeSeriesTargetPreprocessingComponent): def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None): diff --git a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/base_time_series_preprocessing.py b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/base_time_series_preprocessing.py index e924d360d..4e83174ab 100644 --- a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/base_time_series_preprocessing.py +++ b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/base_time_series_preprocessing.py @@ -2,25 +2,17 @@ from sklearn.base import BaseEstimator -from autoPyTorch.pipeline.components.preprocessing.base_preprocessing import ( - autoPyTorchPreprocessingComponent, autoPyTorchTargetPreprocessingComponent) +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.base_tabular_preprocessing import ( + autoPyTorchTabularPreprocessingComponent +) +from autoPyTorch.pipeline.components.preprocessing.base_preprocessing import autoPyTorchTargetPreprocessingComponent -class autoPyTorchTimeSeriesPreprocessingComponent(autoPyTorchPreprocessingComponent): +class autoPyTorchTimeSeriesPreprocessingComponent(autoPyTorchTabularPreprocessingComponent): """ Provides abstract interface for time series preprocessing algorithms in AutoPyTorch. """ - def __init__(self) -> None: - super().__init__() - self.preprocessor: Union[Dict[str, Optional[BaseEstimator]], BaseEstimator] = dict( - numerical=None, categorical=None) - - def __str__(self) -> str: - """ Allow a nice understanding of what components where used """ - string = self.__class__.__name__ - return string - class autoPyTorchTimeSeriesTargetPreprocessingComponent(autoPyTorchTargetPreprocessingComponent): """ diff --git a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/column_spliting/ColumnSplitter.py b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/column_spliting/ColumnSplitter.py new file mode 100644 index 000000000..bf18ed479 --- /dev/null +++ b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/column_spliting/ColumnSplitter.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, Optional + +import numpy as np + +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.column_splitting.ColumnSplitter import ( + ColumnSplitter +) +from autoPyTorch.pipeline.components.preprocessing.time_series_preprocessing.base_time_series_preprocessing import \ + autoPyTorchTimeSeriesPreprocessingComponent + + +class TimeSeriesColumnSplitter(ColumnSplitter, autoPyTorchTimeSeriesPreprocessingComponent): + """ + Splits categorical columns into embed or encode columns based on a hyperparameter. + The splitter for time series is quite similar to the tabular splitter. However, we need to reserve the raw + number of categorical features for later use + """ + def __init__( + self, + min_categories_for_embedding: float = 5, + random_state: Optional[np.random.RandomState] = None + ): + super(TimeSeriesColumnSplitter, self).__init__(min_categories_for_embedding, random_state) + self.num_categories_per_col_encoded = None + + def fit(self, X: Dict[str, Any], y: Optional[Any] = None) -> 'TimeSeriesColumnSplitter': + super(TimeSeriesColumnSplitter, self).fit(X, y) + + self.num_categories_per_col_encoded = X['dataset_properties']['num_categories_per_col'] + for i in range(len(self.num_categories_per_col_encoded)): + if i in self.special_feature_types['embed_columns']: + self.num_categories_per_col_encoded[i] = 1 + return self + + def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: + X = super(TimeSeriesColumnSplitter, self).transform(X) + X['dataset_properties']['num_categories_per_col_encoded'] = self.num_categories_per_col_encoded + return X diff --git a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/column_spliting/__init__.py b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/column_spliting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/OneHotEncoder.py b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/OneHotEncoder.py index 5ac5e2550..274d05c1a 100644 --- a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/OneHotEncoder.py +++ b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/OneHotEncoder.py @@ -19,14 +19,15 @@ def __init__(self, def fit(self, X: Dict[str, Any], y: Any = None) -> TimeSeriesBaseEncoder: OneHotEncoder.fit(self, X, y) categorical_columns = X['dataset_properties']['categorical_columns'] - n_features_cat = X['dataset_properties']['categories'] + if 'num_categories_per_col_encoded' in X['dataset_properties']: + num_categories_per_col = X['dataset_properties']['num_categories_per_col_encoded'] + else: + num_categories_per_col = X['dataset_properties']['num_categories_per_col'] feature_names = X['dataset_properties']['feature_names'] feature_shapes = X['dataset_properties']['feature_shapes'] - if len(n_features_cat) == 0: - n_features_cat = self.preprocessor['categorical'].categories # type: ignore for i, cat_column in enumerate(categorical_columns): - feature_shapes[feature_names[cat_column]] = len(n_features_cat[i]) + feature_shapes[feature_names[cat_column]] = num_categories_per_col[i] self.feature_shapes = feature_shapes return self diff --git a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/__init__.py b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/__init__.py index 4170fff8e..da666957e 100644 --- a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/__init__.py +++ b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/__init__.py @@ -4,8 +4,7 @@ from autoPyTorch.pipeline.components.base_component import ( ThirdPartyComponents, autoPyTorchComponent, find_components) -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding import \ - EncoderChoice +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding import EncoderChoice from autoPyTorch.pipeline.components.preprocessing.time_series_preprocessing.encoding.time_series_base_encoder import \ TimeSeriesBaseEncoder diff --git a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/time_series_base_encoder.py b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/time_series_base_encoder.py index a3d64ee92..d456534a7 100644 --- a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/time_series_base_encoder.py +++ b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/encoding/time_series_base_encoder.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding.base_encoder import \ BaseEncoder @@ -7,7 +7,7 @@ from autoPyTorch.utils.common import FitRequirement -class TimeSeriesBaseEncoder(autoPyTorchTimeSeriesPreprocessingComponent): +class TimeSeriesBaseEncoder(autoPyTorchTimeSeriesPreprocessingComponent, BaseEncoder): """ Base class for encoder """ @@ -15,11 +15,11 @@ def __init__(self) -> None: super(TimeSeriesBaseEncoder, self).__init__() self.add_fit_requirements([ FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True), - FitRequirement('categories', (List,), user_defined=True, dataset_property=True), + FitRequirement('num_categories_per_col', (List,), user_defined=True, dataset_property=True), FitRequirement('feature_names', (tuple,), user_defined=True, dataset_property=True), FitRequirement('feature_shapes', (Dict, ), user_defined=True, dataset_property=True), ]) - self.feature_shapes: Union[Dict[str, int]] = {} + self.feature_shapes: Dict[str, int] = {} def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/utils.py b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/utils.py index 22252f0dd..66fc49529 100644 --- a/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/utils.py +++ b/autoPyTorch/pipeline/components/preprocessing/time_series_preprocessing/utils.py @@ -2,32 +2,14 @@ from sklearn.base import BaseEstimator +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.utils import get_tabular_preprocessers -def get_time_series_preprocessers(X: Dict[str, Any]) -> Dict[str, List[BaseEstimator]]: - """ - Expects fit_dictionary(X) to have numerical/categorical preprocessors - (fitted numerical/categorical preprocessing nodes) that will build a pipeline in the TimeSeriesTransformer. - This function parses X and extracts such components. - Creates a dictionary with two keys, - numerical- containing list of numerical preprocessors - categorical- containing list of categorical preprocessors - - Args: - X: fit dictionary - Returns: - (Dict[str, List[BaseEstimator]]): dictionary with list of numerical and categorical preprocessors +def get_time_series_preprocessors(X: Dict[str, Any]) -> Dict[str, List[BaseEstimator]]: """ - preprocessor = dict(numerical=list(), categorical=list()) # type: Dict[str, List[BaseEstimator]] - for key, value in X.items(): - if isinstance(value, dict): - # as each preprocessor is child of BaseEstimator - if 'numerical' in value and isinstance(value['numerical'], BaseEstimator): - preprocessor['numerical'].append(value['numerical']) - if 'categorical' in value and isinstance(value['categorical'], BaseEstimator): - preprocessor['categorical'].append(value['categorical']) - - return preprocessor + This function simply rename tabular preprocessor to time series preprocessor. + """ + return get_tabular_preprocessers(X) def get_time_series_target_preprocessers(X: Dict[str, Any]) -> Dict[str, List[BaseEstimator]]: diff --git a/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py b/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py index 597f14ca6..f912b07c1 100644 --- a/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py +++ b/autoPyTorch/pipeline/components/setup/early_preprocessor/EarlyPreprocessing.py @@ -10,7 +10,7 @@ from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.setup.base_setup import autoPyTorchSetupComponent -from autoPyTorch.pipeline.components.setup.early_preprocessor.utils import get_preprocess_transforms, preprocess +from autoPyTorch.pipeline.components.setup.early_preprocessor.utils import get_preprocess_transforms, get_preprocessed_dtype, preprocess from autoPyTorch.utils.common import FitRequirement @@ -39,8 +39,14 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: X['X_train'] = preprocess(dataset=X_train, transforms=transforms) + preprocessed_dtype = get_preprocessed_dtype(X['X_train']) + # We need to also save the preprocess transforms for inference - X.update({'preprocess_transforms': transforms}) + X.update({ + 'preprocess_transforms': transforms, + 'shape_after_preprocessing': X['X_train'].shape[1:], + 'preprocessed_dtype': preprocessed_dtype + }) return X @staticmethod diff --git a/autoPyTorch/pipeline/components/setup/early_preprocessor/TimeSeriesEarlyPreProcessing.py b/autoPyTorch/pipeline/components/setup/early_preprocessor/TimeSeriesEarlyPreProcessing.py index 59035869e..67cd4ce39 100644 --- a/autoPyTorch/pipeline/components/setup/early_preprocessor/TimeSeriesEarlyPreProcessing.py +++ b/autoPyTorch/pipeline/components/setup/early_preprocessor/TimeSeriesEarlyPreProcessing.py @@ -10,7 +10,7 @@ from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import \ EarlyPreprocessing from autoPyTorch.pipeline.components.setup.early_preprocessor.utils import ( - get_preprocess_transforms, time_series_preprocess) + get_preprocess_transforms, get_preprocessed_dtype, time_series_preprocess) from autoPyTorch.utils.common import FitRequirement @@ -19,22 +19,15 @@ def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None super(EarlyPreprocessing, self).__init__() self.random_state = random_state self.add_fit_requirements([ - FitRequirement('is_small_preprocess', (bool,), user_defined=True, dataset_property=True), FitRequirement('X_train', (pd.DataFrame, ), user_defined=True, dataset_property=False), FitRequirement('feature_names', (tuple,), user_defined=True, dataset_property=True), - FitRequirement('numerical_columns', (List,), user_defined=True, dataset_property=True), - FitRequirement('categorical_columns', (List,), user_defined=True, dataset_property=True), + FitRequirement('feature_order_after_preprocessing', (List,), user_defined=False, dataset_property=False) ]) def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: """ if dataset is small process, we transform the entire dataset here. - Before transformation, the order of the dataset is: - [(unknown_columns), categorical_columns, numerical_columns] - While after transformation, the order of the dataset is: - [numerical_columns, categorical_columns, unknown_columns] - we need to change feature_names and feature_shapes accordingly Args: X(Dict): fit dictionary @@ -44,28 +37,27 @@ def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: """ transforms = get_preprocess_transforms(X) - if X['dataset_properties']['is_small_preprocess']: - if 'X_train' in X: - X_train = X['X_train'] - else: - # Incorporate the transform to the dataset - X_train = X['backend'].load_datamanager().train_tensors[0] + if 'X_train' in X: + X_train = X['X_train'] + else: + # Incorporate the transform to the dataset + X_train = X['backend'].load_datamanager().train_tensors[0] - X['X_train'] = time_series_preprocess(dataset=X_train, transforms=transforms) + X['X_train'] = time_series_preprocess(dataset=X_train, transforms=transforms) feature_names = X['dataset_properties']['feature_names'] - numerical_columns = X['dataset_properties']['numerical_columns'] - categorical_columns = X['dataset_properties']['categorical_columns'] - - # resort feature_names - new_feature_names = [feature_names[num_col] for num_col in numerical_columns] - new_feature_names += [feature_names[cat_col] for cat_col in categorical_columns] - if set(feature_names) != set(new_feature_names): - new_feature_names += list(set(feature_names) - set(new_feature_names)) + + feature_order_after_preprocessing = X['feature_order_after_preprocessing'] + new_feature_names = (feature_names[i] for i in feature_order_after_preprocessing) X['dataset_properties']['feature_names'] = tuple(new_feature_names) + preprocessed_dtype = get_preprocessed_dtype(X['X_train']) # We need to also save the preprocess transforms for inference - X.update({'preprocess_transforms': transforms}) + X.update({ + 'preprocess_transforms': transforms, + 'shape_after_preprocessing': X['X_train'].shape[1:], + 'preprocessed_dtype': preprocessed_dtype + }) return X @staticmethod @@ -90,14 +82,13 @@ def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: # TODO consider inverse transformation transforms = get_preprocess_transforms(X, preprocess_type=autoPyTorchTargetPreprocessingComponent) - if X['dataset_properties']['is_small_preprocess']: - if 'y_train' in X: - y_train = X['y_train'] - else: - # Incorporate the transform to the dataset - y_train = X['backend'].load_datamanager().train_tensors[1] - - X['y_train'] = time_series_preprocess(dataset=y_train, transforms=transforms) + if 'y_train' in X: + y_train = X['y_train'] + else: + # Incorporate the transform to the dataset + y_train = X['backend'].load_datamanager().train_tensors[1] + + X['y_train'] = time_series_preprocess(dataset=y_train, transforms=transforms) # We need to also save the preprocess transforms for inference X.update({'preprocess_target_transforms': transforms}) diff --git a/autoPyTorch/pipeline/components/setup/early_preprocessor/utils.py b/autoPyTorch/pipeline/components/setup/early_preprocessor/utils.py index 830beced9..667e9c008 100644 --- a/autoPyTorch/pipeline/components/setup/early_preprocessor/utils.py +++ b/autoPyTorch/pipeline/components/setup/early_preprocessor/utils.py @@ -13,6 +13,7 @@ autoPyTorchPreprocessingComponent as aPTPre, autoPyTorchTargetPreprocessingComponent as aPTTPre ) +from .....utils.common import ispandas def get_preprocess_transforms(X: Dict[str, Any], @@ -71,3 +72,10 @@ def time_series_preprocess(dataset: pd.DataFrame, transforms: torchvision.transf sub_dataset = composite_transforms(sub_dataset) dataset.iloc[:, indices] = sub_dataset return dataset + + +def get_preprocessed_dtype(X_train: Union[np.ndarray, pd.DataFrame]): + if ispandas(X_train): + return X_train.dtypes[X_train.columns].name + else: + return X_train.dtype.name \ No newline at end of file diff --git a/autoPyTorch/pipeline/components/setup/network/forecasting_architecture.py b/autoPyTorch/pipeline/components/setup/network/forecasting_architecture.py index 0f3fb9875..6577e9b78 100644 --- a/autoPyTorch/pipeline/components/setup/network/forecasting_architecture.py +++ b/autoPyTorch/pipeline/components/setup/network/forecasting_architecture.py @@ -2,6 +2,7 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional, Tuple, Union +import numpy as np import torch from torch import nn from torch.distributions import AffineTransform, TransformedDistribution @@ -205,6 +206,7 @@ def __init__(self, auto_regressive: bool, feature_names: Union[Tuple[str], Tuple[()]] = (), known_future_features: Union[Tuple[str], Tuple[()]] = (), + embed_features_idx: Tuple[int] = (), feature_shapes: Dict[str, int] = {}, static_features: Union[Tuple[str], Tuple[()]] = (), time_feature_names: Union[Tuple[str], Tuple[()]] = (), @@ -218,7 +220,16 @@ def __init__(self, self.embedding = network_embedding if len(known_future_features) > 0: known_future_features_idx = [feature_names.index(kff) for kff in known_future_features] - self.decoder_embedding = self.embedding.get_partial_models(known_future_features_idx) + known_future_embed_features = np.where( + np.in1d(embed_features_idx, known_future_features_idx, assume_unique=True) + )[0] + idx_excl_embed_future_features = np.setdiff1d(known_future_features_idx, embed_features_idx) + n_excl_embed_features = sum(feature_shapes[feature_names[i]] for i in idx_excl_embed_future_features) + + self.decoder_embedding = self.embedding.get_partial_models( + n_excl_embed_features=n_excl_embed_features, + idx_embed_feat_partial=known_future_embed_features + ) else: self.decoder_embedding = _NoEmbedding() # modules that generate tensors while doing forward pass @@ -558,7 +569,7 @@ def pre_processing(self, return x_past, x_future, x_static, loc, scale, static_context_initial_hidden, past_targets else: if past_features is not None: - x_past = torch.cat([truncated_past_targets, past_features], dim=-1).to(device=self.device) + x_past = torch.cat([past_features, truncated_past_targets], dim=-1).to(device=self.device) x_past = self.embedding(x_past.to(device=self.device)) else: x_past = self.embedding(truncated_past_targets.to(device=self.device)) @@ -615,8 +626,8 @@ def forward(self, return self.rescale_output(output, loc, scale, self.device) def _unwrap_past_targets( - self, - past_targets: dict + self, + past_targets: dict ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], diff --git a/autoPyTorch/pipeline/components/setup/network/forecasting_network.py b/autoPyTorch/pipeline/components/setup/network/forecasting_network.py index 2750348a5..4ab7120e9 100644 --- a/autoPyTorch/pipeline/components/setup/network/forecasting_network.py +++ b/autoPyTorch/pipeline/components/setup/network/forecasting_network.py @@ -44,6 +44,7 @@ def __init__( FitRequirement("auto_regressive", (bool,), user_defined=False, dataset_property=False), FitRequirement("target_scaler", (BaseTargetScaler,), user_defined=False, dataset_property=False), FitRequirement("net_output_type", (str,), user_defined=False, dataset_property=False), + FitRequirement('embed_features_idx', (tuple,), user_defined=False, dataset_property=False), FitRequirement("feature_names", (Iterable,), user_defined=False, dataset_property=True), FitRequirement("feature_shapes", (Iterable,), user_defined=False, dataset_property=True), FitRequirement('transform_time_features', (bool,), user_defined=False, dataset_property=False), @@ -85,6 +86,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchTrainingComponent: feature_names=feature_names, feature_shapes=feature_shapes, known_future_features=known_future_features, + embed_features_idx=X['embed_features_idx'], time_feature_names=time_feature_names, static_features=X['dataset_properties']['static_features'] ) diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py index ef3cc1768..f63ebd578 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/base_network_backbone.py @@ -30,8 +30,7 @@ def __init__(self, self.add_fit_requirements([ FitRequirement('X_train', (np.ndarray, pd.DataFrame, spmatrix), user_defined=True, dataset_property=False), - FitRequirement('input_shape', (Iterable,), user_defined=True, dataset_property=True), - FitRequirement('tabular_transformer', (BaseEstimator,), user_defined=False, dataset_property=False), + FitRequirement('shape_after_preprocessing', (Iterable,), user_defined=False, dataset_property=False), FitRequirement('network_embedding', (nn.Module,), user_defined=False, dataset_property=False) ]) self.backbone: nn.Module = None @@ -49,9 +48,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: Self """ self.check_requirements(X, y) - X_train = X['X_train'] - input_shape = X_train.shape[1:] + input_shape = X['shape_after_preprocessing'] input_shape = get_output_shape(X['network_embedding'], input_shape=input_shape) self.input_shape = input_shape diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/utils.py b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py index a3216c7c1..ae8ea57e7 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/utils.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py @@ -27,7 +27,9 @@ def get_output_shape(network: torch.nn.Module, input_shape: Tuple[int, ...], has the network will return a Tuple, we will then only consider the first item :return: output_shape """ - placeholder = torch.randn((2, *input_shape), dtype=torch.float) + # as we are using nn embedding, 2 is a safe upper limit as 3 + # is the lowest `min_values_for_embedding` can be + placeholder = torch.randint(high=2, size=(2, *input_shape), dtype=torch.float) with torch.no_grad(): if has_hidden_states: output = network(placeholder)[0] diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py b/autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py index fdcf051bd..f3ca60634 100644 --- a/autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py +++ b/autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py @@ -1,117 +1,152 @@ +from math import ceil from typing import Any, Dict, List, Optional, Tuple, Union from ConfigSpace.configuration_space import ConfigurationSpace from ConfigSpace.hyperparameters import ( UniformFloatHyperparameter, - UniformIntegerHyperparameter + UniformIntegerHyperparameter, + CategoricalHyperparameter ) import numpy as np import torch -from torch import nn +from torch import embedding, nn from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType from autoPyTorch.pipeline.components.setup.network_embedding.base_network_embedding import NetworkEmbeddingComponent from autoPyTorch.utils.common import HyperparameterSearchSpace, add_hyperparameter +def get_num_output_dimensions(config: Dict[str, Any], num_categs_per_feature: List[int]) -> List[int]: + """ + Returns list of embedding sizes for each categorical variable. + Selects this adaptively based on training_datset. + Note: Assumes there is at least one embed feature. + Args: + config (Dict[str, Any]): + contains the hyperparameters required to calculate the `num_output_dimensions` + num_categs_per_feature (List[int]): + list containing number of categories for each feature that is to be embedded, + 0 if the column is not an embed column + Returns: + List[int]: + list containing the output embedding size for each column, + 1 if the column is not an embed column + """ + + max_embedding_dim = config['max_embedding_dim'] + embed_exponent = config['embed_exponent'] + size_factor = config['embedding_size_factor'] + num_output_dimensions = [int(size_factor*max( + 2, + min(max_embedding_dim, + 1.6 * num_categories**embed_exponent))) + if num_categories > 0 else 1 for num_categories in num_categs_per_feature] + return num_output_dimensions + + class _LearnedEntityEmbedding(nn.Module): """ Learned entity embedding module for categorical features""" - def __init__(self, config: Dict[str, Any], num_input_features: np.ndarray, num_numerical_features: int): + def __init__(self, config: Dict[str, Any], num_categories_per_col: np.ndarray, num_features_excl_embed: int): """ Args: config (Dict[str, Any]): The configuration sampled by the hyperparameter optimizer - num_input_features (np.ndarray): column wise information of number of output columns after transformation - for each categorical column and 0 for numerical columns - num_numerical_features (int): number of numerical features in X + num_categories_per_col (np.ndarray): number of categories per categorical columns that will be embedded + num_features_excl_embed (int): number of features in X excluding the features that need to be embedded """ super().__init__() self.config = config - - self.num_numerical = num_numerical_features # list of number of categories of categorical data # or 0 for numerical data - self.num_input_features = num_input_features - categorical_features: np.ndarray = self.num_input_features > 0 - - self.num_categorical_features = self.num_input_features[categorical_features] - - self.embed_features = [num_in >= config["min_unique_values_for_embedding"] for num_in in - self.num_input_features] - self.num_output_dimensions = [0] * num_numerical_features - self.num_output_dimensions.extend([config["dimension_reduction_" + str(i)] * num_in for i, num_in in - enumerate(self.num_categorical_features)]) - self.num_output_dimensions = [int(np.clip(num_out, 1, num_in - 1)) for num_out, num_in in - zip(self.num_output_dimensions, self.num_input_features)] - self.num_output_dimensions = [num_out if embed else num_in for num_out, embed, num_in in - zip(self.num_output_dimensions, self.embed_features, - self.num_input_features)] - self.num_out_feats = self.num_numerical + sum(self.num_output_dimensions) + self.num_categories_per_col = num_categories_per_col + self.embed_features = self.num_categories_per_col > 0 + self.num_features_excl_embed = num_features_excl_embed + + self.num_embed_features = self.num_categories_per_col[self.embed_features] + + self.num_output_dimensions = get_num_output_dimensions(config, self.num_categories_per_col) + + self.num_out_feats = num_features_excl_embed + sum(self.num_output_dimensions) self.ee_layers = self._create_ee_layers() - def get_partial_models(self, subset_features: List[int]) -> "_LearnedEntityEmbedding": + def get_partial_models(self, + n_excl_embed_features: int, + idx_embed_feat_partial: List[int]) -> "_LearnedEntityEmbedding": """ extract a partial models that only works on a subset of the data that ought to be passed to the embedding network, this function is implemented for time series forecasting tasks where the known future features is only a subset of the past features + Args: - subset_features (List[int]): - a set of index identifying which features will pass through the partial model + n_excl_embed_features (int): + number of unembedded features + idx_embed_feat_partial (List[int]): + a set of index identifying the which embedding features will be inherited by the partial model. This + index is used to extract self.ee_layers Returns: partial_model (_LearnedEntityEmbedding) a new partial model """ - num_input_features = self.num_input_features[subset_features] - num_numerical_features = sum([sf < self.num_numerical for sf in subset_features]) + n_partial_features = n_excl_embed_features + len(idx_embed_feat_partial) - num_output_dimensions = [self.num_output_dimensions[sf] for sf in subset_features] - embed_features = [self.embed_features[sf] for sf in subset_features] + num_categories_per_col = np.zeros(n_partial_features, dtype=np.int16) + num_output_dimensions = [1] * n_partial_features ee_layers = [] - ee_layer_tracker = 0 - for sf in subset_features: - if self.embed_features[sf]: - ee_layers.append(self.ee_layers[ee_layer_tracker]) - ee_layer_tracker += 1 + for idx, idx_embed in enumerate(idx_embed_feat_partial): + idx_raw = self.num_features_excl_embed + idx_embed + n_embed = self.num_categories_per_col[idx_raw] + n_output = self.num_output_dimensions[idx_raw] + + idx_new = n_excl_embed_features + idx + num_categories_per_col[idx_new] = n_embed + num_output_dimensions[idx_new] = n_output + + ee_layers.append(self.ee_layers[idx_embed]) + ee_layers = nn.ModuleList(ee_layers) - return PartialLearnedEntityEmbedding(num_input_features, num_numerical_features, embed_features, + embed_features = num_categories_per_col > 0 + + return PartialLearnedEntityEmbedding(num_categories_per_col, n_excl_embed_features, embed_features, num_output_dimensions, ee_layers) def forward(self, x: torch.Tensor) -> torch.Tensor: # pass the columns of each categorical feature through entity embedding layer # before passing it through the model concat_seq = [] - last_concat = 0 - x_pointer = 0 + layer_pointer = 0 - for num_in, embed in zip(self.num_input_features, self.embed_features): + x_pointer = 0 + # For forcasting architectures,besides the input features, we might also need to feed targets and time features + # to the embedding layers, which are not counted by self.embed_features. + for x_pointer, embed in enumerate(self.embed_features): if not embed: - x_pointer += 1 + current_feature_slice = x[..., [x_pointer]] + concat_seq.append(current_feature_slice) continue - if x_pointer > last_concat: - concat_seq.append(x[..., last_concat: x_pointer]) - categorical_feature_slice = x[..., x_pointer: x_pointer + num_in] - concat_seq.append(self.ee_layers[layer_pointer](categorical_feature_slice)) + current_feature_slice = x[..., x_pointer] + current_feature_slice = current_feature_slice.to(torch.int) + concat_seq.append(self.ee_layers[layer_pointer](current_feature_slice)) + layer_pointer += 1 - x_pointer += num_in - last_concat = x_pointer + concat_seq.append(x[..., x_pointer + 1:]) - concat_seq.append(x[..., last_concat:]) return torch.cat(concat_seq, dim=-1) def _create_ee_layers(self) -> nn.ModuleList: # entity embeding layers are Linear Layers layers = nn.ModuleList() - for i, (num_in, embed, num_out) in enumerate(zip(self.num_input_features, self.embed_features, - self.num_output_dimensions)): + for num_cat, embed, num_out in zip(self.num_categories_per_col, + self.embed_features, + self.num_output_dimensions): if not embed: continue - layers.append(nn.Linear(num_in, num_out)) + layers.append(nn.Embedding(num_cat, num_out)) return layers @@ -121,28 +156,27 @@ class PartialLearnedEntityEmbedding(_LearnedEntityEmbedding): of the input features. This is applied to forecasting tasks where not all the features might be known beforehand """ def __init__(self, - num_input_features: np.ndarray, - num_numerical_features: int, + num_categories_per_col: np.ndarray, + num_features_excl_embed: int, embed_features: List[bool], num_output_dimensions: List[int], ee_layers: nn.Module ): super(_LearnedEntityEmbedding, self).__init__() - self.num_numerical = num_numerical_features + self.num_features_excl_embed = num_features_excl_embed # list of number of categories of categorical data # or 0 for numerical data - self.num_input_features = num_input_features - categorical_features: np.ndarray = self.num_input_features > 0 - - self.num_categorical_features = self.num_input_features[categorical_features] + self.num_categories_per_col = num_categories_per_col self.embed_features = embed_features self.num_output_dimensions = num_output_dimensions - self.num_out_feats = self.num_numerical + sum(self.num_output_dimensions) + self.num_out_feats = self.num_features_excl_embed + sum(self.num_output_dimensions) self.ee_layers = ee_layers + self.num_embed_features = self.num_categories_per_col[self.embed_features] + class LearnedEntityEmbedding(NetworkEmbeddingComponent): """ @@ -153,38 +187,35 @@ def __init__(self, random_state: Optional[np.random.RandomState] = None, **kwarg super().__init__(random_state=random_state) self.config = kwargs - def build_embedding(self, - num_input_features: np.ndarray, - num_numerical_features: int) -> Tuple[nn.Module, List[int]]: + def build_embedding(self, num_categories_per_col: np.ndarray, num_features_excl_embed: int) -> nn.Module: embedding = _LearnedEntityEmbedding(config=self.config, - num_input_features=num_input_features, - num_numerical_features=num_numerical_features) + num_categories_per_col=num_categories_per_col, + num_features_excl_embed=num_features_excl_embed) + return embedding, embedding.num_output_dimensions @staticmethod def get_hyperparameter_search_space( dataset_properties: Optional[Dict[str, BaseDatasetPropertiesType]] = None, - min_unique_values_for_embedding: HyperparameterSearchSpace = HyperparameterSearchSpace( - hyperparameter="min_unique_values_for_embedding", - value_range=(3, 7), - default_value=5, - log=True), - dimension_reduction: HyperparameterSearchSpace = HyperparameterSearchSpace( - hyperparameter="dimension_reduction", - value_range=(0, 1), - default_value=0.5), + embed_exponent: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="embed_exponent", + value_range=(0.56,), + default_value=0.56), + max_embedding_dim: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="max_embedding_dim", + value_range=(100,), + default_value=100), + embedding_size_factor: HyperparameterSearchSpace = HyperparameterSearchSpace(hyperparameter="embedding_size_factor", + value_range=(0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5), + default_value=1, + ), ) -> ConfigurationSpace: cs = ConfigurationSpace() - add_hyperparameter(cs, min_unique_values_for_embedding, UniformIntegerHyperparameter) if dataset_properties is not None: - for i in range(len(dataset_properties['categorical_columns']) - if isinstance(dataset_properties['categorical_columns'], List) else 0): - ee_dimensions_search_space = HyperparameterSearchSpace(hyperparameter="dimension_reduction_" + str(i), - value_range=dimension_reduction.value_range, - default_value=dimension_reduction.default_value, - log=dimension_reduction.log) - add_hyperparameter(cs, ee_dimensions_search_space, UniformFloatHyperparameter) + if len(dataset_properties['categorical_columns']) > 0: + add_hyperparameter(cs, embed_exponent, UniformFloatHyperparameter) + add_hyperparameter(cs, max_embedding_dim, UniformIntegerHyperparameter) + add_hyperparameter(cs, embedding_size_factor, CategoricalHyperparameter) + return cs @staticmethod diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py b/autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py index 8fa03a65e..bcd782954 100644 --- a/autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py +++ b/autoPyTorch/pipeline/components/setup/network_embedding/NoEmbedding.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from ConfigSpace.configuration_space import ConfigurationSpace @@ -12,7 +12,7 @@ class _NoEmbedding(nn.Module): - def get_partial_models(self, subset_features: List[int]) -> "_NoEmbedding": + def get_partial_models(self, *args, **kwargs) -> "_NoEmbedding": return self def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -28,8 +28,8 @@ def __init__(self, random_state: Optional[np.random.RandomState] = None): super().__init__(random_state=random_state) def build_embedding(self, - num_input_features: np.ndarray, - num_numerical_features: int) -> Tuple[nn.Module, Optional[List[int]]]: + num_categories_per_col: np.ndarray, + num_features_excl_embed: int) -> Tuple[nn.Module, Optional[List[int]]]: return _NoEmbedding(), None @staticmethod diff --git a/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py index 1ff5df13e..b7c8f9206 100644 --- a/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py +++ b/autoPyTorch/pipeline/components/setup/network_embedding/base_network_embedding.py @@ -1,5 +1,4 @@ -import copy -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple import numpy as np @@ -8,72 +7,92 @@ from torch import nn from autoPyTorch.pipeline.components.setup.base_setup import autoPyTorchSetupComponent +from autoPyTorch.utils.common import FitRequirement class NetworkEmbeddingComponent(autoPyTorchSetupComponent): - def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None): - super().__init__() + def __init__(self, random_state: Optional[np.random.RandomState] = None): + super().__init__(random_state=random_state) + self.add_fit_requirements([ + FitRequirement('num_categories_per_col', (List,), user_defined=True, dataset_property=True), + FitRequirement('shape_after_preprocessing', (Tuple[int],), user_defined=False, dataset_property=False)]) + self.embedding: Optional[nn.Module] = None self.random_state = random_state self.feature_shapes: Dict[str, int] = {} + self.embed_features_idx: Optional[Tuple] = None def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: - num_numerical_columns, num_input_features = self._get_args(X) + num_features_excl_embed, num_categories_per_col = self._get_required_info_from_data(X) self.embedding, num_output_features = self.build_embedding( - num_input_features=num_input_features, - num_numerical_features=num_numerical_columns + num_categories_per_col=num_categories_per_col, + num_features_excl_embed=num_features_excl_embed ) if "feature_shapes" in X['dataset_properties']: + n_features_embedded = len(num_categories_per_col) - num_features_excl_embed if num_output_features is not None: feature_shapes = X['dataset_properties']['feature_shapes'] # forecasting tasks feature_names = X['dataset_properties']['feature_names'] - for idx_cat, n_output_cat in enumerate(num_output_features[num_numerical_columns:]): - cat_feature_name = feature_names[idx_cat + num_numerical_columns] - feature_shapes[cat_feature_name] = n_output_cat + n_features_all = len(feature_names) + # embedded feature index + embed_features_idx = tuple(range(n_features_all - n_features_embedded, n_features_all)) + for idx, n_output_embedded in zip(embed_features_idx, num_output_features[-n_features_embedded:]): + feat_name = feature_names[idx] + feature_shapes[feat_name] = n_output_embedded + self.embed_features_idx = embed_features_idx self.feature_shapes = feature_shapes else: self.feature_shapes = X['dataset_properties']['feature_shapes'] + self.embed_features_idx = [] return self def transform(self, X: Dict[str, Any]) -> Dict[str, Any]: X.update({'network_embedding': self.embedding}) if "feature_shapes" in X['dataset_properties']: X['dataset_properties'].update({"feature_shapes": self.feature_shapes}) + X['embed_features_idx'] = self.embed_features_idx return X def build_embedding(self, - num_input_features: np.ndarray, + num_categories_per_col: np.ndarray, num_numerical_features: int) -> Tuple[nn.Module, Optional[List[int]]]: raise NotImplementedError - def _get_args(self, X: Dict[str, Any]) -> Tuple[int, np.ndarray]: - # Feature preprocessors can alter numerical columns - if len(X['dataset_properties']['numerical_columns']) == 0: - num_numerical_columns = 0 - else: - X_train = copy.deepcopy(X['backend'].load_datamanager().train_tensors[0][:2]) - - if 'tabular_transformer' in X: - numerical_column_transformer = X['tabular_transformer'].preprocessor. \ - named_transformers_['numerical_pipeline'] - elif 'time_series_feature_transformer' in X: - numerical_column_transformer = X['time_series_feature_transformer'].preprocessor. \ - named_transformers_['numerical_pipeline'] - else: - raise ValueError("Either a tabular or time_series transformer must be contained!") - if hasattr(X_train, 'iloc'): - num_numerical_columns = numerical_column_transformer.transform( - X_train.iloc[:, X['dataset_properties']['numerical_columns']]).shape[1] - else: - num_numerical_columns = numerical_column_transformer.transform( - X_train[:, X['dataset_properties']['numerical_columns']]).shape[1] - num_input_features = np.zeros((num_numerical_columns + len(X['dataset_properties']['categorical_columns'])), - dtype=np.int32) - categories = X['dataset_properties']['categories'] - - for i, category in enumerate(categories): - num_input_features[num_numerical_columns + i, ] = len(category) - return num_numerical_columns, num_input_features + def _get_required_info_from_data(self, X: Dict[str, Any]) -> Tuple[int, np.ndarray]: + """ + Returns the number of numerical columns after preprocessing and + an array of size equal to the number of input features + containing zeros for numerical data and number of categories + for categorical data. This is required to build the embedding. + + Args: + X (Dict[str, Any]): + Fit dictionary + + Returns: + Tuple[int, np.ndarray]: + number of numerical columns and array indicating + number of categories for categorical columns and + 0 for numerical columns + """ + if X['dataset_properties']['target_type'] == 'time_series_forecasting' \ + and X['dataset_properties'].get('uni_variant', False): + # For uni_variant time series forecasting tasks, we don't have the related information for embeddings + return 0, np.asarray([]) + + num_cols = X['shape_after_preprocessing'] + # only works for 2D(rows, features) tabular data + num_features_excl_embed = num_cols[0] - len(X['embed_columns']) + + num_categories_per_col = np.zeros(num_cols, dtype=np.int16) + + categories_per_embed_col = X['dataset_properties']['num_categories_per_col'] + + # only fill num categories for embedding columns + for idx, cats in enumerate(categories_per_embed_col, start=num_features_excl_embed): + num_categories_per_col[idx] = cats + + return num_features_excl_embed, num_categories_per_col diff --git a/autoPyTorch/pipeline/components/training/trainer/__init__.py b/autoPyTorch/pipeline/components/training/trainer/__init__.py index b70467837..13a106de6 100755 --- a/autoPyTorch/pipeline/components/training/trainer/__init__.py +++ b/autoPyTorch/pipeline/components/training/trainer/__init__.py @@ -447,15 +447,23 @@ def _fit(self, X: Dict[str, Any], y: Any = None, **kwargs: Any) -> 'TrainerChoic raise RuntimeError("Budget exhausted without finishing an epoch.") if self.choice.use_stochastic_weight_averaging and self.choice.swa_updated: + # By default, we assume the data is double. Only if the data was preprocessed, + # we check the dtype and use it accordingly + preprocessed_dtype = X.get('preprocessed_dtype', None) + if preprocessed_dtype is None: + use_double = True + else: + use_double = 'float64' in preprocessed_dtype or 'int64' in preprocessed_dtype # update batch norm statistics - swa_utils.update_bn(loader=X['train_data_loader'], model=self.choice.swa_model.double()) - + swa_model = self.choice.swa_model.double() if use_double else self.choice.swa_model + swa_utils.update_bn(loader=X['train_data_loader'], model=swa_model) # change model update_model_state_dict_from_swa(X['network'], self.choice.swa_model.state_dict()) if self.choice.use_snapshot_ensemble: # we update only the last network which pertains to the stochastic weight averaging model - swa_utils.update_bn(X['train_data_loader'], self.choice.model_snapshots[-1].double()) + snapshot_model = self.choice.model_snapshots[-1].double() if use_double else self.choice.model_snapshots[-1] + swa_utils.update_bn(X['train_data_loader'], snapshot_model) # wrap up -- add score if not evaluating every epoch if not self.eval_valid_each_epoch(X): @@ -492,7 +500,7 @@ def _get_train_label(self, X: Dict[str, Any]) -> List[int]: Verifies and validates the labels from train split. """ # Ensure that the split is not missing any class. - labels: List[int] = X['y_train'][X['backend'].load_datamanager().splits[X['split_id']][0]] + labels: List[int] = X['y_train'][X['train_indices']] if STRING_TO_TASK_TYPES[X['dataset_properties']['task_type']] in CLASSIFICATION_TASKS: unique_labels = len(np.unique(labels)) if unique_labels < X['dataset_properties']['output_shape']: diff --git a/autoPyTorch/pipeline/tabular_classification.py b/autoPyTorch/pipeline/tabular_classification.py index 09eb47485..1b49f0d36 100644 --- a/autoPyTorch/pipeline/tabular_classification.py +++ b/autoPyTorch/pipeline/tabular_classification.py @@ -17,8 +17,8 @@ from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.TabularColumnTransformer import ( TabularColumnTransformer ) -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer import ( - CoalescerChoice +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.column_splitting.ColumnSplitter import ( + ColumnSplitter ) from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding import ( EncoderChoice @@ -28,8 +28,6 @@ ) from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.imputation.SimpleImputer import SimpleImputer from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling import ScalerChoice -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.variance_thresholding. \ - VarianceThreshold import VarianceThreshold from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import EarlyPreprocessing from autoPyTorch.pipeline.components.setup.lr_scheduler import SchedulerChoice from autoPyTorch.pipeline.components.setup.network.base_network import NetworkComponent @@ -53,20 +51,21 @@ class TabularClassificationPipeline(ClassifierMixin, BasePipeline): It implements a pipeline, which includes the following as steps: 1. `imputer` - 2. `encoder` - 3. `scaler` - 4. `feature_preprocessor` - 5. `tabular_transformer` - 6. `preprocessing` - 7. `network_embedding` - 8. `network_backbone` - 9. `network_head` - 10. `network` - 11. `network_init` - 12. `optimizer` - 13. `lr_scheduler` - 14. `data_loader` - 15. `trainer` + 2. `column_splitter + 3. `encoder` + 4. `scaler` + 5. `feature_preprocessor` + 6. `tabular_transformer` + 7. `preprocessing` + 8. `network_embedding` + 9. `network_backbone` + 10. `network_head` + 11. `network` + 12. `network_init` + 13. `optimizer` + 14. `lr_scheduler` + 15. `data_loader` + 16. `trainer` Contrary to the sklearn API it is not possible to enumerate the possible parameters in the __init__ function because we only know the @@ -132,21 +131,23 @@ def __init__( # model, so we comply with https://pytorch.org/docs/stable/notes/randomness.html torch.manual_seed(self.random_state.get_state()[1][0]) - def _predict_proba(self, X: np.ndarray) -> np.ndarray: - # Pre-process X - loader = self.named_steps['data_loader'].get_loader(X=X) - pred = self.named_steps['network'].predict(loader) - if isinstance(self.dataset_properties['output_shape'], int): - # The final layer is always softmax now (`pred` already gives pseudo proba) - return pred - else: - raise ValueError("Expected output_shape to be integer, got {}," - "Tabular Classification only supports 'binary' and 'multiclass' outputs" - "got {}".format(type(self.dataset_properties['output_shape']), - self.dataset_properties['output_type'])) + def predict(self, X: np.ndarray, batch_size: Optional[int] = None) -> np.ndarray: + """Predict the output using the selected model. + + Args: + X (np.ndarray): input data to the array + batch_size (Optional[int]): batch_size controls whether the pipeline will be + called on small chunks of the data. Useful when calling the + predict method on the whole array X results in a MemoryError. + + Returns: + np.ndarray: the predicted values given input X + """ + probas = super().predict(X=X, batch_size=batch_size) + return np.argmax(probas, axis=1) def predict_proba(self, X: np.ndarray, batch_size: Optional[int] = None) -> np.ndarray: - """predict_proba. + """predict probabilities. Args: X (np.ndarray): @@ -160,30 +161,19 @@ def predict_proba(self, X: np.ndarray, batch_size: Optional[int] = None) -> np.n Probabilities of the target being certain class """ if batch_size is None: - y = self._predict_proba(X) - + warnings.warn("Batch size not provided. " + "Will predict on the whole data in a single iteration") + batch_size = X.shape[0] + loader = self.named_steps['data_loader'].get_loader(X=X, batch_size=batch_size) + pred = self.named_steps['network'].predict(loader) + if isinstance(self.dataset_properties['output_shape'], int): + # The final layer is always softmax now (`pred` already gives pseudo proba) + return pred else: - if not isinstance(batch_size, int): - raise ValueError("Argument 'batch_size' must be of type int, " - "but is '%s'" % type(batch_size)) - if batch_size <= 0: - raise ValueError("Argument 'batch_size' must be positive, " - "but is %d" % batch_size) - - else: - # Probe for the target array dimensions - target = self.predict_proba(X[0:2].copy()) - - y = np.zeros((X.shape[0], target.shape[1]), - dtype=np.float32) - - for k in range(max(1, int(np.ceil(float(X.shape[0]) / batch_size)))): - batch_from = k * batch_size - batch_to = min([(k + 1) * batch_size, X.shape[0]]) - pred_prob = self.predict_proba(X[batch_from:batch_to], batch_size=None) - y[batch_from:batch_to] = pred_prob.astype(np.float32) - - return y + raise ValueError("Expected output_shape to be integer, got {}," + "Tabular Classification only supports 'binary' and 'multiclass' outputs" + "got {}".format(type(self.dataset_properties['output_shape']), + self.dataset_properties['output_type'])) def score(self, X: np.ndarray, y: np.ndarray, batch_size: Optional[int] = None, @@ -207,7 +197,7 @@ def score(self, X: np.ndarray, y: np.ndarray, """ from autoPyTorch.pipeline.components.training.metrics.utils import get_metrics, calculate_score metrics = get_metrics(self.dataset_properties, [metric_name]) - y_pred = self.predict(X, batch_size=batch_size) + y_pred = self.predict_proba(X, batch_size=batch_size) score = calculate_score(y, y_pred, task_type=STRING_TO_TASK_TYPES[str(self.dataset_properties['task_type'])], metrics=metrics)[metric_name] return score @@ -286,8 +276,9 @@ def _get_pipeline_steps( steps.extend([ ("imputer", SimpleImputer(random_state=self.random_state)), - ("variance_threshold", VarianceThreshold(random_state=self.random_state)), - ("coalescer", CoalescerChoice(default_dataset_properties, random_state=self.random_state)), + # ("variance_threshold", VarianceThreshold(random_state=self.random_state)), + # ("coalescer", CoalescerChoice(default_dataset_properties, random_state=self.random_state)), + ("column_splitter", ColumnSplitter(random_state=self.random_state)), ("encoder", EncoderChoice(default_dataset_properties, random_state=self.random_state)), ("scaler", ScalerChoice(default_dataset_properties, random_state=self.random_state)), ("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties, diff --git a/autoPyTorch/pipeline/tabular_regression.py b/autoPyTorch/pipeline/tabular_regression.py index 4cd67bb9f..1cf60e561 100644 --- a/autoPyTorch/pipeline/tabular_regression.py +++ b/autoPyTorch/pipeline/tabular_regression.py @@ -17,8 +17,8 @@ from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.TabularColumnTransformer import ( TabularColumnTransformer ) -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.coalescer import ( - CoalescerChoice +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.column_splitting.ColumnSplitter import ( + ColumnSplitter ) from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding import ( EncoderChoice @@ -28,8 +28,6 @@ ) from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.imputation.SimpleImputer import SimpleImputer from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling import ScalerChoice -from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.variance_thresholding. \ - VarianceThreshold import VarianceThreshold from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import EarlyPreprocessing from autoPyTorch.pipeline.components.setup.lr_scheduler import SchedulerChoice from autoPyTorch.pipeline.components.setup.network.base_network import NetworkComponent @@ -55,20 +53,21 @@ class TabularRegressionPipeline(RegressorMixin, BasePipeline): It implements a pipeline, which includes the following as steps: 1. `imputer` - 2. `encoder` - 3. `scaler` - 4. `feature_preprocessor` - 5. `tabular_transformer` - 6. `preprocessing` - 7. `network_embedding` - 8. `network_backbone` - 9. `network_head` - 10. `network` - 11. `network_init` - 12. `optimizer` - 13. `lr_scheduler` - 14. `data_loader` - 15. `trainer` + 2. `column_splitter + 3. `encoder` + 4. `scaler` + 5. `feature_preprocessor` + 6. `tabular_transformer` + 7. `preprocessing` + 8. `network_embedding` + 9. `network_backbone` + 10. `network_head` + 11. `network` + 12. `network_init` + 13. `optimizer` + 14. `lr_scheduler` + 15. `data_loader` + 16. `trainer` Contrary to the sklearn API it is not possible to enumerate the possible parameters in the __init__ function because we only know the @@ -234,8 +233,9 @@ def _get_pipeline_steps( steps.extend([ ("imputer", SimpleImputer(random_state=self.random_state)), - ("variance_threshold", VarianceThreshold(random_state=self.random_state)), - ("coalescer", CoalescerChoice(default_dataset_properties, random_state=self.random_state)), + # ("variance_threshold", VarianceThreshold(random_state=self.random_state)), + # ("coalescer", CoalescerChoice(default_dataset_properties, random_state=self.random_state)), + ("column_splitter", ColumnSplitter(random_state=self.random_state)), ("encoder", EncoderChoice(default_dataset_properties, random_state=self.random_state)), ("scaler", ScalerChoice(default_dataset_properties, random_state=self.random_state)), ("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties, diff --git a/autoPyTorch/pipeline/time_series_forecasting.py b/autoPyTorch/pipeline/time_series_forecasting.py index 53143e4df..bf2f53e95 100644 --- a/autoPyTorch/pipeline/time_series_forecasting.py +++ b/autoPyTorch/pipeline/time_series_forecasting.py @@ -26,6 +26,9 @@ from autoPyTorch.pipeline.components.preprocessing.time_series_preprocessing.TimeSeriesTransformer import ( TimeSeriesFeatureTransformer ) +from autoPyTorch.pipeline.components.preprocessing.time_series_preprocessing.column_spliting.ColumnSplitter import ( + TimeSeriesColumnSplitter +) from autoPyTorch.pipeline.components.preprocessing.time_series_preprocessing.encoding import TimeSeriesEncoderChoice from autoPyTorch.pipeline.components.preprocessing.time_series_preprocessing.imputation.TimeSeriesImputer import ( TimeSeriesFeatureImputer, @@ -333,6 +336,7 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]]) -> L if not default_dataset_properties.get("uni_variant", False): steps.extend([("impute", TimeSeriesFeatureImputer(random_state=self.random_state)), ("scaler", BaseScaler(random_state=self.random_state)), + ("column_splitter", TimeSeriesColumnSplitter(random_state=self.random_state)), ('feature_encoding', TimeSeriesEncoderChoice(default_dataset_properties, random_state=self.random_state)), ("time_series_transformer", TimeSeriesFeatureTransformer(random_state=self.random_state)), diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 12b12c3ad..f71ad3f5f 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -4,6 +4,7 @@ import pickle import tempfile import unittest +import unittest.mock from test.test_api.utils import ( dummy_do_dummy_prediction, dummy_eval_train_function, @@ -684,6 +685,7 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): @unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function', new=dummy_eval_train_function) @pytest.mark.parametrize('openml_id', (40981, )) +@pytest.mark.skip(reason="Fix with new portfolio PR") def test_portfolio_selection(openml_id, backend, n_samples): # Get the data and check that contents of data-manager make sense @@ -723,6 +725,7 @@ def test_portfolio_selection(openml_id, backend, n_samples): assert any(successful_config in portfolio_configs for successful_config in successful_configs) +@pytest.mark.skip(reason="Fix with new portfolio PR") @unittest.mock.patch('autoPyTorch.evaluation.tae.eval_train_function', new=dummy_eval_train_function) @pytest.mark.parametrize('openml_id', (40981, )) @@ -871,7 +874,7 @@ def test_pipeline_fit(openml_id, configuration = estimator.get_search_space(dataset).get_default_configuration() pipeline, run_info, run_value, dataset = estimator.fit_pipeline(dataset=dataset, configuration=configuration, - run_time_limit_secs=50, + run_time_limit_secs=70, disable_file_output=disable_file_output, budget_type='epochs', budget=budget diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index 099ee691f..1ccb91b2f 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -288,7 +288,7 @@ def test_features_unsupported_calls_are_raised(): expected """ validator = TabularFeatureValidator() - with pytest.raises(TypeError, match=r"Valid types are `numerical`, `categorical` or `boolean`, but input column"): + with pytest.raises(TypeError, match=r"Valid types are .*"): validator.fit( pd.DataFrame({'datetime': [pd.Timestamp('20180310')]}) ) @@ -298,7 +298,7 @@ def test_features_unsupported_calls_are_raised(): validator.fit({'input1': 1, 'input2': 2}) validator = TabularFeatureValidator() - with pytest.raises(TypeError, match=r"Valid types are `numerical`, `categorical` or `boolean`, but input column"): + with pytest.raises(TypeError, match=r"Valid types are .*"): validator.fit(pd.DataFrame([{'A': 1, 'B': 2}], dtype='string')) validator = TabularFeatureValidator() @@ -430,7 +430,7 @@ def test_unknown_encode_value(): assert expected_row == x_t[0].tolist() # Notice how there is only one column 'c' to encode - assert validator.categories == [list(range(2)) for i in range(1)] + assert validator.num_categories_per_col == [2] # Actual checks for the features @@ -485,13 +485,13 @@ def test_feature_validator_new_data_after_fit( if train_data_type == 'pandas': old_dtypes = copy.deepcopy(validator.dtypes) validator.dtypes = ['dummy' for dtype in X_train.dtypes] - with pytest.raises(ValueError, match=r"The dtype of the features must not be changed after fit()"): + with pytest.raises(ValueError, match=r"The dtype of the features must not be changed after fit.*"): transformed_X = validator.transform(X_test) validator.dtypes = old_dtypes if test_data_type == 'pandas': columns = X_test.columns.tolist() X_test = X_test[reversed(columns)] - with pytest.raises(ValueError, match=r"The column order of the features"): + with pytest.raises(ValueError, match=r"The column order of the features must not be changed after fit.*"): transformed_X = validator.transform(X_test) diff --git a/test/test_data/test_utils.py b/test/test_data/test_utils.py index 4269c4e5f..6228740b0 100644 --- a/test/test_data/test_utils.py +++ b/test/test_data/test_utils.py @@ -25,7 +25,7 @@ from autoPyTorch.data.utils import ( default_dataset_compression_arg, get_dataset_compression_mapping, - megabytes, + get_raw_memory_usage, reduce_dataset_size_if_too_large, reduce_precision, subsample, @@ -45,13 +45,14 @@ def test_reduce_dataset_if_too_large(openmlid, as_frame, n_samples): X.copy(), y=y.copy(), is_classification=True, + categorical_columns=[], random_state=1, - memory_allocation=0.001) + memory_allocation=0.01) assert X_converted.shape[0] < X.shape[0] assert y_converted.shape[0] < y.shape[0] - assert megabytes(X_converted) < megabytes(X) + assert get_raw_memory_usage(X_converted) < get_raw_memory_usage(X) @pytest.mark.parametrize("X", [np.asarray([[1, 1, 1]] * 30)]) @@ -211,8 +212,18 @@ def test_unsupported_errors(): ['a', 'b', 'c', 'a', 'b', 'c'], ['a', 'b', 'd', 'r', 'b', 'c']]) with pytest.raises(ValueError, match=r'X.dtype = .*'): - reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0) + reduce_dataset_size_if_too_large( + X, + is_classification=True, + categorical_columns=[], + random_state=1, + memory_allocation=0) X = [[1, 2], [2, 3]] with pytest.raises(ValueError, match=r'Unrecognised data type of X, expected data type to be in .*'): - reduce_dataset_size_if_too_large(X, is_classification=True, random_state=1, memory_allocation=0) + reduce_dataset_size_if_too_large( + X, + is_classification=True, + categorical_columns=[], + random_state=1, + memory_allocation=0) diff --git a/test/test_data/test_validation.py b/test/test_data/test_validation.py index af46be55f..b6f05f7ba 100644 --- a/test/test_data/test_validation.py +++ b/test/test_data/test_validation.py @@ -8,7 +8,8 @@ import sklearn.model_selection from autoPyTorch.data.tabular_validator import TabularInputValidator -from autoPyTorch.data.utils import megabytes +from autoPyTorch.data.utils import get_approximate_mem_usage_in_mb +from autoPyTorch.utils.common import ispandas @pytest.mark.parametrize('openmlid', [2, 40975, 40984]) @@ -148,16 +149,36 @@ def test_featurevalidator_dataset_compression(input_data_featuretest): X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( input_data_featuretest, input_data_targets, test_size=0.1, random_state=1) validator = TabularInputValidator( - dataset_compression={'memory_allocation': 0.8 * megabytes(X_train), 'methods': ['precision', 'subsample']} + dataset_compression={ + 'memory_allocation': 0.8 * get_approximate_mem_usage_in_mb(X_train, [], None), + 'methods': ['precision', 'subsample']} ) validator.fit(X_train=X_train, y_train=y_train) transformed_X_train, _ = validator.transform(X_train.copy(), y_train.copy()) + if ispandas(X_train): + # input validator converts transformed_X_train to numpy and the cat columns are chosen as column indices + columns = X_train.columns + categorical_columns = [columns[col] for col in validator.feature_validator.categorical_columns] + else: + categorical_columns = validator.feature_validator.categorical_columns + assert validator._reduced_dtype is not None - assert megabytes(transformed_X_train) < megabytes(X_train) + assert get_approximate_mem_usage_in_mb( + transformed_X_train, + validator.feature_validator.categorical_columns, + validator.feature_validator.num_categories_per_col + ) < get_approximate_mem_usage_in_mb( + X_train, categorical_columns, validator.feature_validator.num_categories_per_col) transformed_X_test, _ = validator.transform(X_test.copy(), y_test.copy()) - assert megabytes(transformed_X_test) < megabytes(X_test) + assert get_approximate_mem_usage_in_mb( + transformed_X_test, + validator.feature_validator.categorical_columns, + validator.feature_validator.num_categories_per_col + ) < get_approximate_mem_usage_in_mb( + X_test, categorical_columns, validator.feature_validator.num_categories_per_col) + if hasattr(transformed_X_train, 'iloc'): assert all(transformed_X_train.dtypes == transformed_X_test.dtypes) assert all(transformed_X_train.dtypes == validator._precision) diff --git a/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py b/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py index a81eb34a2..f5f928bd8 100644 --- a/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py +++ b/test/test_pipeline/components/preprocessing/test_tabular_column_transformer.py @@ -13,8 +13,6 @@ ) -# TODO: fix in preprocessing PR -# @pytest.mark.skip("Skipping tests as preprocessing is not finalised") @pytest.mark.parametrize("fit_dictionary_tabular", ['classification_numerical_only', 'classification_categorical_only', 'classification_numerical_and_categorical'], indirect=True) diff --git a/test/test_pipeline/components/setup/forecasting/forecasting_networks/test_forecasting_architecture.py b/test/test_pipeline/components/setup/forecasting/forecasting_networks/test_forecasting_architecture.py index 252fe7d1d..815594010 100644 --- a/test/test_pipeline/components/setup/forecasting/forecasting_networks/test_forecasting_architecture.py +++ b/test/test_pipeline/components/setup/forecasting/forecasting_networks/test_forecasting_architecture.py @@ -24,22 +24,22 @@ from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdate -class ReducedEmbedding(torch.nn.Module): +class IncrementalEmbedding(torch.nn.Module): # a dummy reduced embedding, it simply cut row for each categorical features - def __init__(self, num_input_features, num_numerical_features: int): - super(ReducedEmbedding, self).__init__() - self.num_input_features = num_input_features - self.num_numerical_features = num_numerical_features - self.n_cat_features = len(num_input_features) - num_numerical_features + def __init__(self, n_excl_embed_features, embed_feat_idx): + super(IncrementalEmbedding, self).__init__() + self.n_excl_embed_features = n_excl_embed_features + self.embed_feat_idx = embed_feat_idx def forward(self, x): - x = x[..., :-self.n_cat_features] + if len(self.embed_feat_idx) > 0: + x = torch.cat([x, x[..., -len(self.embed_feat_idx):]], dim=-1) return x - def get_partial_models(self, subset_features): - num_numerical_features = sum([sf < self.num_numerical_features for sf in subset_features]) - num_input_features = [self.num_input_features[sf] for sf in subset_features] - return ReducedEmbedding(num_input_features, num_numerical_features) + def get_partial_models(self, n_excl_embed_features, idx_embed_feat_partial): + n_excl_embed_features = n_excl_embed_features + embed_feat_idx = [self.embed_feat_idx[idx] for idx in idx_embed_feat_partial] + return IncrementalEmbedding(n_excl_embed_features, embed_feat_idx) @pytest.fixture(params=['ForecastingNet', 'ForecastingSeq2SeqNet', 'ForecastingDeepARNet', 'NBEATSNet']) @@ -52,7 +52,7 @@ def network_encoder(request): return request.param -@pytest.fixture(params=['ReducedEmbedding', 'NoEmbedding']) +@pytest.fixture(params=['IncrementalEmbedding', 'NoEmbedding']) def embedding(request): return request.param @@ -110,7 +110,7 @@ def test_network_forward(self, dataset_properties['known_future_features'] = ('f1', 'f3', 'f5') if with_static_features: - dataset_properties['static_features'] = (0, 4) + dataset_properties['static_features'] = (0, 3) else: dataset_properties['static_features'] = tuple() @@ -130,10 +130,14 @@ def test_network_forward(self, fit_dictionary['net_output_type'] = net_output_type if embedding == 'NoEmbedding': + embed_features_idx = () fit_dictionary['network_embedding'] = _NoEmbedding() + fit_dictionary['embed_features_idx'] = embed_features_idx else: - fit_dictionary['network_embedding'] = ReducedEmbedding([10] * 5, 2) - dataset_properties['feature_shapes'] = {'f1': 10, 'f2': 10, 'f3': 9, 'f4': 9, 'f5': 9} + embed_features_idx = (3, 4) + fit_dictionary['network_embedding'] = IncrementalEmbedding(50, embed_features_idx) + fit_dictionary['embed_features_idx'] = embed_features_idx + dataset_properties['feature_shapes'] = {'f1': 10, 'f2': 10, 'f3': 10, 'f4': 11, 'f5': 11} if uni_variant_data: fit_dictionary['X_train'] = None diff --git a/test/test_pipeline/components/setup/test_setup_networks.py b/test/test_pipeline/components/setup/test_setup_networks.py index f5e9b1bb7..8fa77560f 100644 --- a/test/test_pipeline/components/setup/test_setup_networks.py +++ b/test/test_pipeline/components/setup/test_setup_networks.py @@ -19,8 +19,7 @@ def head(request): return request.param -# TODO: add 'LearnedEntityEmbedding' after preprocessing dix -@pytest.fixture(params=['NoEmbedding']) +@pytest.fixture(params=['NoEmbedding', 'LearnedEntityEmbedding']) def embedding(request): return request.param diff --git a/test/test_pipeline/components/setup/test_setup_preprocessing_node.py b/test/test_pipeline/components/setup/test_setup_preprocessing_node.py index 1ec858864..5d3b49923 100644 --- a/test/test_pipeline/components/setup/test_setup_preprocessing_node.py +++ b/test/test_pipeline/components/setup/test_setup_preprocessing_node.py @@ -37,7 +37,7 @@ def test_tabular_preprocess(self): 'is_small_preprocess': True, 'input_shape': (15,), 'output_shape': 2, - 'categories': [], + 'num_categories_per_col': [], 'issparse': False } X = dict(X_train=np.random.random((10, 15)), @@ -64,43 +64,6 @@ def test_tabular_preprocess(self): # We expect the transformation always for inference self.assertIn('preprocess_transforms', X.keys()) - def test_tabular_no_preprocess(self): - dataset_properties = { - 'numerical_columns': list(range(15)), - 'categorical_columns': [], - 'task_type': TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION], - 'output_type': OUTPUT_TYPES_TO_STRING[MULTICLASS], - 'is_small_preprocess': False, - 'input_shape': (15,), - 'output_shape': 2, - 'categories': [], - 'issparse': False - } - X = dict(X_train=np.random.random((10, 15)), - y_train=np.random.random(10), - train_indices=[0, 1, 2, 3, 4, 5], - val_indices=[6, 7, 8, 9], - dataset_properties=dataset_properties, - # Training configuration - num_run=16, - device='cpu', - budget_type='epochs', - epochs=10, - torch_num_threads=1, - early_stopping=20, - split_id=0, - backend=self.backend, - ) - - pipeline = TabularClassificationPipeline(dataset_properties=dataset_properties) - # Remove the trainer - pipeline.steps.pop() - pipeline = pipeline.fit(X) - X = pipeline.transform(X) - self.assertIn('preprocess_transforms', X.keys()) - self.assertIsInstance(X['preprocess_transforms'], list) - self.assertIsInstance(X['preprocess_transforms'][-1].preprocessor, BaseEstimator) - class ImagePreprocessingTest(unittest.TestCase): def setUp(self): diff --git a/test/test_pipeline/components/training/test_image_data_loader.py b/test/test_pipeline/components/training/test_image_data_loader.py index af70cf77b..98a10373b 100644 --- a/test/test_pipeline/components/training/test_image_data_loader.py +++ b/test/test_pipeline/components/training/test_image_data_loader.py @@ -16,7 +16,6 @@ def test_imageloader_build_transform(): fit_dictionary = dict() fit_dictionary['dataset_properties'] = dict() - fit_dictionary['dataset_properties']['is_small_preprocess'] = unittest.mock.Mock(()) fit_dictionary['image_augmenter'] = unittest.mock.Mock() fit_dictionary['preprocess_transforms'] = unittest.mock.Mock() diff --git a/test/test_pipeline/components/training/test_training.py b/test/test_pipeline/components/training/test_training.py index ae85cad4d..397488468 100644 --- a/test/test_pipeline/components/training/test_training.py +++ b/test/test_pipeline/components/training/test_training.py @@ -101,7 +101,7 @@ def test_fit_transform(self): 'y_train': np.array([0, 1, 0]), 'train_indices': [0, 1], 'val_indices': [2], - 'dataset_properties': {'is_small_preprocess': True}, + 'dataset_properties': {}, 'working_dir': '/tmp', 'split_id': 0, 'backend': backend, @@ -513,7 +513,7 @@ def dummy_performance(*args, **kwargs): 'step_interval': StepIntervalUnit.batch } for item in ['backend', 'lr_scheduler', 'network', 'optimizer', 'train_data_loader', 'val_data_loader', - 'device', 'y_train', 'network_snapshots']: + 'device', 'y_train', 'network_snapshots', 'train_indices']: fit_dictionary[item] = unittest.mock.MagicMock() fit_dictionary['backend'].temporary_directory = tempfile.mkdtemp() diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index 3e4e3bde5..c3f7f49f8 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -123,8 +123,8 @@ def test_pipeline_predict(self, fit_dictionary_tabular, exclude): pipeline.fit(fit_dictionary_tabular) # we expect the output to have the same batch size as the test input, - # and number of outputs per batch sample equal to the number of outputs - expected_output_shape = (X.shape[0], fit_dictionary_tabular["dataset_properties"]["output_shape"]) + # and number of outputs per batch sample equal to 1 + expected_output_shape = (X.shape[0], ) prediction = pipeline.predict(X) assert isinstance(prediction, np.ndarray) @@ -205,15 +205,12 @@ def test_pipeline_transform(self, fit_dictionary_tabular, exclude): # We expect the transformations to be in the pipeline at anytime for inference assert 'preprocess_transforms' in transformed_fit_dictionary_tabular.keys() - @pytest.mark.parametrize("is_small_preprocess", [True, False]) - def test_default_configuration(self, fit_dictionary_tabular, is_small_preprocess, exclude): + def test_default_configuration(self, fit_dictionary_tabular, exclude): """Makes sure that when no config is set, we can trust the default configuration from the space""" fit_dictionary_tabular['epochs'] = 5 - fit_dictionary_tabular['is_small_preprocess'] = is_small_preprocess - pipeline = TabularClassificationPipeline( dataset_properties=fit_dictionary_tabular['dataset_properties'], exclude=exclude) @@ -435,9 +432,9 @@ def test_trainer_cocktails(self, fit_dictionary_tabular, mocker, lr_scheduler, t len(X['network_snapshots']) == config.get(f'trainer:{trainer}:se_lastk') mocker.patch("autoPyTorch.pipeline.components.setup.network.base_network.NetworkComponent._predict", - return_value=torch.Tensor([1])) + return_value=torch.Tensor([[1, 0]])) # Assert that predict gives no error when swa and se are on - assert isinstance(pipeline.predict(fit_dictionary_tabular['X_train']), np.ndarray) + assert isinstance(pipeline.predict(X['X_train']), np.ndarray) # As SE is True, _predict should be called 3 times assert pipeline.named_steps['network']._predict.call_count == 3 diff --git a/test/test_pipeline/test_tabular_regression.py b/test/test_pipeline/test_tabular_regression.py index a2c3b695e..e2e770a24 100644 --- a/test/test_pipeline/test_tabular_regression.py +++ b/test/test_pipeline/test_tabular_regression.py @@ -61,11 +61,9 @@ def test_pipeline_fit(self, fit_dictionary_tabular): """This test makes sure that the pipeline is able to fit given random combinations of hyperparameters across the pipeline""" # TODO: fix issue where adversarial also works for regression - # TODO: Fix issue with learned entity embedding after preprocessing PR pipeline = TabularRegressionPipeline( dataset_properties=fit_dictionary_tabular['dataset_properties'], - exclude={'trainer': ['AdversarialTrainer'], - 'network_embedding': ['LearnedEntityEmbedding']}) + exclude={'trainer': ['AdversarialTrainer']}) cs = pipeline.get_hyperparameter_search_space() config = cs.sample_configuration() @@ -91,8 +89,7 @@ def test_pipeline_predict(self, fit_dictionary_tabular): X = fit_dictionary_tabular['X_train'].copy() pipeline = TabularRegressionPipeline( dataset_properties=fit_dictionary_tabular['dataset_properties'], - exclude={'trainer': ['AdversarialTrainer'], - 'network_embedding': ['LearnedEntityEmbedding']}) + exclude={'trainer': ['AdversarialTrainer']}) cs = pipeline.get_hyperparameter_search_space() config = cs.sample_configuration() @@ -121,8 +118,7 @@ def test_pipeline_transform(self, fit_dictionary_tabular): pipeline = TabularRegressionPipeline( dataset_properties=fit_dictionary_tabular['dataset_properties'], - exclude={'trainer': ['AdversarialTrainer'], - 'network_embedding': ['LearnedEntityEmbedding']}) + exclude={'trainer': ['AdversarialTrainer']}) cs = pipeline.get_hyperparameter_search_space() config = cs.sample_configuration() pipeline.set_hyperparameters(config) @@ -139,11 +135,10 @@ def test_pipeline_transform(self, fit_dictionary_tabular): assert fit_dictionary_tabular.items() <= transformed_fit_dictionary_tabular.items() # Then the pipeline should have added the following keys - # Removing 'imputer', 'encoder', 'scaler', these will be - # TODO: added back after a PR fixing preprocessing expected_keys = {'tabular_transformer', 'preprocess_transforms', 'network', 'optimizer', 'lr_scheduler', 'train_data_loader', - 'val_data_loader', 'run_summary', 'feature_preprocessor'} + 'val_data_loader', 'run_summary', 'feature_preprocessor', + 'imputer', 'encoder', 'scaler'} assert expected_keys.issubset(set(transformed_fit_dictionary_tabular.keys())) # Then we need to have transformations being created. @@ -152,13 +147,10 @@ def test_pipeline_transform(self, fit_dictionary_tabular): # We expect the transformations to be in the pipeline at anytime for inference assert 'preprocess_transforms' in transformed_fit_dictionary_tabular.keys() - @pytest.mark.parametrize("is_small_preprocess", [True, False]) - def test_default_configuration(self, fit_dictionary_tabular, is_small_preprocess): + def test_default_configuration(self, fit_dictionary_tabular): """Makes sure that when no config is set, we can trust the default configuration from the space""" - fit_dictionary_tabular['is_small_preprocess'] = is_small_preprocess - pipeline = TabularRegressionPipeline( dataset_properties=fit_dictionary_tabular['dataset_properties'], exclude={'trainer': ['AdversarialTrainer']}) diff --git a/test/test_pipeline/test_time_series_forecasting_pipeline.py b/test/test_pipeline/test_time_series_forecasting_pipeline.py index 3e34b71b7..09cc6b5f0 100644 --- a/test/test_pipeline/test_time_series_forecasting_pipeline.py +++ b/test/test_pipeline/test_time_series_forecasting_pipeline.py @@ -46,7 +46,7 @@ class TestTimeSeriesForecastingPipeline: "multi_variant_only_num"], indirect=True) def test_fit_predict(self, fit_dictionary_forecasting, forecasting_budgets): dataset_properties = fit_dictionary_forecasting['dataset_properties'] - if not dataset_properties['uni_variant'] and len(dataset_properties['categories']) > 0: + if not dataset_properties['uni_variant'] and len(dataset_properties['num_categories_per_col']) > 0: include = {'network_embedding': ['LearnedEntityEmbedding']} else: include = None