Skip to content

Commit

Permalink
📝
Browse files Browse the repository at this point in the history
Signed-off-by: miguelgfierro <[email protected]>
  • Loading branch information
miguelgfierro committed Dec 30, 2023
1 parent c5a0a12 commit 0467655
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 20 deletions.
4 changes: 0 additions & 4 deletions docs/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ root: intro
defaults:
numbered: false
parts:
- caption: Getting Started
chapters:
- file: ../SETUP.md
- file: TEMP
- caption: Recommenders API Documentation
chapters:
- file: datasets
Expand Down
47 changes: 35 additions & 12 deletions recommenders/evaluation/python_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,28 @@


class ColumnMismatchError(Exception):
"""Exception raised when there is a mismatch in columns.
This exception is raised when an operation involving columns
encounters a mismatch or inconsistency.
Attributes:
message (str): Explanation of the error.
"""

pass


class ColumnTypeMismatchError(Exception):
"""Exception raised when there is a mismatch in column types.
This exception is raised when an operation involving column types
encounters a mismatch or inconsistency.
Attributes:
message (str): Explanation of the error.
"""

pass


Expand All @@ -63,7 +81,7 @@ def check_column_dtypes_wrapper(
col_item=DEFAULT_ITEM_COL,
col_prediction=DEFAULT_PREDICTION_COL,
*args,
**kwargs
**kwargs,
):
"""Check columns of DataFrame inputs
Expand All @@ -81,12 +99,16 @@ def check_column_dtypes_wrapper(
expected_true_columns.add(kwargs["col_rating"])
if not has_columns(rating_true, expected_true_columns):
raise ColumnMismatchError("Missing columns in true rating DataFrame")

if not has_columns(rating_pred, {col_user, col_item, col_prediction}):
raise ColumnMismatchError("Missing columns in predicted rating DataFrame")

if not has_same_base_dtype(rating_true, rating_pred, columns=[col_user, col_item]):
raise ColumnTypeMismatchError("Columns in provided DataFrames are not the same datatype")

if not has_same_base_dtype(
rating_true, rating_pred, columns=[col_user, col_item]
):
raise ColumnTypeMismatchError(
"Columns in provided DataFrames are not the same datatype"
)

return func(
rating_true=rating_true,
Expand All @@ -95,7 +117,7 @@ def check_column_dtypes_wrapper(
col_item=col_item,
col_prediction=col_prediction,
*args,
**kwargs
**kwargs,
)

return check_column_dtypes_wrapper
Expand Down Expand Up @@ -750,7 +772,9 @@ def map_at_k(
if df_merge is None:
return 0.0
else:
return (df_merge["rr"] / df_merge["actual"].apply(lambda x: min(x, k))).sum() / n_users
return (
df_merge["rr"] / df_merge["actual"].apply(lambda x: min(x, k))
).sum() / n_users


def get_top_k_items(
Expand Down Expand Up @@ -837,7 +861,7 @@ def check_column_dtypes_diversity_serendipity_wrapper(
col_sim=DEFAULT_SIMILARITY_COL,
col_relevance=None,
*args,
**kwargs
**kwargs,
):
"""Check columns of DataFrame inputs
Expand Down Expand Up @@ -904,7 +928,7 @@ def check_column_dtypes_diversity_serendipity_wrapper(
col_sim=col_sim,
col_relevance=col_relevance,
*args,
**kwargs
**kwargs,
)

return check_column_dtypes_diversity_serendipity_wrapper
Expand Down Expand Up @@ -933,7 +957,7 @@ def check_column_dtypes_novelty_coverage_wrapper(
col_user=DEFAULT_USER_COL,
col_item=DEFAULT_ITEM_COL,
*args,
**kwargs
**kwargs,
):
"""Check columns of DataFrame inputs
Expand Down Expand Up @@ -969,7 +993,7 @@ def check_column_dtypes_novelty_coverage_wrapper(
col_user=col_user,
col_item=col_item,
*args,
**kwargs
**kwargs,
)

return check_column_dtypes_novelty_coverage_wrapper
Expand Down Expand Up @@ -1006,7 +1030,6 @@ def _get_cosine_similarity(
col_item=DEFAULT_ITEM_COL,
col_sim=DEFAULT_SIMILARITY_COL,
):

if item_sim_measure == "item_cooccurrence_count":
# calculate item-item similarity based on item co-occurrence count
df_cosine_similarity = _get_cooccurrence_similarity(
Expand Down
8 changes: 4 additions & 4 deletions recommenders/evaluation/spark_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def precision_at_k(self):
Note:
More details can be found
`on this website <http://spark.apache.org/docs/2.1.1/api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics.precisionAt>`_.
`on the precisionAt PySpark documentation <http://spark.apache.org/docs/3.0.0/api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics.precisionAt>`_.
Return:
float: precision at k (min=0, max=1)
Expand All @@ -318,7 +318,7 @@ def recall_at_k(self):
Note:
More details can be found
`here <http://spark.apache.org/docs/2.1.1/api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics.meanAveragePrecision>`_.
`on the recallAt PySpark documentation <http://spark.apache.org/docs/3.0.0/api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics.recallAt>`_.
Return:
float: recall at k (min=0, max=1).
Expand All @@ -330,7 +330,7 @@ def ndcg_at_k(self):
Note:
More details can be found
`on <http://spark.apache.org/docs/2.1.1/api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics.ndcgAt>`_.
`on the ndcgAt PySpark documentation <http://spark.apache.org/docs/3.0.0/api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics.ndcgAt>`_.
Return:
float: nDCG at k (min=0, max=1).
Expand All @@ -349,7 +349,7 @@ def map_at_k(self):
"""Get mean average precision at k.
Note:
More details `on this link <http://spark.apache.org/docs/2.1.1/api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics.meanAveragePrecision>`_.
More details `on the meanAveragePrecision PySpark documentation <http://spark.apache.org/docs/3.0.0/api/python/pyspark.mllib.html#pyspark.mllib.evaluation.RankingMetrics.meanAveragePrecision>`_.
Return:
float: MAP at k (min=0, max=1).
Expand Down

0 comments on commit 0467655

Please sign in to comment.