Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: cross validate timings #233

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 57 additions & 19 deletions rectools/model_selection/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import typing as tp
from contextlib import contextmanager

from rectools.columns import Columns
from rectools.dataset import Dataset
Expand All @@ -24,6 +25,26 @@
from .splitter import Splitter


@contextmanager
def compute_timing(label: str, timings: tp.Optional[tp.Dict[str, float]] = None) -> tp.Iterator[None]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this function accepts the label timings param but don't really need it

Let's please rewrite it in one of the following ways:

  1. Remove both params and simply return the elapsed time without dictionaries
  2. Rewrite it as a class

I personally prefer the second option since it's clearer.
But anyway let's not use this labels and dictionary inside. We can easily fill them out of the class

And example (it's simplified a bit, please add init if required by linters, also types)

class Timer:        
    def __enter__(self):
        self._start = time.perf_counter()
        self._end = None
        return self

    def __exit__(self, *args):
        self._end = time.perf_counter()

    @property
    def elapsed(self):
        return self._end - self._start
    
    
with Timer() as timer:
    # code
    pass

fit_time = timer.elapsed

"""
Context manager to compute timing for a code block.

Parameters
----------
label : str
Label to store the timing result in the timings dictionary.
timings : dict, optional
Dictionary to store the timing results. If None, timing is not recorded.
"""
if timings is not None:
start_time = time.time()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use time.perf_counter instead, it's more correct for measuring time intervals

yield
timings[label] = round(time.time() - start_time, 2)
blondered marked this conversation as resolved.
Show resolved Hide resolved
else:
yield


def cross_validate( # pylint: disable=too-many-locals
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update CHANGELOG.MD

dataset: Dataset,
splitter: Splitter,
Expand All @@ -36,6 +57,7 @@ def cross_validate( # pylint: disable=too-many-locals
ref_models: tp.Optional[tp.List[str]] = None,
validate_ref_models: bool = False,
on_unsupported_targets: ErrorBehaviour = "warn",
compute_timings: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add new argument to docstring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need this param? what's wrong if we always measure the time?

) -> tp.Dict[str, tp.Any]:
"""
Run cross validation on multiple models with multiple metrics.
Expand Down Expand Up @@ -123,28 +145,16 @@ def cross_validate( # pylint: disable=too-many-locals

# ### Train ref models if any
ref_reco = {}
ref_res = {}
blondered marked this conversation as resolved.
Show resolved Hide resolved
for model_name in ref_models or []:
model = models[model_name]
model.fit(fold_dataset)
ref_reco[model_name] = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)
ref_timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings and validate_ref_models else None
blondered marked this conversation as resolved.
Show resolved Hide resolved

# ### Generate recommendations and calc metrics
for model_name, model in models.items():
if model_name in ref_reco and not validate_ref_models:
continue

if model_name in ref_reco:
reco = ref_reco[model_name]
else:
with compute_timing("fit_time", ref_timings):
model.fit(fold_dataset)
reco = model.recommend(

with compute_timing("recommend_time", ref_timings):
ref_reco[model_name] = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
Expand All @@ -153,6 +163,33 @@ def cross_validate( # pylint: disable=too-many-locals
on_unsupported_targets=on_unsupported_targets,
)

ref_res[model_name] = ref_timings or {}

# ### Generate recommendations and calc metrics
for model_name, model in models.items():
if model_name in ref_reco and not validate_ref_models:
continue
if model_name in ref_reco:
reco = ref_reco[model_name]
model_res = ref_res[model_name]
else:
timings: tp.Optional[tp.Dict[str, float]] = {} if compute_timings else None

with compute_timing("fit_time", timings):
model.fit(fold_dataset)

with compute_timing("recommend_time", timings):
reco = model.recommend(
users=test_users,
dataset=fold_dataset,
k=k,
filter_viewed=filter_viewed,
items_to_recommend=items_to_recommend,
on_unsupported_targets=on_unsupported_targets,
)

model_res = timings or {}
blondered marked this conversation as resolved.
Show resolved Hide resolved

metric_values = calc_metrics(
metrics,
reco=reco,
Expand All @@ -163,6 +200,7 @@ def cross_validate( # pylint: disable=too-many-locals
)
res = {"model": model_name, "i_split": split_info["i_split"]}
res.update(metric_values)
res.update(model_res)
metrics_all.append(res)

result = {"splits": split_infos, "metrics": metrics_all}
Expand Down
Loading