Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: cross validate timings #233

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions rectools/model_selection/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import typing as tp
from contextlib import contextmanager

from rectools.columns import Columns
from rectools.dataset import Dataset
Expand All @@ -24,6 +25,26 @@
from .splitter import Splitter


@contextmanager
def compute_timing(label: str, timings: tp.Optional[tp.Dict[str, float]] = None) -> tp.Iterator[None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this function accepts the label timings param but don't really need it

Let's please rewrite it in one of the following ways:

  1. Remove both params and simply return the elapsed time without dictionaries
  2. Rewrite it as a class

I personally prefer the second option since it's clearer.
But anyway let's not use this labels and dictionary inside. We can easily fill them out of the class

And example (it's simplified a bit, please add init if required by linters, also types)

class Timer:        
    def __enter__(self):
        self._start = time.perf_counter()
        self._end = None
        return self

    def __exit__(self, *args):
        self._end = time.perf_counter()

    @property
    def elapsed(self):
        return self._end - self._start
    
    
with Timer() as timer:
    # code
    pass

fit_time = timer.elapsed

"""
Context manager to compute timing for a code block.

Parameters
----------
label : str
Label to store the timing result in the timings dictionary.
timings : dict, optional
Dictionary to store the timing results. If None, timing is not recorded.
"""
if timings is not None:
start_time = time.time()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use time.perf_counter instead, it's more correct for measuring time intervals

yield
timings[label] = round(time.time() - start_time, 5)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't do this

  1. If we need to format the values somehow it should always be done separately. We should separate the computing level and presentation level. From the computing level we should always return the raw values.
  2. In this specific case I think we shouldn't format the value at all. I don't see much sense in it, also we're not doing this for other metrics.

else:
yield


def cross_validate( # pylint: disable=too-many-locals
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update CHANGELOG.MD

dataset: Dataset,
splitter: Splitter,
Expand All @@ -36,6 +57,7 @@ def cross_validate( # pylint: disable=too-many-locals
ref_models: tp.Optional[tp.List[str]] = None,
validate_ref_models: bool = False,
on_unsupported_targets: ErrorBehaviour = "warn",
compute_timings: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add new argument to docstring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this param? what's wrong if we always measure the time?

) -> tp.Dict[str, tp.Any]:
"""
Run cross validation on multiple models with multiple metrics.
Expand Down Expand Up @@ -123,28 +145,16 @@ def cross_validate( # pylint: disable=too-many-locals

# ### Train ref models if any
ref_reco = {}
ref_timings = {}
for model_name in ref_models or []:
model = models[model_name]
model.fit(fold_dataset)
ref_reco[model_name] = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)
model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None

# ### Generate recommendations and calc metrics
for model_name, model in models.items():
if model_name in ref_reco and not validate_ref_models:
continue

if model_name in ref_reco:
reco = ref_reco[model_name]
else:
with compute_timing("fit_time", model_timings):
model.fit(fold_dataset)
reco = model.recommend(

with compute_timing("recommend_time", model_timings):
ref_reco[model_name] = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
Expand All @@ -153,6 +163,33 @@ def cross_validate( # pylint: disable=too-many-locals
on_unsupported_targets=on_unsupported_targets,
)

ref_timings[model_name] = model_timings or {}

# ### Generate recommendations and calc metrics
for model_name, model in models.items():
if model_name in ref_reco and not validate_ref_models:
continue
if model_name in ref_reco:
reco = ref_reco[model_name]
model_timing = ref_timings[model_name]
else:
model_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None # type: ignore

with compute_timing("fit_time", model_timings):
model.fit(fold_dataset)

with compute_timing("recommend_time", model_timings):
reco = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)

model_timing = model_timings or {}

metric_values = calc_metrics(
metrics,
reco=reco,
Expand All @@ -163,6 +200,7 @@ def cross_validate( # pylint: disable=too-many-locals
)
res = {"model": model_name, "i_split": split_info["i_split"]}
res.update(metric_values)
res.update(model_timing)
metrics_all.append(res)

result = {"splits": split_infos, "metrics": metrics_all}
Expand Down
112 changes: 112 additions & 0 deletions tests/model_selection/test_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,5 +371,117 @@ def test_happy_path_with_intersection(
],
"metrics": expected_metrics,
}
assert actual == expected

@pytest.mark.parametrize(
"validate_ref_models,expected_metrics",
(
(
False,
[
{
"model": "random",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.0,
"intersection_popular": 0.5,
},
{
"model": "random",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.5,
"intersection_popular": 0.75,
},
],
),
(
True,
[
{
"model": "popular",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.5,
"intersection_popular": 1.0,
},
{
"model": "random",
"i_split": 0,
"precision@2": 0.5,
"recall@1": 0.0,
"intersection_popular": 0.5,
},
{
"model": "popular",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.25,
"intersection_popular": 1.0,
},
{
"model": "random",
"i_split": 1,
"precision@2": 0.375,
"recall@1": 0.5,
"intersection_popular": 0.75,
},
],
),
),
)
@pytest.mark.parametrize("compute_timings", (False, True))
def test_happy_path_with_intersection_timings(
self,
validate_ref_models: bool,
expected_metrics: tp.List[tp.Dict[str, tp.Any]],
compute_timings: bool,
) -> None:
splitter = LastNSplitter(n=1, n_splits=2, filter_cold_items=False, filter_already_seen=False)

actual = cross_validate(
dataset=self.dataset,
splitter=splitter,
metrics=self.metrics_intersection,
models=self.models,
k=2,
filter_viewed=False,
ref_models=["popular"],
validate_ref_models=validate_ref_models,
compute_timings=compute_timings,
)

time_threshold = 0.5

if compute_timings:
for data in actual["metrics"]:
assert data["fit_time"] < time_threshold
assert data["recommend_time"] < time_threshold

del data["fit_time"]
del data["recommend_time"]

expected = {
"splits": [
{
"i_split": 0,
"test": 2,
"test_items": 2,
"test_users": 2,
"train": 2,
"train_items": 2,
"train_users": 2,
},
{
"i_split": 1,
"test": 4,
"test_items": 3,
"test_users": 4,
"train": 6,
"train_items": 2,
"train_users": 4,
},
],
"metrics": expected_metrics,
}
assert actual == expected