Skip to content

Commit

Permalink
fix parentheses
Browse files Browse the repository at this point in the history
  • Loading branch information
DRMPN authored Feb 8, 2024
1 parent 5f98ed8 commit 459041e
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions test/unit/data/test_data_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,27 @@ def test_multivariate_time_series_splitting_correct():
assert np.allclose(test_series_data.target, np.array([16, 17, 18, 19]))


@pytest.mark.parametrize(('datas_funs', 'cv_folds', 'shuffle', 'stratify'),
[
# classification + stratify + shuffle + cv_folds
([partial(get_tabular_classification_data, 100, 5)] * 3, 4, True, True),
# classification + shuffle + cv_folds
([partial(get_tabular_classification_data, 100, 5)] * 3, 4, True, False),
# classification + cv_folds
([partial(get_tabular_classification_data, 100, 5)] * 3, 4, False, False),
# classification + stratify + shuffle
([partial(get_tabular_classification_data, 100, 5)] * 3, None, True, True),
# classification + shuffle
([partial(get_tabular_classification_data, 100, 5)] * 3, None, True, False),
# classification
([partial(get_tabular_classification_data, 100, 5)] * 3, None, False, False),
# timeseries + cv_folds
([partial(get_ts_data_to_forecast, 10, 100)] * 3, 3, False, False),
# timeseries
([partial(get_ts_data_to_forecast, 10, 100)] * 3, None, False, False),
])
@pytest.mark.parametrize(
("datas_funs", "cv_folds", "shuffle", "stratify"),
[
# classification + stratify + shuffle + cv_folds
([partial(get_tabular_classification_data, 100, 5)] * 3, 4, True, True),
# classification + shuffle + cv_folds
([partial(get_tabular_classification_data, 100, 5)] * 3, 4, True, False),
# classification + cv_folds
([partial(get_tabular_classification_data, 100, 5)] * 3, 4, False, False),
# classification + stratify + shuffle
([partial(get_tabular_classification_data, 100, 5)] * 3, None, True, True),
# classification + shuffle
([partial(get_tabular_classification_data, 100, 5)] * 3, None, True, False),
# classification
([partial(get_tabular_classification_data, 100, 5)] * 3, None, False, False),
# timeseries + cv_folds
([partial(get_ts_data_to_forecast, 10, 100)] * 3, 3, False, False),
# timeseries
([partial(get_ts_data_to_forecast, 10, 100)] * 3, None, False, False),
],
)
def test_multimodal_data_splitting_is_correct(datas_funs, cv_folds, shuffle, stratify):
mdata = MultiModalData({f'data_{i}': data_fun() for i, data_fun in enumerate(datas_funs)})
data_splitter = DataSourceSplitter(cv_folds=cv_folds, shuffle=shuffle, stratify=stratify)
Expand Down

0 comments on commit 459041e

Please sign in to comment.