Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alekseykalyagin committed Dec 19, 2024
1 parent 85c1604 commit 5afc0e4
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 194 deletions.
22 changes: 11 additions & 11 deletions rectools/model_selection/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ def cross_validate( # pylint: disable=too-many-locals

# ### Train ref models if any
ref_reco = {}
ref_res = {}
ref_timings = {}
for model_name in ref_models or []:
model = models[model_name]
ref_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None
model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None

with compute_timing("fit_time", ref_timings):
with compute_timing("fit_time", model_timings):
model.fit(fold_dataset)

with compute_timing("recommend_time", ref_timings):
with compute_timing("recommend_time", model_timings):
ref_reco[model_name] = model.recommend(
users=test_users,
dataset=fold_dataset,
Expand All @@ -163,22 +163,22 @@ def cross_validate( # pylint: disable=too-many-locals
on_unsupported_targets=on_unsupported_targets,
)

ref_res[model_name] = ref_timings or {}
ref_timings[model_name] = model_timings or {}

# ### Generate recommendations and calc metrics
for model_name, model in models.items():
if model_name in ref_reco and not validate_ref_models:
continue
if model_name in ref_reco:
reco = ref_reco[model_name]
model_res = ref_res[model_name]
model_timing = ref_timings[model_name]
else:
timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None
model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None # type: ignore

with compute_timing("fit_time", timings):
with compute_timing("fit_time", model_timings):
model.fit(fold_dataset)

with compute_timing("recommend_time", timings):
with compute_timing("recommend_time", model_timings):
reco = model.recommend(
users=test_users,
dataset=fold_dataset,
Expand All @@ -188,7 +188,7 @@ def cross_validate( # pylint: disable=too-many-locals
on_unsupported_targets=on_unsupported_targets,
)

model_res = timings or {}
model_timing = model_timings or {}

metric_values = calc_metrics(
metrics,
Expand All @@ -200,7 +200,7 @@ def cross_validate( # pylint: disable=too-many-locals
)
res = {"model": model_name, "i_split": split_info["i_split"]}
res.update(metric_values)
res.update(model_res)
res.update(model_timing)
metrics_all.append(res)

result = {"splits": split_infos, "metrics": metrics_all}
Expand Down
Loading

0 comments on commit 5afc0e4

Please sign in to comment.