From 459041ed9827c52cadafbe62dcaf32cec239bf60 Mon Sep 17 00:00:00 2001 From: ilyushka <61294398+DRMPN@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:42:34 +0300 Subject: [PATCH] fix parentheses --- test/unit/data/test_data_split.py | 40 ++++++++++++++++--------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/test/unit/data/test_data_split.py b/test/unit/data/test_data_split.py index f8a87f4d86..29e0367f6c 100644 --- a/test/unit/data/test_data_split.py +++ b/test/unit/data/test_data_split.py @@ -202,25 +202,27 @@ def test_multivariate_time_series_splitting_correct(): assert np.allclose(test_series_data.target, np.array([16, 17, 18, 19])) -@pytest.mark.parametrize(('datas_funs', 'cv_folds', 'shuffle', 'stratify'), - [ - # classification + stratify + shuffle + cv_folds - ([partial(get_tabular_classification_data, 100, 5)] * 3, 4, True, True), - # classification + shuffle + cv_folds - ([partial(get_tabular_classification_data, 100, 5)] * 3, 4, True, False), - # classification + cv_folds - ([partial(get_tabular_classification_data, 100, 5)] * 3, 4, False, False), - # classification + stratify + shuffle - ([partial(get_tabular_classification_data, 100, 5)] * 3, None, True, True), - # classification + shuffle - ([partial(get_tabular_classification_data, 100, 5)] * 3, None, True, False), - # classification - ([partial(get_tabular_classification_data, 100, 5)] * 3, None, False, False), - # timeseries + cv_folds - ([partial(get_ts_data_to_forecast, 10, 100)] * 3, 3, False, False), - # timeseries - ([partial(get_ts_data_to_forecast, 10, 100)] * 3, None, False, False), -]) +@pytest.mark.parametrize( + ("datas_funs", "cv_folds", "shuffle", "stratify"), + [ + # classification + stratify + shuffle + cv_folds + ([partial(get_tabular_classification_data, 100, 5)] * 3, 4, True, True), + # classification + shuffle + cv_folds + ([partial(get_tabular_classification_data, 100, 5)] * 3, 4, True, False), + # classification + cv_folds + ([partial(get_tabular_classification_data, 100, 5)] * 3, 4, False, False), + # classification + stratify + shuffle + ([partial(get_tabular_classification_data, 100, 5)] * 3, None, True, True), + # classification + shuffle + ([partial(get_tabular_classification_data, 100, 5)] * 3, None, True, False), + # classification + ([partial(get_tabular_classification_data, 100, 5)] * 3, None, False, False), + # timeseries + cv_folds + ([partial(get_ts_data_to_forecast, 10, 100)] * 3, 3, False, False), + # timeseries + ([partial(get_ts_data_to_forecast, 10, 100)] * 3, None, False, False), + ], +) def test_multimodal_data_splitting_is_correct(datas_funs, cv_folds, shuffle, stratify): mdata = MultiModalData({f'data_{i}': data_fun() for i, data_fun in enumerate(datas_funs)}) data_splitter = DataSourceSplitter(cv_folds=cv_folds, shuffle=shuffle, stratify=stratify)