Skip to content

Commit

Permalink
Remove fit_recommend function
Browse files Browse the repository at this point in the history
  • Loading branch information
alekseykalyagin committed Dec 13, 2024
1 parent a64db58 commit 85c1604
Showing 1 changed file with 32 additions and 45 deletions.
77 changes: 32 additions & 45 deletions rectools/model_selection/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import typing as tp
from contextlib import contextmanager

import pandas as pd

from rectools.columns import Columns
from rectools.dataset import Dataset
from rectools.metrics import calc_metrics
Expand Down Expand Up @@ -124,47 +122,6 @@ def cross_validate( # pylint: disable=too-many-locals
]
}
"""

def fit_recommend(model: ModelBase, ref_model: bool = False) -> tp.Tuple[pd.DataFrame, tp.Dict[str, tp.Any]]:
"""
Trains the given recommendation model on a dataset split and generates recommendations.
Parameters
----------
model : ModelBase
The recommendation model to be trained and used for generating recommendations.
Must be an instance of a subclass of `rectools.models.base.ModelBase`.
ref_model : bool, optional, default False
Indicates whether the model is a reference model used for comparison. If True,
and `validate_ref_models` is False, this model's recommendations may be reused
across splits without being refitted.
Returns
-------
tuple(pd.DataFrame, dict)
- A DataFrame with recommendations.
- A dictionary containing timing metrics (`fit_time` and `recommend_time`), if
`compute_timings` is enabled; otherwise, an empty dictionary.
"""
timings: tp.Optional[tp.Dict[str, float]] = (
{} if compute_timings and (validate_ref_models or not ref_model) else None
)

with compute_timing("fit_time", timings):
model.fit(fold_dataset)

with compute_timing("recommend_time", timings):
reco = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)

return reco, (timings or {})

split_iterator = splitter.split(dataset.interactions, collect_fold_stats=True)

split_infos = []
Expand All @@ -191,7 +148,22 @@ def fit_recommend(model: ModelBase, ref_model: bool = False) -> tp.Tuple[pd.Data
ref_res = {}
for model_name in ref_models or []:
model = models[model_name]
ref_reco[model_name], ref_res[model_name] = fit_recommend(model, ref_model=True)
ref_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None

with compute_timing("fit_time", ref_timings):
model.fit(fold_dataset)

with compute_timing("recommend_time", ref_timings):
ref_reco[model_name] = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)

ref_res[model_name] = ref_timings or {}

# ### Generate recommendations and calc metrics
for model_name, model in models.items():
Expand All @@ -201,7 +173,22 @@ def fit_recommend(model: ModelBase, ref_model: bool = False) -> tp.Tuple[pd.Data
reco = ref_reco[model_name]
model_res = ref_res[model_name]
else:
reco, model_res = fit_recommend(model)
timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None

with compute_timing("fit_time", timings):
model.fit(fold_dataset)

with compute_timing("recommend_time", timings):
reco = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)

model_res = timings or {}

metric_values = calc_metrics(
metrics,
Expand Down

0 comments on commit 85c1604

Please sign in to comment.