Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Repeat - Explore metrics #2083

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion recbole/evaluator/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,32 @@ def eval_batch_collect(
result = torch.cat((pos_idx, pos_len_list), dim=1)
self.data_struct.update_tensor("rec.topk", result)

if self.register.need("rec.topk_repeat"):
_, topk_idx = torch.topk(
scores_tensor, max(self.topk), dim=-1
) # n_users x k

repeat_mask = (interaction[self.config["ITEM_ID_FIELD"] + "_list"] == positive_i[..., None]).sum(dim=1) > 0
pos_matrix = torch.zeros_like(scores_tensor, dtype=torch.int)
pos_matrix[positive_u[repeat_mask], positive_i[repeat_mask]] = 1
pos_len_list = pos_matrix.sum(dim=1, keepdim=True)
pos_idx = torch.gather(pos_matrix, dim=1, index=topk_idx)
result = torch.cat((pos_idx, pos_len_list), dim=1)
self.data_struct.update_tensor("rec.topk_repeat", result)

if self.register.need("rec.topk_explore"):
_, topk_idx = torch.topk(
scores_tensor, max(self.topk), dim=-1
) # n_users x k

explore_mask = (interaction[self.config["ITEM_ID_FIELD"] + "_list"] == positive_i[..., None]).sum(dim=1) < 1
pos_matrix = torch.zeros_like(scores_tensor, dtype=torch.int)
pos_matrix[positive_u[explore_mask], positive_i[explore_mask]] = 1
pos_len_list = pos_matrix.sum(dim=1, keepdim=True)
pos_idx = torch.gather(pos_matrix, dim=1, index=topk_idx)
result = torch.cat((pos_idx, pos_len_list), dim=1)
self.data_struct.update_tensor("rec.topk_explore", result)

if self.register.need("rec.meanrank"):

desc_scores, desc_index = torch.sort(scores_tensor, dim=-1, descending=True)
Expand Down Expand Up @@ -196,6 +222,8 @@ def eval_batch_collect(
self.data_struct.update_tensor(
"data.label", interaction[self.label_field].to(self.device)
)
if self.register.need("data.item_list"):
self.data_struct.update_tensor("data.item_list", interaction[self.config["ITEM_ID_FIELD"] + "_list"])

def model_collect(self, model: torch.nn.Module):
"""Collect the evaluation resource from model.
Expand Down Expand Up @@ -226,7 +254,7 @@ def get_data_struct(self):
for key in self.data_struct._data_dict:
self.data_struct._data_dict[key] = self.data_struct._data_dict[key].cpu()
returned_struct = copy.deepcopy(self.data_struct)
for key in ["rec.topk", "rec.meanrank", "rec.score", "rec.items", "data.label"]:
for key in ["rec.topk", "rec.meanrank", "rec.score", "rec.items", "data.label", "rec.topk_repeat", "rec.topk_explore"]:
if key in self.data_struct:
del self.data_struct[key]
return returned_struct
182 changes: 179 additions & 3 deletions recbole/evaluator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
:math:`{r}_{u i}` represents the ground-truth labels.

"""

from logging import getLogger

import numpy as np
import torch

from collections import Counter
from sklearn.metrics import auc as sk_auc
from sklearn.metrics import mean_absolute_error, mean_squared_error
Expand All @@ -33,6 +34,7 @@
from recbole.evaluator.base_metric import AbstractMetric, TopkMetric, LossMetric
from recbole.utils import EvaluatorType


# TopK Metrics


Expand Down Expand Up @@ -157,7 +159,68 @@ def calculate_metric(self, dataobject):
return metric_dict

def metric_info(self, pos_index, pos_len):
return np.cumsum(pos_index, axis=1) / pos_len.reshape(-1, 1)
result = np.cumsum(pos_index, axis=1) / pos_len.reshape(-1, 1)
result[(np.isnan(result)) | (result == np.inf) | (result == -np.inf)] = 0

return result


class RepeatRecall(Recall):
"""
RepeatRecall calculates the contribution of repeat item (item that exist in the user's history) to the recall

.. _repeat_recall:
https://arxiv.org/abs/2109.14233 described in section 5.4 as relative contribution of repetition and exploration

.. math::
\mathrm {Recall@K} = \frac{1}{|U|}\sum_{u \in U} \frac{|\hat{R}(u) \cap R(u)^{repeat}|}{|R(u)|}

:math:`|R(u)|` represents the item count of :math:`R(u)`
`\hat{R}^{repeat}` represents the set of predicted repeat items
"""
metric_need = ["rec.topk_repeat"]

def used_info(self, dataobject):
"""Get the bool matrix indicating whether the corresponding item is positive and a repeat item
and number of positive items for each user.
"""
rec_mat = dataobject.get("rec.topk_repeat")
topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1)
return topk_idx.to(torch.bool).numpy(), pos_len_list.squeeze(-1).numpy()

def calculate_metric(self, dataobject):
pos_index, pos_len = self.used_info(dataobject)
result = self.metric_info(pos_index, pos_len)
metric_dict = self.topk_result("repeat_recall", result)
return metric_dict


class ExploreRecall(Recall):
"""
ExploreRecall calculates the contribution of explore item (item that don't exist in the user's history) to the recall

.. _explore_recall:
https://arxiv.org/abs/2109.14233 described in section 5.4 as relative contribution of repetition and exploration

.. math::
\mathrm {Recall@K} = \frac{1}{|U|}\sum_{u \in U} \frac{|\hat{R}(u) \cap R(u)^{explore}|}{|R(u)|}

:math:`|R(u)|` represents the item count of :math:`R(u)`
`R^{explore}` represents the set of predicted explore items
"""

metric_need = ["rec.topk_explore"]

def used_info(self, dataobject):
rec_mat = dataobject.get("rec.topk_explore")
topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1)
return topk_idx.to(torch.bool).numpy(), pos_len_list.squeeze(-1).numpy()

def calculate_metric(self, dataobject):
pos_index, pos_len = self.used_info(dataobject)
result = self.metric_info(pos_index, pos_len)
metric_dict = self.topk_result("explore_recall", result)
return metric_dict


class NDCG(TopkMetric):
Expand Down Expand Up @@ -202,6 +265,63 @@ def metric_info(self, pos_index, pos_len):
return result


class RepeatNDCG(NDCG):
r"""RepeatNDCG_ measure the performance contribution of repeat item to the NDCG metric

.. _repeat_NDCG:
https://arxiv.org/abs/2109.14233 described in section 5.4 as relative contribution of repetition and exploration

.. math::
\mathrm {NDCG@K} = \frac{1}{|U|}\sum_{u \in U} (\frac{1}{\sum_{i=1}^{\min (|R(u)^{repeat}|, K)}
\frac{1}{\log _{2}(i+1)}} \sum_{i=1}^{K} \delta(i \in R(u)^{repeat}) \frac{1}{\log _{2}(i+1)})

:math:`\delta(·)` is an indicator function.
"""
metric_need = ["rec.topk_repeat"]


def used_info(self, dataobject):
rec_mat = dataobject.get("rec.topk_repeat")
topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1)
return topk_idx.to(torch.bool).numpy(), pos_len_list.squeeze(-1).numpy()

def calculate_metric(self, dataobject):
pos_index, pos_len = self.used_info(dataobject)
result = self.metric_info(pos_index, pos_len)
metric_dict = self.topk_result("repeat_ndcg", result)
return metric_dict


class ExploreNDCG(NDCG):
r"""ExploreNDCG_ measure the performance contribution of explore item (item that weren't in the user's history) to
the NDCG metric

.. _repeat_NDCG:
https://arxiv.org/abs/2109.14233 described in section 5.4 as relative contribution of repetition and exploration

.. math::
\mathrm {NDCG@K} = \frac{1}{|U|}\sum_{u \in U} (\frac{1}{\sum_{i=1}^{\min (|R(u)^{explore}|, K)}
\frac{1}{\log _{2}(i+1)}} \sum_{i=1}^{K} \delta(i \in R(u)^{explore}) \frac{1}{\log _{2}(i+1)})

:math:`\delta(·)` is an indicator function.
"""
metric_need = ["rec.topk_explore"]


def used_info(self, dataobject):
rec_mat = dataobject.get("rec.topk_explore")
topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1)
return topk_idx.to(torch.bool).numpy(), pos_len_list.squeeze(-1).numpy()

def calculate_metric(self, dataobject):
pos_index, pos_len = self.used_info(dataobject)
result = self.metric_info(pos_index, pos_len)
metric_dict = self.topk_result("explore_ndcg", result)
return metric_dict




class Precision(TopkMetric):
r"""Precision_ (also called positive predictive value) is a measure for computing the fraction of relevant items
out of all the recommended items. We average the metric for each user :math:`u` get the final result.
Expand All @@ -224,7 +344,63 @@ def calculate_metric(self, dataobject):
return metric_dict

def metric_info(self, pos_index):
return pos_index.cumsum(axis=1) / np.arange(1, pos_index.shape[1] + 1)
result = pos_index.cumsum(axis=1) / np.arange(1, pos_index.shape[1] + 1)
result[(np.isnan(result)) | (result == np.inf) | (result == -np.inf)] = 0
return result


class RepeatPrecision(Precision):
r"""RepeatPrecision_ measure the performance contribution of repeat item (item that were in the user's history) to
the precision

.. _repeat_precision:
https://arxiv.org/abs/2109.14233 described in section 5.4 as relative contribution of repetition and exploration

.. math::
\mathrm {Precision@K} = \frac{1}{|U|}\sum_{u \in U} \frac{|\hat{R}(u) \cap R(u)^{repeat}|}{|\hat {R}(u)|}

:math:`|\hat R(u)|` represents the item count of :math:`\hat R(u)`.
"""

metric_need = ["rec.topk_repeat"]

def used_info(self, dataobject):
rec_mat = dataobject.get("rec.topk_repeat")
topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1)
return topk_idx.to(torch.bool).numpy(), pos_len_list.squeeze(-1).numpy()

def calculate_metric(self, dataobject):
pos_index, _ = self.used_info(dataobject)
result = self.metric_info(pos_index)
metric_dict = self.topk_result("repeat_precision", result)
return metric_dict


class ExplorePrecision(Precision):
r"""ExplorePrecision_ measure the performance contribution of explore item (item that weren't in the user's history)
to the precision

.. _repeat_precision:
https://arxiv.org/abs/2109.14233 described in section 5.4 as relative contribution of repetition and exploration

.. math::
\mathrm {Precision@K} = \frac{1}{|U|}\sum_{u \in U} \frac{|\hat{R}(u) \cap R(u)^{repeat}|}{|\hat {R}(u)|}

:math:`|\hat R(u)|` represents the item count of :math:`\hat R(u)`.
"""

metric_need = ["rec.topk_explore"]

def used_info(self, dataobject):
rec_mat = dataobject.get("rec.topk_explore")
topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1)
return topk_idx.to(torch.bool).numpy(), pos_len_list.squeeze(-1).numpy()

def calculate_metric(self, dataobject):
pos_index, _ = self.used_info(dataobject)
result = self.metric_info(pos_index)
metric_dict = self.topk_result("explore_precision", result)
return metric_dict


# CTR Metrics
Expand Down