-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #12 from HanXudong/dataset
update tutorials
- Loading branch information
Showing
14 changed files
with
10,005 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,7 @@ | |
if __name__ == '__main__': | ||
setup( | ||
name='fairlib', | ||
version="0.0.8", | ||
version="0.0.9", | ||
author="Xudong Han", | ||
author_email="[email protected]", | ||
description='Unified framework for assessing and improving fairness.', | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Large diffs are not rendered by default.
Oops, something went wrong.
216 changes: 216 additions & 0 deletions
216
tutorial/archive/fairlib_v0.0.6/interactive_plots.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
348 changes: 348 additions & 0 deletions
348
tutorial/archive/fairlib_v0.0.6/manipulate_data_distribution.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,348 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Manilulating Data Distribution\n", | ||
"\n", | ||
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HanXudong/fairlib/blob/main/tutorial/manipulate_data_distribution.ipynb)\n", | ||
"\n", | ||
"In this notebook, we will introduce the built-in function for manipulating data distributions.\n", | ||
"\n", | ||
"Overall, the process can be described as:\n", | ||
"1. Load data from files\n", | ||
"2. Analysis the loaded dataset distribution\n", | ||
"3. Resample isntacnes based on their target labels and protected labels\n", | ||
"\n", | ||
"Impelemtations for steps 1 and 4 are the same with others. Please see [dataloader](https://hanxudong.github.io/fairlib/reference_api_dataloaders.html) for detailed description." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load data\n", | ||
"\n", | ||
"Similar to long-tail leaning, the data distribution is essentially the label distributions. However, besides the target label in long-tail leanring, we also need to consider the demographic labels.\n", | ||
"\n", | ||
"Since the input distribution is not required for manipulating data distributions, we just create synthetic labels for demostration purpose." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"\n", | ||
"y = []\n", | ||
"g = []\n", | ||
"for _y, _g, _n in zip(\n", | ||
" [1,1,1,0,0,0], # Target labels, y\n", | ||
" [2,1,0,1,2,0], # Protected labels, g\n", | ||
" [4,5,6,3,7,9] # Number of instances with target label _y and group label _g\n", | ||
" ):\n", | ||
" y = y + [_y]*_n\n", | ||
" g = g + [_g]*_n\n", | ||
"\n", | ||
"y = np.array(y)\n", | ||
"g = np.array(g)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Analysis the loaded dataset distribution\n", | ||
"\n", | ||
"Essentialy, we are instereseted with the joint distribution of target label *y* and protected label *g*, which are both discrete random variables.\n", | ||
"\n", | ||
"We save the corresponding results in probability tables:\n", | ||
"\n", | ||
"Given target label and protected labels, calculate emprical distributions.\n", | ||
"\n", | ||
"- joint_dist: n_class * n_groups matrix, where each element refers to the joint probability, i.e., proportion size.\n", | ||
"- g_dist: n_groups array, indicating the prob of each group\n", | ||
"- y_dist: n_class array, indicating the prob of each class\n", | ||
"- g_cond_y_dist: n_class * n_groups matrix, g_cond_y_dit[y_id,:] refers to the group distribution within class y_id\n", | ||
"- y_cond_g_dist: n_class * n_groups matrix, y_cond_g_dit[:,g_id] refers to the class distribution within group g_id" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from fairlib.src.dataloaders.generalized_BT import get_data_distribution" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"dict_keys(['joint_dist', 'g_dist', 'y_dist', 'g_cond_y_dist', 'y_cond_g_dist', 'yg_index', 'N'])\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"synthetic_data_distribution = get_data_distribution(y_data = y, g_data = g)\n", | ||
"print(synthetic_data_distribution.keys())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([[0.26470588, 0.08823529, 0.20588235],\n", | ||
" [0.17647059, 0.14705882, 0.11764706]])" | ||
] | ||
}, | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# n_class * n_groups matrix, where each element refers to the joint probability\n", | ||
"synthetic_data_distribution[\"joint_dist\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Resample isntacnes based on their target labels and protected labels\n", | ||
"\n", | ||
"In order to specify a target distribution of *y* and *g*, there are 5 options in the `generalized_sampling` function, which can be classified into 3 types as:\n", | ||
"1. specify the target joint distribution of *y* and *g*, i.e., *p(y,g)*.\n", | ||
"2. specify the y_dist (*p(y)*) and g_cond_y_dist (*p(g|y)*), and *p(y,g) = p(g|y)p(y)*\n", | ||
"3. specify the g_dist (*p(g)*) and y_cond_g_dist (*p(y|g)*), and *p(y,g) = p(y|g)p(g)*" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from fairlib.src.dataloaders.generalized_BT import generalized_sampling" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Taking the group label distribution as an example, we would like the group label is uniformaly distributed, i.e., *p(g=0)=p(g=1)=p(g=2)*" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"array([0.33333333, 0.33333333, 0.33333333])" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"target_g_dist = np.ones(3)/3\n", | ||
"target_g_dist" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[25, 26, 27, 28, 15, 16, 17, 18, 19, 20, 21, 9, 10, 11, 4, 5, 6, 7, 0, 1]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"balanced_g_indices = generalized_sampling(\n", | ||
" default_distribution_dict = synthetic_data_distribution,\n", | ||
" N = 20,\n", | ||
" g_dist=target_g_dist,)\n", | ||
"print(balanced_g_indices)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"balanced_g_distribution = get_data_distribution(\n", | ||
" y_data = y[balanced_g_indices], \n", | ||
" g_data = g[balanced_g_indices])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'joint_dist': array([[0.2 , 0.15, 0.2 ],\n", | ||
" [0.15, 0.2 , 0.1 ]]),\n", | ||
" 'g_dist': array([0.35, 0.35, 0.3 ]),\n", | ||
" 'y_dist': array([0.55, 0.45]),\n", | ||
" 'g_cond_y_dist': array([[0.36363636, 0.27272727, 0.36363636],\n", | ||
" [0.33333333, 0.44444444, 0.22222222]]),\n", | ||
" 'y_cond_g_dist': array([[0.57142857, 0.42857143, 0.66666667],\n", | ||
" [0.42857143, 0.57142857, 0.33333333]]),\n", | ||
" 'yg_index': {(0, 0): array([0, 1, 2, 3], dtype=int64),\n", | ||
" (0, 1): array([4, 5, 6], dtype=int64),\n", | ||
" (0, 2): array([ 7, 8, 9, 10], dtype=int64),\n", | ||
" (1, 0): array([11, 12, 13], dtype=int64),\n", | ||
" (1, 1): array([14, 15, 16, 17], dtype=int64),\n", | ||
" (1, 2): array([18, 19], dtype=int64)},\n", | ||
" 'N': 20.0}" | ||
] | ||
}, | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"balanced_g_distribution" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Theoritically, we could sample to any data distributions with the `generalized_sampling` function. But it is inconvinent to specify distribution arraies every time. We further provide the `manipulate_data_distribution` function to simplify the manipulation process.\n", | ||
"\n", | ||
"Specifically, the manipulation is based on the iterpolation between the original data distribution and a specific balanced target distrition. Essentially, these balanced target distributions are identical to [**BT**](https://hanxudong.github.io/fairlib/reference_api_debiasing/BT.html).\n", | ||
"\n", | ||
"The *alpha* refers to the interpolation extent as $final\\_dist = \\alpha*balanced\\_dist + (1-\\alpha)*origianl\\_dist$\n", | ||
"\n", | ||
"- default_distribution_dict (dict): a dict of distribution information of the original dataset.\n", | ||
"- N (int, optional): The total number of returned indices. Defaults to None.\n", | ||
"- GBTObj (str, optional): original | joint | g | y | g_cond_y | y_cond_g. Defaults to \"original\".\n", | ||
"- alpha (int, optional): interpolation between the original distribution and the target distribution. Defaults to 1." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from fairlib.src.dataloaders.generalized_BT import manipulate_data_distribution" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The following examples show to what extent can we manilutale the distribution by tuning the alpha value\n", | ||
"- aphla > 1: anto-imbalance.\n", | ||
"- aphla = 1: perfectly balanced distribution as corresponding BT objectives.\n", | ||
"- 0 < aphla < 1: interpolation between original distribution and the perfectly balanced distribution.\n", | ||
"- alpha = 0: original distribution.\n", | ||
"- alpha < 0: amplify imbalance." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"g_indices_10 = manipulate_data_distribution(default_distribution_dict = synthetic_data_distribution, N = 20, GBTObj = \"g\", alpha =1)\n", | ||
"g_indices_00 = manipulate_data_distribution(default_distribution_dict = synthetic_data_distribution, N = 20, GBTObj = \"g\", alpha =0)\n", | ||
"g_indices_05 = manipulate_data_distribution(default_distribution_dict = synthetic_data_distribution, N = 20, GBTObj = \"g\", alpha =0.5)\n", | ||
"g_indices_n10 = manipulate_data_distribution(default_distribution_dict = synthetic_data_distribution, N = 20, GBTObj = \"g\", alpha =-1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Original: [0.44117647 0.23529412 0.32352941]\n", | ||
"Balanced g: [0.35 0.35 0.3 ]\n", | ||
"aplha 0.0: [0.45 0.25 0.3 ]\n", | ||
"aplha 0.5: [0.4 0.3 0.3]\n", | ||
"aplha 1.0: [0.35 0.35 0.3 ]\n", | ||
"aplha -1.0: [0.55 0.15 0.3 ]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\"Original:\", synthetic_data_distribution[\"g_dist\"])\n", | ||
"print(\"Balanced g:\", get_data_distribution(y_data = y[balanced_g_indices], g_data = g[balanced_g_indices])[\"g_dist\"])\n", | ||
"print(\"aplha 0.0:\", get_data_distribution(y_data = y[g_indices_00], g_data = g[g_indices_00])[\"g_dist\"])\n", | ||
"print(\"aplha 0.5:\", get_data_distribution(y_data = y[g_indices_05], g_data = g[g_indices_05])[\"g_dist\"])\n", | ||
"print(\"aplha 1.0:\", get_data_distribution(y_data = y[g_indices_10], g_data = g[g_indices_10])[\"g_dist\"])\n", | ||
"print(\"aplha -1.0:\", get_data_distribution(y_data = y[g_indices_n10], g_data = g[g_indices_n10])[\"g_dist\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Limitation and Extension\n", | ||
"\n", | ||
"The `manipulate_data_distribution` function has an strong assumption that the original distribution is **NOT** identical to the target perfectly balanced distribution. For those who has a perfectly balanced dataset, the target distributions will need to the manully speficied and use the `generalized_sampling` for resampling." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"interpreter": { | ||
"hash": "eb4eca40f7710e7c7146430b0424b585ee7a07b7e7498958627feae1c8ad8261" | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3.7.12 ('py37')", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.12" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.