Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
kasyanovse committed Aug 9, 2023
1 parent 3bdfe6e commit 6c440e3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 6 additions & 0 deletions fedot/core/data/data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def _are_stratification_allowed(data: Union[InputData, MultiModalData], split_ra
# check that there are enough labels for two samples
if not all(x > 1 for x in classes[1]):
if __debug__:
# tests often use very small datasets that are not suitable for data splitting
# for test stratification is disabled in that case
return False
else:
raise ValueError(("There is the only value for some classes:"
Expand All @@ -144,6 +146,7 @@ def _are_stratification_allowed(data: Union[InputData, MultiModalData], split_ra
def train_test_data_setup(data: Union[InputData, MultiModalData],
split_ratio: float = 0.8,
shuffle: bool = False,
shuffle_flag: bool = False,
stratify: bool = True,
random_seed: int = 42,
validation_blocks: Optional[int] = None) -> Tuple[Union[InputData, MultiModalData],
Expand All @@ -153,13 +156,16 @@ def train_test_data_setup(data: Union[InputData, MultiModalData],
:param data: InputData object to split
:param split_ratio: share of train data between 0 and 1
:param shuffle: is data needed to be shuffled or not
:param shuffle_flag: same is shuffle, use for backward compatibility
:param stratify: make stratified sample or not
:param random_seed: random_seed for shuffle
:param validation_blocks: validation blocks are used for test
:return: data for train, data for validation
"""

# for backward compatibility
shuffle |= shuffle_flag
# check that stratification may be done
stratify &= _are_stratification_allowed(data, split_ratio)
# stratification is allowed only with shuffle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,6 @@ def _find_binary_features(self, numerical_features: np.array):
# Calculate unique values per column (excluding nans)
for column_id, col in enumerate(df):
unique_values = df[col].dropna().unique()
# TODO: test data processed without information about train data
# it may lead to errors
if len(unique_values) == 2:
# Current numerical column has only two values
column_info = {column_id: {'min': min(unique_values),
Expand Down

0 comments on commit 6c440e3

Please sign in to comment.