-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
inclusion of double corrected variance estimator
- Loading branch information
Showing
1 changed file
with
78 additions
and
0 deletions.
There are no files selected for viewing
78 changes: 78 additions & 0 deletions
78
fairlib/src/evaluators/double_corrected_variance_estimator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import pandas as pd | ||
import numpy as np | ||
from .evaluator import confusion_matrix_based_scores | ||
|
||
def group_level_metrics(confusion_matrices, metric, class_id=1): | ||
"""organizing the evaluation scores for each group. | ||
Args: | ||
confusion_matrices (dict): a dictionary of confusion_matrices of each group. | ||
metric (str): metic name, such as TPR and FPR. | ||
class_id (str): the class id of the metrics. Default to 1. | ||
Returns: | ||
pd.DataFrame: evaluation scores for each group with group size. | ||
""" | ||
metric_df = [] | ||
|
||
group_keys = list(confusion_matrices.keys()) | ||
group_keys.remove("overall") | ||
|
||
for gid in group_keys: | ||
cnf_k = confusion_matrices[gid] | ||
if metric in ["TPR", "TNR"]: | ||
n_k = np.sum(cnf_k, axis=1)[class_id] | ||
elif metric in ["FPR", "FNR"]: | ||
n_k = np.sum(cnf_k) - np.sum(cnf_k, axis=1)[class_id] | ||
metric_k = confusion_matrix_based_scores(cnf_k)[metric][class_id] | ||
metric_df.append({ | ||
"gid":gid, | ||
"metric_k":metric_k, | ||
"n_k":n_k, | ||
}) | ||
metric_df = pd.DataFrame(metric_df) | ||
|
||
return metric_df | ||
|
||
def double_correction(metric_df, n_sample = 1000, threshold = False, sample_variance = True): | ||
"""Calculated corrected variance estimation. | ||
Args: | ||
metric_df (pd.DataFrame): results of group_level_metrics. | ||
n_sample (int, optional): number of trails in bootstrapping. Defaults to 1000. | ||
threshold (bool, optional): whether or not replace negative corrected variance to 0. Defaults to False. | ||
sample_variance (bool, optional): Use sample variance if true, and population variance otherwise. Defaults to True. | ||
Returns: | ||
pd.DataFrame: A DataFrame containing uncorrected_var, corrected_var, and double_corrected_var of all bootstrapping samples. | ||
""" | ||
bootstrapping = [] | ||
for _value in metric_df.values: | ||
bootstrapping.append( | ||
np.random.binomial(n=_value[2], p=_value[1], size=n_sample)/_value[2] | ||
) | ||
mu_hats = np.stack(bootstrapping) | ||
n_groups = len(metric_df.values) | ||
nks = metric_df.values[:,2] | ||
|
||
|
||
if sample_variance: | ||
uncorrected_var = np.var(mu_hats, axis=0)*n_groups/(n_groups-1) | ||
else: | ||
uncorrected_var = np.var(mu_hats, axis=0) | ||
sigma2_hats = (mu_hats*(1-mu_hats)/nks.reshape(-1,1)) | ||
sigma2_hats_mean = np.mean(sigma2_hats, axis=0) | ||
corrected_var = uncorrected_var - sigma2_hats_mean | ||
double_corrected_var = uncorrected_var - 2*sigma2_hats_mean + np.mean(sigma2_hats/nks.reshape(-1,1), axis=0) | ||
|
||
if threshold: | ||
corrected_var = np.maximum(0, corrected_var) | ||
double_corrected_var = np.maximum(0, double_corrected_var) | ||
|
||
results_df = pd.DataFrame({ | ||
"uncorrected_var":uncorrected_var, | ||
"corrected_var":corrected_var, | ||
"double_corrected_var":double_corrected_var, | ||
}) | ||
|
||
return results_df |