Skip to content

Commit

Permalink
Merge pull request #380 from monarch-initiative/379-infinity-values-c…
Browse files Browse the repository at this point in the history
…reate-an-error-with-roc-curves

379 infinity values create an error with roc curves
  • Loading branch information
yaseminbridges authored Jan 7, 2025
2 parents eb7ba11 + 80386bc commit d28e702
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pheval"
version = "0.4.2"
version = "0.4.3"
description = ""
authors = ["Yasemin Bridges <[email protected]>",
"Julius Jacobsen <[email protected]>",
Expand Down
19 changes: 17 additions & 2 deletions src/pheval/analyse/generate_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

import matplotlib
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
Expand Down Expand Up @@ -368,9 +369,16 @@ def generate_roc_curve(
"""
plt.clf()
for i, benchmark_result in enumerate(benchmarking_results):
y_score = np.array(benchmark_result.binary_classification_stats.scores)
y_score = np.nan_to_num(
y_score,
nan=0.0,
posinf=max(y_score[np.isfinite(y_score)]),
neginf=min(y_score[np.isfinite(y_score)]),
)
fpr, tpr, thresh = roc_curve(
benchmark_result.binary_classification_stats.labels,
benchmark_result.binary_classification_stats.scores,
y_score,
pos_label=1,
)
roc_auc = auc(fpr, tpr)
Expand Down Expand Up @@ -411,9 +419,16 @@ def generate_precision_recall(
plt.clf()
plt.figure()
for i, benchmark_result in enumerate(benchmarking_results):
y_score = np.array(benchmark_result.binary_classification_stats.scores)
y_score = np.nan_to_num(
y_score,
nan=0.0,
posinf=max(y_score[np.isfinite(y_score)]),
neginf=min(y_score[np.isfinite(y_score)]),
)
precision, recall, thresh = precision_recall_curve(
benchmark_result.binary_classification_stats.labels,
benchmark_result.binary_classification_stats.scores,
y_score,
)
precision_recall_auc = auc(recall, precision)
plt.plot(
Expand Down

0 comments on commit d28e702

Please sign in to comment.