-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature: cross validate timings #233
base: main
Are you sure you want to change the base?
Changes from 5 commits
2b95c08
e474430
b05fbe1
a64db58
85c1604
5afc0e4
ce5f1c0
46e05c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,9 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import time | ||
import typing as tp | ||
from contextlib import contextmanager | ||
|
||
from rectools.columns import Columns | ||
from rectools.dataset import Dataset | ||
|
@@ -24,6 +25,26 @@ | |
from .splitter import Splitter | ||
|
||
|
||
@contextmanager | ||
def compute_timing(label: str, timings: tp.Optional[tp.Dict[str, float]] = None) -> tp.Iterator[None]: | ||
""" | ||
Context manager to compute timing for a code block. | ||
|
||
Parameters | ||
---------- | ||
label : str | ||
Label to store the timing result in the timings dictionary. | ||
timings : dict, optional | ||
Dictionary to store the timing results. If None, timing is not recorded. | ||
""" | ||
if timings is not None: | ||
start_time = time.time() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use |
||
yield | ||
timings[label] = round(time.time() - start_time, 2) | ||
blondered marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
yield | ||
|
||
|
||
def cross_validate( # pylint: disable=too-many-locals | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update CHANGELOG.MD |
||
dataset: Dataset, | ||
splitter: Splitter, | ||
|
@@ -36,6 +57,7 @@ def cross_validate( # pylint: disable=too-many-locals | |
ref_models: tp.Optional[tp.List[str]] = None, | ||
validate_ref_models: bool = False, | ||
on_unsupported_targets: ErrorBehaviour = "warn", | ||
compute_timings: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add new argument to docstring There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need this param? what's wrong if we always measure the time? |
||
) -> tp.Dict[str, tp.Any]: | ||
""" | ||
Run cross validation on multiple models with multiple metrics. | ||
|
@@ -123,28 +145,16 @@ def cross_validate( # pylint: disable=too-many-locals | |
|
||
# ### Train ref models if any | ||
ref_reco = {} | ||
ref_res = {} | ||
blondered marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for model_name in ref_models or []: | ||
model = models[model_name] | ||
model.fit(fold_dataset) | ||
ref_reco[model_name] = model.recommend( | ||
users=test_users, | ||
dataset=fold_dataset, | ||
k=k, | ||
filter_viewed=filter_viewed, | ||
items_to_recommend=items_to_recommend, | ||
on_unsupported_targets=on_unsupported_targets, | ||
) | ||
ref_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None | ||
blondered marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# ### Generate recommendations and calc metrics | ||
for model_name, model in models.items(): | ||
if model_name in ref_reco and not validate_ref_models: | ||
continue | ||
|
||
if model_name in ref_reco: | ||
reco = ref_reco[model_name] | ||
else: | ||
with compute_timing("fit_time", ref_timings): | ||
model.fit(fold_dataset) | ||
reco = model.recommend( | ||
|
||
with compute_timing("recommend_time", ref_timings): | ||
ref_reco[model_name] = model.recommend( | ||
users=test_users, | ||
dataset=fold_dataset, | ||
k=k, | ||
|
@@ -153,6 +163,33 @@ def cross_validate( # pylint: disable=too-many-locals | |
on_unsupported_targets=on_unsupported_targets, | ||
) | ||
|
||
ref_res[model_name] = ref_timings or {} | ||
|
||
# ### Generate recommendations and calc metrics | ||
for model_name, model in models.items(): | ||
if model_name in ref_reco and not validate_ref_models: | ||
continue | ||
if model_name in ref_reco: | ||
reco = ref_reco[model_name] | ||
model_res = ref_res[model_name] | ||
else: | ||
timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None | ||
|
||
with compute_timing("fit_time", timings): | ||
model.fit(fold_dataset) | ||
|
||
with compute_timing("recommend_time", timings): | ||
reco = model.recommend( | ||
users=test_users, | ||
dataset=fold_dataset, | ||
k=k, | ||
filter_viewed=filter_viewed, | ||
items_to_recommend=items_to_recommend, | ||
on_unsupported_targets=on_unsupported_targets, | ||
) | ||
|
||
model_res = timings or {} | ||
blondered marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
metric_values = calc_metrics( | ||
metrics, | ||
reco=reco, | ||
|
@@ -163,6 +200,7 @@ def cross_validate( # pylint: disable=too-many-locals | |
) | ||
res = {"model": model_name, "i_split": split_info["i_split"]} | ||
res.update(metric_values) | ||
res.update(model_res) | ||
metrics_all.append(res) | ||
|
||
result = {"splits": split_infos, "metrics": metrics_all} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this function accepts the
label
timings
param but don't really need itLet's please rewrite it in one of the following ways:
I personally prefer the second option since it's clearer.
But anyway let's not use this labels and dictionary inside. We can easily fill them out of the class
And example (it's simplified a bit, please add init if required by linters, also types)