diff --git a/rectools/model_selection/cross_validate.py b/rectools/model_selection/cross_validate.py index f23d778c..c02b142e 100644 --- a/rectools/model_selection/cross_validate.py +++ b/rectools/model_selection/cross_validate.py @@ -145,15 +145,15 @@ def cross_validate( # pylint: disable=too-many-locals # ### Train ref models if any ref_reco = {} - ref_res = {} + ref_timings = {} for model_name in ref_models or []: model = models[model_name] - ref_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None + model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None - with compute_timing("fit_time", ref_timings): + with compute_timing("fit_time", model_timings): model.fit(fold_dataset) - with compute_timing("recommend_time", ref_timings): + with compute_timing("recommend_time", model_timings): ref_reco[model_name] = model.recommend( users=test_users, dataset=fold_dataset, @@ -163,7 +163,7 @@ def cross_validate( # pylint: disable=too-many-locals on_unsupported_targets=on_unsupported_targets, ) - ref_res[model_name] = ref_timings or {} + ref_timings[model_name] = model_timings or {} # ### Generate recommendations and calc metrics for model_name, model in models.items(): @@ -171,14 +171,14 @@ def cross_validate( # pylint: disable=too-many-locals continue if model_name in ref_reco: reco = ref_reco[model_name] - model_res = ref_res[model_name] + model_timing = ref_timings[model_name] else: - timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None + model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None # type: ignore - with compute_timing("fit_time", timings): + with compute_timing("fit_time", model_timings): model.fit(fold_dataset) - with compute_timing("recommend_time", timings): + with compute_timing("recommend_time", model_timings): reco = model.recommend( users=test_users, dataset=fold_dataset, @@ -188,7 +188,7 @@ def cross_validate( # pylint: disable=too-many-locals on_unsupported_targets=on_unsupported_targets, ) - model_res = timings or {} + model_timing = model_timings or {} metric_values = calc_metrics( metrics, @@ -200,7 +200,7 @@ def cross_validate( # pylint: disable=too-many-locals ) res = {"model": model_name, "i_split": split_info["i_split"]} res.update(metric_values) - res.update(model_res) + res.update(model_timing) metrics_all.append(res) result = {"splits": split_infos, "metrics": metrics_all} diff --git a/tests/model_selection/test_cross_validate.py b/tests/model_selection/test_cross_validate.py index 12c9d50e..61bb4e0d 100644 --- a/tests/model_selection/test_cross_validate.py +++ b/tests/model_selection/test_cross_validate.py @@ -376,21 +376,6 @@ def test_happy_path_with_intersection( @pytest.mark.parametrize( "ref_models,validate_ref_models,expected_metrics,compute_timings", ( - ( - ["popular"], - False, - [ - {"model": "random", "i_split": 0, "precision@2": 0.5, "recall@1": 0.0, "intersection_popular": 0.5}, - { - "model": "random", - "i_split": 1, - "precision@2": 0.375, - "recall@1": 0.5, - "intersection_popular": 0.75, - }, - ], - False, - ), ( ["popular"], False, @@ -401,8 +386,8 @@ def test_happy_path_with_intersection( "precision@2": 0.5, "recall@1": 0.0, "intersection_popular": 0.5, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "random", @@ -410,8 +395,8 @@ def test_happy_path_with_intersection( "precision@2": 0.375, "recall@1": 0.5, "intersection_popular": 0.75, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, ], True, @@ -426,37 +411,8 @@ def test_happy_path_with_intersection( "precision@2": 0.5, "recall@1": 0.5, "intersection_popular": 1.0, - }, - {"model": "random", "i_split": 0, "precision@2": 0.5, "recall@1": 0.0, "intersection_popular": 0.5}, - { - "model": "popular", - "i_split": 1, - "precision@2": 0.375, - "recall@1": 0.25, - "intersection_popular": 1.0, - }, - { - "model": "random", - "i_split": 1, - "precision@2": 0.375, - "recall@1": 0.5, - "intersection_popular": 0.75, - }, - ], - False, - ), - ( - ["popular"], - True, - [ - { - "model": "popular", - "i_split": 0, - "precision@2": 0.5, - "recall@1": 0.5, - "intersection_popular": 1.0, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "random", @@ -464,8 +420,8 @@ def test_happy_path_with_intersection( "precision@2": 0.5, "recall@1": 0.0, "intersection_popular": 0.5, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "popular", @@ -473,8 +429,8 @@ def test_happy_path_with_intersection( "precision@2": 0.375, "recall@1": 0.25, "intersection_popular": 1.0, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "random", @@ -482,27 +438,12 @@ def test_happy_path_with_intersection( "precision@2": 0.375, "recall@1": 0.5, "intersection_popular": 0.75, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, ], True, ), - ( - ["random"], - False, - [ - {"model": "popular", "i_split": 0, "precision@2": 0.5, "recall@1": 0.5, "intersection_random": 0.5}, - { - "model": "popular", - "i_split": 1, - "precision@2": 0.375, - "recall@1": 0.25, - "intersection_random": 0.75, - }, - ], - False, - ), ( ["random"], False, @@ -513,8 +454,8 @@ def test_happy_path_with_intersection( "precision@2": 0.5, "recall@1": 0.5, "intersection_random": 0.5, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "popular", @@ -522,35 +463,12 @@ def test_happy_path_with_intersection( "precision@2": 0.375, "recall@1": 0.25, "intersection_random": 0.75, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, ], True, ), - ( - ["random"], - True, - [ - {"model": "popular", "i_split": 0, "precision@2": 0.5, "recall@1": 0.5, "intersection_random": 0.5}, - {"model": "random", "i_split": 0, "precision@2": 0.5, "recall@1": 0.0, "intersection_random": 1.0}, - { - "model": "popular", - "i_split": 1, - "precision@2": 0.375, - "recall@1": 0.25, - "intersection_random": 0.75, - }, - { - "model": "random", - "i_split": 1, - "precision@2": 0.375, - "recall@1": 0.5, - "intersection_random": 1.0, - }, - ], - False, - ), ( ["random"], True, @@ -561,8 +479,8 @@ def test_happy_path_with_intersection( "precision@2": 0.5, "recall@1": 0.5, "intersection_random": 0.5, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "random", @@ -570,8 +488,8 @@ def test_happy_path_with_intersection( "precision@2": 0.5, "recall@1": 0.0, "intersection_random": 1.0, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "popular", @@ -579,8 +497,8 @@ def test_happy_path_with_intersection( "precision@2": 0.375, "recall@1": 0.25, "intersection_random": 0.75, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "random", @@ -588,13 +506,12 @@ def test_happy_path_with_intersection( "precision@2": 0.375, "recall@1": 0.5, "intersection_random": 1.0, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, ], True, ), - (["random", "popular"], False, [], False), (["random", "popular"], False, [], True), ( ["random", "popular"], @@ -607,6 +524,8 @@ def test_happy_path_with_intersection( "recall@1": 0.5, "intersection_random": 0.5, "intersection_popular": 1.0, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "random", @@ -615,6 +534,8 @@ def test_happy_path_with_intersection( "recall@1": 0.0, "intersection_random": 1.0, "intersection_popular": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "popular", @@ -623,6 +544,8 @@ def test_happy_path_with_intersection( "recall@1": 0.25, "intersection_random": 0.75, "intersection_popular": 1.0, + "fit_time": 0.0, + "recommend_time": 0.0, }, { "model": "random", @@ -631,53 +554,8 @@ def test_happy_path_with_intersection( "recall@1": 0.5, "intersection_random": 1.0, "intersection_popular": 0.75, - }, - ], - False, - ), - ( - ["random", "popular"], - True, - [ - { - "model": "popular", - "i_split": 0, - "precision@2": 0.5, - "recall@1": 0.5, - "intersection_random": 0.5, - "intersection_popular": 1.0, - "fit_time": 0.5, - "recommend_time": 0.5, - }, - { - "model": "random", - "i_split": 0, - "precision@2": 0.5, - "recall@1": 0.0, - "intersection_random": 1.0, - "intersection_popular": 0.5, - "fit_time": 0.5, - "recommend_time": 0.5, - }, - { - "model": "popular", - "i_split": 1, - "precision@2": 0.375, - "recall@1": 0.25, - "intersection_random": 0.75, - "intersection_popular": 1.0, - "fit_time": 0.5, - "recommend_time": 0.5, - }, - { - "model": "random", - "i_split": 1, - "precision@2": 0.375, - "recall@1": 0.5, - "intersection_random": 1.0, - "intersection_popular": 0.75, - "fit_time": 0.5, - "recommend_time": 0.5, + "fit_time": 0.0, + "recommend_time": 0.0, }, ], True, @@ -705,33 +583,35 @@ def test_happy_path_with_intersection_timings( compute_timings=compute_timings, ) - expected_keys = {"fit_time", "recommend_time"} + time_threshold = 0.5 - if compute_timings: - for data in actual["metrics"]: - assert len(expected_keys.intersection(set(data.keys()))) == 2 - else: - expected = { - "splits": [ - { - "i_split": 0, - "test": 2, - "test_items": 2, - "test_users": 2, - "train": 2, - "train_items": 2, - "train_users": 2, - }, - { - "i_split": 1, - "test": 4, - "test_items": 3, - "test_users": 4, - "train": 6, - "train_items": 2, - "train_users": 4, - }, - ], - "metrics": expected_metrics, - } - assert actual == expected + for data in actual["metrics"]: + print(data["fit_time"]) + print(data["recommend_time"]) + assert data["fit_time"] < time_threshold + assert data["recommend_time"] < time_threshold + + expected = { + "splits": [ + { + "i_split": 0, + "test": 2, + "test_items": 2, + "test_users": 2, + "train": 2, + "train_items": 2, + "train_users": 2, + }, + { + "i_split": 1, + "test": 4, + "test_items": 3, + "test_users": 4, + "train": 6, + "train_items": 2, + "train_users": 4, + }, + ], + "metrics": expected_metrics, + } + assert actual == expected