From 85c1604292d37e5670efc4d15e06c0cbdbcd8c2c Mon Sep 17 00:00:00 2001 From: alekseykalyagin Date: Fri, 13 Dec 2024 16:56:31 +0300 Subject: [PATCH] Remove fit_recommend function --- rectools/model_selection/cross_validate.py | 77 +++++++++------------- 1 file changed, 32 insertions(+), 45 deletions(-) diff --git a/rectools/model_selection/cross_validate.py b/rectools/model_selection/cross_validate.py index 3107ce79..f23d778c 100644 --- a/rectools/model_selection/cross_validate.py +++ b/rectools/model_selection/cross_validate.py @@ -15,8 +15,6 @@ import typing as tp from contextlib import contextmanager -import pandas as pd - from rectools.columns import Columns from rectools.dataset import Dataset from rectools.metrics import calc_metrics @@ -124,47 +122,6 @@ def cross_validate( # pylint: disable=too-many-locals ] } """ - - def fit_recommend(model: ModelBase, ref_model: bool = False) -> tp.Tuple[pd.DataFrame, tp.Dict[str, tp.Any]]: - """ - Trains the given recommendation model on a dataset split and generates recommendations. - - Parameters - ---------- - model : ModelBase - The recommendation model to be trained and used for generating recommendations. - Must be an instance of a subclass of `rectools.models.base.ModelBase`. - ref_model : bool, optional, default False - Indicates whether the model is a reference model used for comparison. If True, - and `validate_ref_models` is False, this model's recommendations may be reused - across splits without being refitted. - - Returns - ------- - tuple(pd.DataFrame, dict) - - A DataFrame with recommendations. - - A dictionary containing timing metrics (`fit_time` and `recommend_time`), if - `compute_timings` is enabled; otherwise, an empty dictionary. - """ - timings: tp.Optional[tp.Dict[str, float]] = ( - {} if compute_timings and (validate_ref_models or not ref_model) else None - ) - - with compute_timing("fit_time", timings): - model.fit(fold_dataset) - - with compute_timing("recommend_time", timings): - reco = model.recommend( - users=test_users, - dataset=fold_dataset, - k=k, - filter_viewed=filter_viewed, - items_to_recommend=items_to_recommend, - on_unsupported_targets=on_unsupported_targets, - ) - - return reco, (timings or {}) - split_iterator = splitter.split(dataset.interactions, collect_fold_stats=True) split_infos = [] @@ -191,7 +148,22 @@ def fit_recommend(model: ModelBase, ref_model: bool = False) -> tp.Tuple[pd.Data ref_res = {} for model_name in ref_models or []: model = models[model_name] - ref_reco[model_name], ref_res[model_name] = fit_recommend(model, ref_model=True) + ref_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None + + with compute_timing("fit_time", ref_timings): + model.fit(fold_dataset) + + with compute_timing("recommend_time", ref_timings): + ref_reco[model_name] = model.recommend( + users=test_users, + dataset=fold_dataset, + k=k, + filter_viewed=filter_viewed, + items_to_recommend=items_to_recommend, + on_unsupported_targets=on_unsupported_targets, + ) + + ref_res[model_name] = ref_timings or {} # ### Generate recommendations and calc metrics for model_name, model in models.items(): @@ -201,7 +173,22 @@ def fit_recommend(model: ModelBase, ref_model: bool = False) -> tp.Tuple[pd.Data reco = ref_reco[model_name] model_res = ref_res[model_name] else: - reco, model_res = fit_recommend(model) + timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None + + with compute_timing("fit_time", timings): + model.fit(fold_dataset) + + with compute_timing("recommend_time", timings): + reco = model.recommend( + users=test_users, + dataset=fold_dataset, + k=k, + filter_viewed=filter_viewed, + items_to_recommend=items_to_recommend, + on_unsupported_targets=on_unsupported_targets, + ) + + model_res = timings or {} metric_values = calc_metrics( metrics,