diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py
index c4fa0e7ce..3c712efa9 100644
--- a/autoPyTorch/api/base_task.py
+++ b/autoPyTorch/api/base_task.py
@@ -10,7 +10,6 @@
 import time
 import typing
 import unittest.mock
-import uuid
 import warnings
 from abc import abstractmethod
 from typing import Any, Callable, Dict, List, Optional, Union, cast
@@ -782,13 +781,15 @@ def _search(
                              ":{}".format(self.task_type, dataset.task_type))
 
         # Initialise information needed for the experiment
-        experiment_task_name = 'runSearch'
+        experiment_task_name: str = 'runSearch'
         dataset_requirements = get_dataset_requirements(
             info=self._get_required_dataset_properties(dataset))
         self._dataset_requirements = dataset_requirements
         dataset_properties = dataset.get_dataset_properties(dataset_requirements)
         self._stopwatch.start_task(experiment_task_name)
         self.dataset_name = dataset.dataset_name
+        assert self.dataset_name is not None
+
         if self._logger is None:
             self._logger = self._get_logger(self.dataset_name)
         self._all_supported_metrics = all_supported_metrics
@@ -897,7 +898,7 @@ def _search(
                 start_time=time.time(),
                 time_left_for_ensembles=time_left_for_ensembles,
                 backend=copy.deepcopy(self._backend),
-                dataset_name=dataset.dataset_name,
+                dataset_name=str(dataset.dataset_name),
                 output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type],
                 task_type=STRING_TO_TASK_TYPES[self.task_type],
                 metrics=[self._metric],
@@ -916,7 +917,7 @@ def _search(
             self._stopwatch.stop_task(ensemble_task_name)
 
         # ==> Run SMAC
-        smac_task_name = 'runSMAC'
+        smac_task_name: str = 'runSMAC'
         self._stopwatch.start_task(smac_task_name)
         elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name)
         time_left_for_smac = max(0, total_walltime_limit - elapsed_time)
@@ -928,7 +929,7 @@ def _search(
 
             _proc_smac = AutoMLSMBO(
                 config_space=self.search_space,
-                dataset_name=dataset.dataset_name,
+                dataset_name=str(dataset.dataset_name),
                 backend=self._backend,
                 total_walltime_limit=total_walltime_limit,
                 func_eval_time_limit_secs=func_eval_time_limit_secs,
@@ -1035,11 +1036,11 @@ def refit(
         Returns:
             self
         """
-        if self.dataset_name is None:
-            self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
+
+        self.dataset_name = dataset.dataset_name
 
         if self._logger is None:
-            self._logger = self._get_logger(self.dataset_name)
+            self._logger = self._get_logger(str(self.dataset_name))
 
         dataset_requirements = get_dataset_requirements(
             info=self._get_required_dataset_properties(dataset))
@@ -1105,11 +1106,10 @@ def fit(self,
         Returns:
             (BasePipeline): fitted pipeline
         """
-        if self.dataset_name is None:
-            self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
+        self.dataset_name = dataset.dataset_name
 
         if self._logger is None:
-            self._logger = self._get_logger(self.dataset_name)
+            self._logger = self._get_logger(str(self.dataset_name))
 
         # get dataset properties
         dataset_requirements = get_dataset_requirements(
diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py
index 4c19fa17d..2f99e54f7 100644
--- a/autoPyTorch/datasets/base_dataset.py
+++ b/autoPyTorch/datasets/base_dataset.py
@@ -1,3 +1,5 @@
+import os
+import uuid
 from abc import ABCMeta
 from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
 
@@ -13,18 +15,17 @@
 
 from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
 from autoPyTorch.datasets.resampling_strategy import (
-    CROSS_VAL_FN,
+    CrossValFunc,
+    CrossValFuncs,
     CrossValTypes,
     DEFAULT_RESAMPLING_PARAMETERS,
-    HOLDOUT_FN,
-    HoldoutValTypes,
-    get_cross_validators,
-    get_holdout_validators,
-    is_stratified,
+    HoldOutFunc,
+    HoldOutFuncs,
+    HoldoutValTypes
 )
-from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
+from autoPyTorch.utils.common import FitRequirement
 
-BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
+BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
 
 
 def check_valid_data(data: Any) -> None:
@@ -33,7 +34,8 @@ def check_valid_data(data: Any) -> None:
             'The specified Data for Dataset must have both __getitem__ and __len__ attribute.')
 
 
-def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None:
+def type_check(train_tensors: BaseDatasetInputType,
+               val_tensors: Optional[BaseDatasetInputType] = None) -> None:
     """To avoid unexpected behavior, we use loops over indices."""
     for i in range(len(train_tensors)):
         check_valid_data(train_tensors[i])
@@ -49,8 +51,8 @@ class TransformSubset(Subset):
     we require different transformation for each data point.
     This class helps to take the subset of the dataset
     with either training or validation transformation.
-
-    We achieve so by adding a train flag to the pytorch subset
+    The TransformSubset allows to add train flags
+    while indexing the main dataset towards this goal.
 
     Attributes:
         dataset (BaseDataset/Dataset): Dataset to sample the subset
@@ -71,10 +73,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
 class BaseDataset(Dataset, metaclass=ABCMeta):
     def __init__(
         self,
-        train_tensors: BaseDatasetType,
+        train_tensors: BaseDatasetInputType,
         dataset_name: Optional[str] = None,
-        val_tensors: Optional[BaseDatasetType] = None,
-        test_tensors: Optional[BaseDatasetType] = None,
+        val_tensors: Optional[BaseDatasetInputType] = None,
+        test_tensors: Optional[BaseDatasetInputType] = None,
         resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
         resampling_strategy_args: Optional[Dict[str, Any]] = None,
         shuffle: Optional[bool] = True,
@@ -106,14 +108,16 @@ def __init__(
             val_transforms (Optional[torchvision.transforms.Compose]):
                 Additional Transforms to be applied to the validation/test data
         """
-        self.dataset_name = dataset_name if dataset_name is not None \
-            else hash_array_or_matrix(train_tensors[0])
+        self.dataset_name = dataset_name
+
+        if self.dataset_name is None:
+            self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
 
         if not hasattr(train_tensors[0], 'shape'):
             type_check(train_tensors, val_tensors)
         self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
-        self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
-        self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
+        self.cross_validators: Dict[str, CrossValFunc] = {}
+        self.holdout_validators: Dict[str, HoldOutFunc] = {}
         self.rng = np.random.RandomState(seed=seed)
         self.shuffle = shuffle
         self.resampling_strategy = resampling_strategy
@@ -134,8 +138,8 @@ def __init__(
         self.is_small_preprocess = True
 
         # Make sure cross validation splits are created once
-        self.cross_validators = get_cross_validators(*CrossValTypes)
-        self.holdout_validators = get_holdout_validators(*HoldoutValTypes)
+        self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
+        self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
         self.splits = self.get_splits_from_resampling_strategy()
 
         # We also need to be able to transform the data, be it for pre-processing
@@ -263,7 +267,7 @@ def create_cross_val_splits(
         if not isinstance(cross_val_type, CrossValTypes):
             raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
         kwargs = {}
-        if is_stratified(cross_val_type):
+        if cross_val_type.is_stratified():
             # we need additional information about the data for stratification
             kwargs["stratify"] = self.train_tensors[-1]
         splits = self.cross_validators[cross_val_type.name](
@@ -298,7 +302,7 @@ def create_holdout_val_split(
         if not isinstance(holdout_val_type, HoldoutValTypes):
             raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
         kwargs = {}
-        if is_stratified(holdout_val_type):
+        if holdout_val_type.is_stratified():
             # we need additional information about the data for stratification
             kwargs["stratify"] = self.train_tensors[-1]
         train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs)
@@ -321,7 +325,8 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
         return (TransformSubset(self, self.splits[split_id][0], train=True),
                 TransformSubset(self, self.splits[split_id][1], train=False))
 
-    def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset':
+    def replace_data(self, X_train: BaseDatasetInputType,
+                     X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':
         """
         To speed up the training of small dataset, early pre-processing of the data
         can be made on the fly by the pipeline.
diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py
index b853fac0a..765a31cdb 100644
--- a/autoPyTorch/datasets/resampling_strategy.py
+++ b/autoPyTorch/datasets/resampling_strategy.py
@@ -16,7 +16,7 @@
 
 
 # Use callback protocol as workaround, since callable with function fields count 'self' as argument
-class CROSS_VAL_FN(Protocol):
+class CrossValFunc(Protocol):
     def __call__(self,
                  num_splits: int,
                  indices: np.ndarray,
@@ -24,25 +24,57 @@ def __call__(self,
         ...
 
 
-class HOLDOUT_FN(Protocol):
+class HoldOutFunc(Protocol):
     def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
                  ) -> Tuple[np.ndarray, np.ndarray]:
         ...
 
 
 class CrossValTypes(IntEnum):
+    """The type of cross validation
+
+    This class is used to specify the cross validation function
+    and is not supposed to be instantiated.
+
+    Examples: This class is supposed to be used as follows
+    >>> cv_type = CrossValTypes.k_fold_cross_validation
+    >>> print(cv_type.name)
+
+    k_fold_cross_validation
+
+    >>> for cross_val_type in CrossValTypes:
+            print(cross_val_type.name, cross_val_type.value)
+
+    stratified_k_fold_cross_validation 1
+    k_fold_cross_validation 2
+    stratified_shuffle_split_cross_validation 3
+    shuffle_split_cross_validation 4
+    time_series_cross_validation 5
+    """
     stratified_k_fold_cross_validation = 1
     k_fold_cross_validation = 2
     stratified_shuffle_split_cross_validation = 3
     shuffle_split_cross_validation = 4
     time_series_cross_validation = 5
 
+    def is_stratified(self) -> bool:
+        stratified = [self.stratified_k_fold_cross_validation,
+                      self.stratified_shuffle_split_cross_validation]
+        return getattr(self, self.name) in stratified
+
 
 class HoldoutValTypes(IntEnum):
+    """TODO: change to enum using functools.partial"""
+    """The type of hold out validation (refer to CrossValTypes' doc-string)"""
     holdout_validation = 6
     stratified_holdout_validation = 7
 
+    def is_stratified(self) -> bool:
+        stratified = [self.stratified_holdout_validation]
+        return getattr(self, self.name) in stratified
+
 
+# TODO: replace it with another way
 RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes]
 
 DEFAULT_RESAMPLING_PARAMETERS = {
@@ -67,87 +99,111 @@ class HoldoutValTypes(IntEnum):
 }  # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]
 
 
-def get_cross_validators(*cross_val_types: CrossValTypes) -> Dict[str, CROSS_VAL_FN]:
-    cross_validators = {}  # type: Dict[str, CROSS_VAL_FN]
-    for cross_val_type in cross_val_types:
-        cross_val_fn = globals()[cross_val_type.name]
-        cross_validators[cross_val_type.name] = cross_val_fn
-    return cross_validators
-
-
-def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]:
-    holdout_validators = {}  # type: Dict[str, HOLDOUT_FN]
-    for holdout_val_type in holdout_val_types:
-        holdout_val_fn = globals()[holdout_val_type.name]
-        holdout_validators[holdout_val_type.name] = holdout_val_fn
-    return holdout_validators
-
-
-def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool:
-    if isinstance(val_type, str):
-        return val_type.lower().startswith("stratified")
-    else:
-        return val_type.name.lower().startswith("stratified")
-
-
-def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]:
-    train, val = train_test_split(indices, test_size=val_share, shuffle=False)
-    return train, val
-
-
-def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \
-        -> Tuple[np.ndarray, np.ndarray]:
-    train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"])
-    return train, val
-
-
-def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-        -> List[Tuple[np.ndarray, np.ndarray]]:
-    cv = ShuffleSplit(n_splits=num_splits)
-    splits = list(cv.split(indices))
-    return splits
-
-
-def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-        -> List[Tuple[np.ndarray, np.ndarray]]:
-    cv = StratifiedShuffleSplit(n_splits=num_splits)
-    splits = list(cv.split(indices, kwargs["stratify"]))
-    return splits
-
-
-def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-        -> List[Tuple[np.ndarray, np.ndarray]]:
-    cv = StratifiedKFold(n_splits=num_splits)
-    splits = list(cv.split(indices, kwargs["stratify"]))
-    return splits
-
-
-def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) -> List[Tuple[np.ndarray, np.ndarray]]:
-    """
-    Standard k fold cross validation.
-
-    :param indices: array of indices to be split
-    :param num_splits: number of cross validation splits
-    :return: list of tuples of training and validation indices
-    """
-    cv = KFold(n_splits=num_splits)
-    splits = list(cv.split(indices))
-    return splits
-
-
-def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \
-        -> List[Tuple[np.ndarray, np.ndarray]]:
-    """
-    Returns train and validation indices respecting the temporal ordering of the data.
-    Dummy example: [0, 1, 2, 3] with 3 folds yields
-        [0] [1]
-        [0, 1] [2]
-        [0, 1, 2] [3]
-
-    :param indices: array of indices to be split
-    :param num_splits: number of cross validation splits
-    :return: list of tuples of training and validation indices
-    """
-    cv = TimeSeriesSplit(n_splits=num_splits)
-    splits = list(cv.split(indices))
-    return splits
+class HoldOutFuncs():
+    @staticmethod
+    def holdout_validation(val_share: float,
+                           indices: np.ndarray,
+                           **kwargs: Any
+                           ) -> Tuple[np.ndarray, np.ndarray]:
+        train, val = train_test_split(indices, test_size=val_share, shuffle=False)
+        return train, val
+
+    @staticmethod
+    def stratified_holdout_validation(val_share: float,
+                                      indices: np.ndarray,
+                                      **kwargs: Any
+                                      ) -> Tuple[np.ndarray, np.ndarray]:
+        train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"])
+        return train, val
+
+    @classmethod
+    def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str, HoldOutFunc]:
+
+        holdout_validators = {
+            holdout_val_type.name: getattr(cls, holdout_val_type.name)
+            for holdout_val_type in holdout_val_types
+        }
+        return holdout_validators
+
+
+class CrossValFuncs():
+    @staticmethod
+    def shuffle_split_cross_validation(num_splits: int,
+                                       indices: np.ndarray,
+                                       **kwargs: Any
+                                       ) -> List[Tuple[np.ndarray, np.ndarray]]:
+        cv = ShuffleSplit(n_splits=num_splits)
+        splits = list(cv.split(indices))
+        return splits
+
+    @staticmethod
+    def stratified_shuffle_split_cross_validation(num_splits: int,
+                                                  indices: np.ndarray,
+                                                  **kwargs: Any
+                                                  ) -> List[Tuple[np.ndarray, np.ndarray]]:
+        cv = StratifiedShuffleSplit(n_splits=num_splits)
+        splits = list(cv.split(indices, kwargs["stratify"]))
+        return splits
+
+    @staticmethod
+    def stratified_k_fold_cross_validation(num_splits: int,
+                                           indices: np.ndarray,
+                                           **kwargs: Any
+                                           ) -> List[Tuple[np.ndarray, np.ndarray]]:
+        cv = StratifiedKFold(n_splits=num_splits)
+        splits = list(cv.split(indices, kwargs["stratify"]))
+        return splits
+
+    @staticmethod
+    def k_fold_cross_validation(num_splits: int,
+                                indices: np.ndarray,
+                                **kwargs: Any
+                                ) -> List[Tuple[np.ndarray, np.ndarray]]:
+        """
+        Standard k fold cross validation.
+
+        Args:
+            indices (np.ndarray): array of indices to be split
+            num_splits (int): number of cross validation splits
+
+        Returns:
+            splits (List[Tuple[List, List]]): list of tuples of training and validation indices
+        """
+        cv = KFold(n_splits=num_splits)
+        splits = list(cv.split(indices))
+        return splits
+
+    @staticmethod
+    def time_series_cross_validation(num_splits: int,
+                                     indices: np.ndarray,
+                                     **kwargs: Any
+                                     ) -> List[Tuple[np.ndarray, np.ndarray]]:
+        """
+        Returns train and validation indices respecting the temporal ordering of the data.
+
+        Args:
+            indices (np.ndarray): array of indices to be split
+            num_splits (int): number of cross validation splits
+
+        Returns:
+            splits (List[Tuple[List, List]]): list of tuples of training and validation indices
+
+        Examples:
+            >>> indices = np.array([0, 1, 2, 3])
+            >>> CrossValFuncs.time_series_cross_validation(3, indices)
+                [([0], [1]),
+                 ([0, 1], [2]),
+                 ([0, 1, 2], [3])]
+
+        """
+        cv = TimeSeriesSplit(n_splits=num_splits)
+        splits = list(cv.split(indices))
+        return splits
+
+    @classmethod
+    def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]:
+        cross_validators = {
+            cross_val_type.name: getattr(cls, cross_val_type.name)
+            for cross_val_type in cross_val_types
+        }
+        return cross_validators
diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py
index 7b0435d19..edd07a80e 100644
--- a/autoPyTorch/datasets/time_series_dataset.py
+++ b/autoPyTorch/datasets/time_series_dataset.py
@@ -6,10 +6,10 @@
 
 from autoPyTorch.datasets.base_dataset import BaseDataset
 from autoPyTorch.datasets.resampling_strategy import (
+    CrossValFuncs,
     CrossValTypes,
-    HoldoutValTypes,
-    get_cross_validators,
-    get_holdout_validators
+    HoldOutFuncs,
+    HoldoutValTypes
 )
 
 TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray]  # currently only numpy arrays are supported
@@ -60,8 +60,8 @@ def __init__(self,
                          train_transforms=train_transforms,
                          val_transforms=val_transforms,
                          )
-        self.cross_validators = get_cross_validators(CrossValTypes.time_series_cross_validation)
-        self.holdout_validators = get_holdout_validators(HoldoutValTypes.holdout_validation)
+        self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation)
+        self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation)
 
 
 def _check_time_series_forecasting_inputs(target_variables: Tuple[int],
@@ -117,13 +117,13 @@ def __init__(self,
                                   val=val,
                                   task_type="time_series_classification")
         super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
-        self.cross_validators = get_cross_validators(
+        self.cross_validators = CrossValFuncs.get_cross_validators(
             CrossValTypes.stratified_k_fold_cross_validation,
             CrossValTypes.k_fold_cross_validation,
             CrossValTypes.shuffle_split_cross_validation,
             CrossValTypes.stratified_shuffle_split_cross_validation
         )
-        self.holdout_validators = get_holdout_validators(
+        self.holdout_validators = HoldOutFuncs.get_holdout_validators(
             HoldoutValTypes.holdout_validation,
             HoldoutValTypes.stratified_holdout_validation
         )
@@ -135,11 +135,11 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np.
                                   val=val,
                                   task_type="time_series_regression")
         super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
-        self.cross_validators = get_cross_validators(
+        self.cross_validators = CrossValFuncs.get_cross_validators(
             CrossValTypes.k_fold_cross_validation,
             CrossValTypes.shuffle_split_cross_validation
         )
-        self.holdout_validators = get_holdout_validators(
+        self.holdout_validators = HoldOutFuncs.get_holdout_validators(
             HoldoutValTypes.holdout_validation
         )
 
diff --git a/autoPyTorch/ensemble/ensemble_builder.py b/autoPyTorch/ensemble/ensemble_builder.py
index 434849ef1..e236f091b 100644
--- a/autoPyTorch/ensemble/ensemble_builder.py
+++ b/autoPyTorch/ensemble/ensemble_builder.py
@@ -66,57 +66,56 @@ def __init__(
         logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT,
     ):
         """ SMAC callback to handle ensemble building
-        Parameters
-        ----------
-        start_time: int
-            the time when this job was started, to account for any latency in job allocation
-        time_left_for_ensemble: int
-            How much time is left for the task. Job should finish within this allocated time
-        backend: util.backend.Backend
-            backend to write and read files
-        dataset_name: str
-            name of dataset
-        task_type: int
-            what type of output is expected. If Binary, we need to argmax the one hot encoding.
-        metrics: List[autoPyTorchMetric],
-            A set of metrics that will be used to get performance estimates
-        opt_metric: str
-            name of the optimization metrics
-        ensemble_size: int
-            maximal size of ensemble (passed to ensemble_selection)
-        ensemble_nbest: int/float
-            if int: consider only the n best prediction
-            if float: consider only this fraction of the best models
-            Both wrt to validation predictions
-            If performance_range_threshold > 0, might return less models
-        max_models_on_disc: Union[float, int]
-           Defines the maximum number of models that are kept in the disc.
-           If int, it must be greater or equal than 1, and dictates the max number of
-           models to keep.
-           If float, it will be interpreted as the max megabytes allowed of disc space. That
-           is, if the number of ensemble candidates require more disc space than this float
-           value, the worst models will be deleted to keep within this budget.
-           Models and predictions of the worst-performing models will be deleted then.
-           If None, the feature is disabled.
-           It defines an upper bound on the models that can be used in the ensemble.
-        seed: int
-            random seed
-        max_iterations: int
-            maximal number of iterations to run this script
-            (default None --> deactivated)
-        precision: [16,32,64,128]
-            precision of floats to read the predictions
-        memory_limit: Optional[int]
-            memory limit in mb. If ``None``, no memory limit is enforced.
-        read_at_most: int
-            read at most n new prediction files in each iteration
-        logger_port: int
-            port in where to publish a msg
-    Returns
-    -------
-        List[Tuple[int, float, float, float]]:
-            A list with the performance history of this ensemble, of the form
-            [[pandas_timestamp, train_performance, val_performance, test_performance], ...]
+        Args:
+            start_time: int
+                the time when this job was started, to account for any latency in job allocation
+            time_left_for_ensemble: int
+                How much time is left for the task. Job should finish within this allocated time
+            backend: util.backend.Backend
+                backend to write and read files
+            dataset_name: str
+                name of dataset
+            task_type: int
+                what type of output is expected. If Binary, we need to argmax the one hot encoding.
+            metrics: List[autoPyTorchMetric],
+                A set of metrics that will be used to get performance estimates
+            opt_metric: str
+                name of the optimization metrics
+            ensemble_size: int
+                maximal size of ensemble (passed to ensemble_selection)
+            ensemble_nbest: int/float
+                if int: consider only the n best prediction
+                if float: consider only this fraction of the best models
+                Both wrt to validation predictions
+                If performance_range_threshold > 0, might return less models
+            max_models_on_disc: Union[float, int]
+            Defines the maximum number of models that are kept in the disc.
+            If int, it must be greater or equal than 1, and dictates the max number of
+            models to keep.
+            If float, it will be interpreted as the max megabytes allowed of disc space. That
+            is, if the number of ensemble candidates require more disc space than this float
+            value, the worst models will be deleted to keep within this budget.
+            Models and predictions of the worst-performing models will be deleted then.
+            If None, the feature is disabled.
+            It defines an upper bound on the models that can be used in the ensemble.
+            seed: int
+                random seed
+            max_iterations: int
+                maximal number of iterations to run this script
+                (default None --> deactivated)
+            precision: [16,32,64,128]
+                precision of floats to read the predictions
+            memory_limit: Optional[int]
+                memory limit in mb. If ``None``, no memory limit is enforced.
+            read_at_most: int
+                read at most n new prediction files in each iteration
+            logger_port: int
+                port in where to publish a msg
+
+        Returns:
+            List[Tuple[int, float, float, float]]:
+                A list with the performance history of this ensemble, of the form
+                [[pandas_timestamp, train_performance, val_performance, test_performance], ...]
         """
         self.start_time = start_time
         self.time_left_for_ensembles = time_left_for_ensembles