diff --git a/rectools/model_selection/cross_validate.py b/rectools/model_selection/cross_validate.py index 510c48b1..756b23cb 100644 --- a/rectools/model_selection/cross_validate.py +++ b/rectools/model_selection/cross_validate.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import time import typing as tp +from contextlib import contextmanager from rectools.columns import Columns from rectools.dataset import Dataset @@ -24,6 +25,26 @@ from .splitter import Splitter +@contextmanager +def compute_timing(label: str, timings: tp.Optional[tp.Dict[str, float]] = None) -> tp.Iterator[None]: + """ + Context manager to compute timing for a code block. + + Parameters + ---------- + label : str + Label to store the timing result in the timings dictionary. + timings : dict, optional + Dictionary to store the timing results. If None, timing is not recorded. + """ + if timings is not None: + start_time = time.time() + yield + timings[label] = round(time.time() - start_time, 5) + else: + yield + + def cross_validate( # pylint: disable=too-many-locals dataset: Dataset, splitter: Splitter, @@ -36,6 +57,7 @@ def cross_validate( # pylint: disable=too-many-locals ref_models: tp.Optional[tp.List[str]] = None, validate_ref_models: bool = False, on_unsupported_targets: ErrorBehaviour = "warn", + compute_timings: bool = False, ) -> tp.Dict[str, tp.Any]: """ Run cross validation on multiple models with multiple metrics. @@ -123,28 +145,16 @@ def cross_validate( # pylint: disable=too-many-locals # ### Train ref models if any ref_reco = {} + ref_timings = {} for model_name in ref_models or []: model = models[model_name] - model.fit(fold_dataset) - ref_reco[model_name] = model.recommend( - users=test_users, - dataset=fold_dataset, - k=k, - filter_viewed=filter_viewed, - items_to_recommend=items_to_recommend, - on_unsupported_targets=on_unsupported_targets, - ) + model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None - # ### Generate recommendations and calc metrics - for model_name, model in models.items(): - if model_name in ref_reco and not validate_ref_models: - continue - - if model_name in ref_reco: - reco = ref_reco[model_name] - else: + with compute_timing("fit_time", model_timings): model.fit(fold_dataset) - reco = model.recommend( + + with compute_timing("recommend_time", model_timings): + ref_reco[model_name] = model.recommend( users=test_users, dataset=fold_dataset, k=k, @@ -153,6 +163,33 @@ def cross_validate( # pylint: disable=too-many-locals on_unsupported_targets=on_unsupported_targets, ) + ref_timings[model_name] = model_timings or {} + + # ### Generate recommendations and calc metrics + for model_name, model in models.items(): + if model_name in ref_reco and not validate_ref_models: + continue + if model_name in ref_reco: + reco = ref_reco[model_name] + model_timing = ref_timings[model_name] + else: + model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None # type: ignore + + with compute_timing("fit_time", model_timings): + model.fit(fold_dataset) + + with compute_timing("recommend_time", model_timings): + reco = model.recommend( + users=test_users, + dataset=fold_dataset, + k=k, + filter_viewed=filter_viewed, + items_to_recommend=items_to_recommend, + on_unsupported_targets=on_unsupported_targets, + ) + + model_timing = model_timings or {} + metric_values = calc_metrics( metrics, reco=reco, @@ -163,6 +200,7 @@ def cross_validate( # pylint: disable=too-many-locals ) res = {"model": model_name, "i_split": split_info["i_split"]} res.update(metric_values) + res.update(model_timing) metrics_all.append(res) result = {"splits": split_infos, "metrics": metrics_all} diff --git a/tests/model_selection/test_cross_validate.py b/tests/model_selection/test_cross_validate.py index e449aa3e..1f25765b 100644 --- a/tests/model_selection/test_cross_validate.py +++ b/tests/model_selection/test_cross_validate.py @@ -371,5 +371,117 @@ def test_happy_path_with_intersection( ], "metrics": expected_metrics, } + assert actual == expected + @pytest.mark.parametrize( + "validate_ref_models,expected_metrics", + ( + ( + False, + [ + { + "model": "random", + "i_split": 0, + "precision@2": 0.5, + "recall@1": 0.0, + "intersection_popular": 0.5, + }, + { + "model": "random", + "i_split": 1, + "precision@2": 0.375, + "recall@1": 0.5, + "intersection_popular": 0.75, + }, + ], + ), + ( + True, + [ + { + "model": "popular", + "i_split": 0, + "precision@2": 0.5, + "recall@1": 0.5, + "intersection_popular": 1.0, + }, + { + "model": "random", + "i_split": 0, + "precision@2": 0.5, + "recall@1": 0.0, + "intersection_popular": 0.5, + }, + { + "model": "popular", + "i_split": 1, + "precision@2": 0.375, + "recall@1": 0.25, + "intersection_popular": 1.0, + }, + { + "model": "random", + "i_split": 1, + "precision@2": 0.375, + "recall@1": 0.5, + "intersection_popular": 0.75, + }, + ], + ), + ), + ) + @pytest.mark.parametrize("compute_timings", (False, True)) + def test_happy_path_with_intersection_timings( + self, + validate_ref_models: bool, + expected_metrics: tp.List[tp.Dict[str, tp.Any]], + compute_timings: bool, + ) -> None: + splitter = LastNSplitter(n=1, n_splits=2, filter_cold_items=False, filter_already_seen=False) + + actual = cross_validate( + dataset=self.dataset, + splitter=splitter, + metrics=self.metrics_intersection, + models=self.models, + k=2, + filter_viewed=False, + ref_models=["popular"], + validate_ref_models=validate_ref_models, + compute_timings=compute_timings, + ) + + time_threshold = 0.5 + + if compute_timings: + for data in actual["metrics"]: + assert data["fit_time"] < time_threshold + assert data["recommend_time"] < time_threshold + + del data["fit_time"] + del data["recommend_time"] + + expected = { + "splits": [ + { + "i_split": 0, + "test": 2, + "test_items": 2, + "test_users": 2, + "train": 2, + "train_items": 2, + "train_users": 2, + }, + { + "i_split": 1, + "test": 4, + "test_items": 3, + "test_users": 4, + "train": 6, + "train_items": 2, + "train_users": 4, + }, + ], + "metrics": expected_metrics, + } assert actual == expected