diff --git a/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts b/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts index c6548aeb25..d62e13b7de 100644 --- a/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts +++ b/apps/widget-e2e/src/describer/modelAssessment/modelAssessmentDatasets.ts @@ -123,7 +123,14 @@ const modelAssessmentDatasets = { "s6" ], modelStatisticsData: { - cohortDropDownValues: ["All data"], + cohortDropDownValues: [ + "All data", + "Cohort Index", + "Cohort Predicted Y", + "Cohort True Y", + "Cohort Regression Error", + "Cohort Age and BMI" + ], defaultXAxis: "Error", defaultXAxisPanelValue: "Error", defaultYAxis: "Cohort", diff --git a/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb b/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb index 960585f6e9..bf2df40fab 100644 --- a/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb +++ b/notebooks/responsibleaidashboard/responsibleaidashboard-diabetes-decision-making.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "75018d5c", "metadata": {}, "source": [ "# Plan real-world action using counterfactual example analysis and causal analysis" @@ -9,6 +10,7 @@ }, { "cell_type": "markdown", + "id": "d4939847", "metadata": {}, "source": [ "This notebook demonstrates the use of the Responsible AI Toolbox to make decisions from diabetes progression data. It walks through the API calls necessary to create a widget with causal inferencing insights, then guides a visual analysis of the data." @@ -16,6 +18,7 @@ }, { "cell_type": "markdown", + "id": "231caa35", "metadata": {}, "source": [ "* [Launch Responsible AI Toolbox](#Launch-Responsible-AI-Toolbox)\n", @@ -28,6 +31,7 @@ }, { "cell_type": "markdown", + "id": "8cfa82d1", "metadata": {}, "source": [ "## Launch Responsible AI Toolbox" @@ -35,6 +39,7 @@ }, { "cell_type": "markdown", + "id": "789b30d1", "metadata": {}, "source": [ "The following section examines the code necessary to create the dataset. It then generates insights using the `responsibleai` API that can be visually analyzed." @@ -42,6 +47,7 @@ }, { "cell_type": "markdown", + "id": "3e43e464", "metadata": {}, "source": [ "### Train a Model\n", @@ -51,6 +57,7 @@ { "cell_type": "code", "execution_count": null, + "id": "a670ba8c", "metadata": {}, "outputs": [], "source": [ @@ -64,6 +71,7 @@ }, { "cell_type": "markdown", + "id": "a4f53194", "metadata": {}, "source": [ "First, load the diabetes dataset and specify the different types of features. Then, clean it and put it into a DataFrame with named columns." @@ -72,6 +80,7 @@ { "cell_type": "code", "execution_count": null, + "id": "479ad4f8", "metadata": {}, "outputs": [], "source": [ @@ -83,6 +92,7 @@ }, { "cell_type": "markdown", + "id": "c7cdd8ae", "metadata": {}, "source": [ "After loading and cleaning the data, split the datapoints into training and test sets. Assemble separate datasets for the training and test data." @@ -91,6 +101,7 @@ { "cell_type": "code", "execution_count": null, + "id": "4e02d132", "metadata": {}, "outputs": [], "source": [ @@ -105,6 +116,7 @@ }, { "cell_type": "markdown", + "id": "59853607", "metadata": {}, "source": [ "Train a nearest-neighbors classifier on the training data." @@ -113,6 +125,7 @@ { "cell_type": "code", "execution_count": null, + "id": "6612038f", "metadata": {}, "outputs": [], "source": [ @@ -122,6 +135,7 @@ }, { "cell_type": "markdown", + "id": "29805164", "metadata": {}, "source": [ "### Create Model and Data Insights" @@ -130,6 +144,7 @@ { "cell_type": "code", "execution_count": null, + "id": "c65f788f", "metadata": {}, "outputs": [], "source": [ @@ -139,6 +154,7 @@ }, { "cell_type": "markdown", + "id": "400de1d9", "metadata": {}, "source": [ "To use Responsible AI Toolbox, initialize a RAIInsights object upon which different components can be loaded.\n", @@ -149,6 +165,7 @@ { "cell_type": "code", "execution_count": null, + "id": "d965f769", "metadata": {}, "outputs": [], "source": [ @@ -158,6 +175,7 @@ }, { "cell_type": "markdown", + "id": "38fbbe06", "metadata": {}, "source": [ "Add the components of the toolbox that are focused on decision-making." @@ -166,6 +184,7 @@ { "cell_type": "code", "execution_count": null, + "id": "24567d8d", "metadata": {}, "outputs": [], "source": [ @@ -178,6 +197,7 @@ }, { "cell_type": "markdown", + "id": "571b2235", "metadata": {}, "source": [ "Once all the desired components have been loaded, compute insights on the test set." @@ -186,6 +206,7 @@ { "cell_type": "code", "execution_count": null, + "id": "a7dec636", "metadata": {}, "outputs": [], "source": [ @@ -194,6 +215,81 @@ }, { "cell_type": "markdown", + "id": "0ad206fd", + "metadata": {}, + "source": [ + "Compose some cohorts which can be injected into the `ResponsibleAIDashboard`." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7a039b34", + "metadata": {}, + "outputs": [], + "source": [ + "from raiwidgets.cohort import Cohort, CohortFilter, CohortFilterMethods\n", + "\n", + "# Cohort on age and bmi features in the dataset\n", + "cohort_filter_age = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_LESS,\n", + " arg=[40],\n", + " column='age')\n", + "cohort_filter_bmi = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_GREATER,\n", + " arg=[0],\n", + " column='bmi')\n", + " \n", + "user_cohort_age_and_bmi= Cohort(name='Cohort Age and BMI')\n", + "user_cohort_age_and_bmi.add_cohort_filter(cohort_filter_age)\n", + "user_cohort_age_and_bmi.add_cohort_filter(cohort_filter_bmi)\n", + "\n", + "# Cohort on index\n", + "cohort_filter_index = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_LESS,\n", + " arg=[20],\n", + " column='Index')\n", + "\n", + "user_cohort_index = Cohort(name='Cohort Index')\n", + "user_cohort_index.add_cohort_filter(cohort_filter_index)\n", + "\n", + "# Cohort on predicted y values\n", + "cohort_filter_predicted_y = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_LESS,\n", + " arg=[165.0],\n", + " column='Predicted Y')\n", + "\n", + "user_cohort_predicted_y = Cohort(name='Cohort Predicted Y')\n", + "user_cohort_predicted_y.add_cohort_filter(cohort_filter_predicted_y)\n", + "\n", + "# Cohort on true y values\n", + "cohort_filter_true_y = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_GREATER,\n", + " arg=[45.0],\n", + " column='True Y')\n", + "\n", + "user_cohort_true_y = Cohort(name='Cohort True Y')\n", + "user_cohort_true_y.add_cohort_filter(cohort_filter_true_y)\n", + "\n", + "# Cohort on true y values\n", + "cohort_filter_regression_error = CohortFilter(\n", + " method=CohortFilterMethods.METHOD_GREATER,\n", + " arg=[20.0],\n", + " column='Error')\n", + "\n", + "user_cohort_regression_error = Cohort(name='Cohort Regression Error')\n", + "user_cohort_regression_error.add_cohort_filter(cohort_filter_regression_error)\n", + "\n", + "cohort_list = [user_cohort_age_and_bmi,\n", + " user_cohort_index,\n", + " user_cohort_predicted_y,\n", + " user_cohort_true_y,\n", + " user_cohort_regression_error]" + ] + }, + { + "cell_type": "markdown", + "id": "54a43b5c", "metadata": {}, "source": [ "Finally, visualize and explore the model insights. Use the resulting widget or follow the link to view this in a new tab." @@ -202,14 +298,16 @@ { "cell_type": "code", "execution_count": null, + "id": "ad84c884", "metadata": {}, "outputs": [], "source": [ - "ResponsibleAIDashboard(rai_insights)" + "ResponsibleAIDashboard(rai_insights, cohort_list=cohort_list)" ] }, { "cell_type": "markdown", + "id": "fb2ab57e", "metadata": {}, "source": [ "## Take Real-World Action" @@ -217,6 +315,7 @@ }, { "cell_type": "markdown", + "id": "84325421", "metadata": {}, "source": [ "### What-If Counterfactuals Analysis" @@ -224,6 +323,7 @@ }, { "cell_type": "markdown", + "id": "d292d247", "metadata": {}, "source": [ "Let's imagine that the diabetes progression scores predicted by the model are used to determine medical insurance rates. If the score is greater than 120, there is a higher rate. Patient 43's model score of 268.08 results in this increased rate, and they want to know how they should change their health to get a lower rate prediction from the model (leading to lower insurance price).\n", @@ -234,6 +334,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "d459156b", "metadata": {}, "source": [ "![What-If Counterfactuals component with datapoint 43 selected on the scatter plot with axes \"Predicted Y\" and \"Index\"](./img/regression-decision-making-1.png)" @@ -241,6 +342,7 @@ }, { "cell_type": "markdown", + "id": "d7b86696", "metadata": {}, "source": [ "What can Patient 43 do to create the desired change? The top ranked features bar plot shows that `bmi` and `s5` are the best to perturb to bring the model score within 120." @@ -249,6 +351,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "b16d1a6c", "metadata": {}, "source": [ "![Top-ranked features (descending) for datapoint 43 to perturb to reduce model prediction below 120: bmi, s5, s4, s3, age, bp, sex, s1, s2, s6](./img/regression-decision-making-2.png)" @@ -256,6 +359,7 @@ }, { "cell_type": "markdown", + "id": "709c3019", "metadata": {}, "source": [ "Let's see how that can be achieved. Change `bmi` to -0.04 and `s5` to -0.042 and see what the result is." @@ -264,6 +368,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "5faa62ea", "metadata": {}, "source": [ "![Counterfactual creation panel. BMI has been changed to -0.04 and s5 has been changed to -0.042](./img/regression-decision-making-3.png)" @@ -271,6 +376,7 @@ }, { "cell_type": "markdown", + "id": "a9f67339", "metadata": {}, "source": [ "As we can see, the model's prediction has dropped to 131.22. Thus, Patient 43 should work on reducing their [body mass index and serum triglycerides level](https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset) to bring the model score under the insurance threshold." @@ -279,6 +385,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "22d445d7", "metadata": {}, "source": [ "![Counterfactual of datapoint 43 selected on the counterfactuals scatter plot with axes \"Predicted Y\" and \"Index\". Predicted Y is 115.4](./img/regression-decision-making-4.png)" @@ -286,6 +393,7 @@ }, { "cell_type": "markdown", + "id": "b4f78fd8", "metadata": {}, "source": [ "Note that this result does not mean that reducing `bmi` and `s5` *causes* the diabetes progression score to go down. It simply decreases the model prediction. To investigate causal relationships, continue reading:" @@ -293,6 +401,7 @@ }, { "cell_type": "markdown", + "id": "b134cdb5", "metadata": {}, "source": [ "### Causal Analysis" @@ -300,6 +409,7 @@ }, { "cell_type": "markdown", + "id": "da76466d", "metadata": {}, "source": [ "Now suppose that a doctor wishes to know how to reduce the progression of diabetes in her patients. This can be explored in the Causal Inference component of the Responsible AI Toolbox.\n", @@ -310,6 +420,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "90b838d8", "metadata": {}, "source": [ "![Overall causal analysis table](./img/regression-decision-making-5.png)\n", @@ -318,6 +429,7 @@ }, { "cell_type": "markdown", + "id": "f6078481", "metadata": {}, "source": [ "Let's revisit Patient 43. Instead of simply reducing the model score, they've decided to focus on actually improving their health to manage their diabetes better. In the \"Individual causal what-if\" tab, it shows that decreasing his/her bmi to 0.05 reduces diabetes progression from 242 to 237.982." @@ -326,6 +438,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "93105414", "metadata": {}, "source": [ "![individual causal analysis table](./img/regression-decision-making-7.png)" @@ -333,6 +446,7 @@ }, { "cell_type": "markdown", + "id": "a6fa7384", "metadata": {}, "source": [ "To put that into a formal intervention policy, switch to the \"Treatment policy\" tab. This view helps build policies for future interventions. You can identify what parts of your sample experience the largest responses to changes in causal features, or treatments, and construct rules to define which future populations should be targeted for particular interventions." @@ -341,6 +455,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "d1af0772", "metadata": {}, "source": [ "![treatment_policy](./img/regression-decision-making-8.png)" @@ -348,6 +463,7 @@ }, { "cell_type": "markdown", + "id": "ac8025e4", "metadata": {}, "source": [ "Is that change the best overall treatment for them? Let's investigate different policies. Going back to the \"Treatment policy\" tab, we see that going with the above intervention of s2 feature outperforms perturbing that with a \"always increase\" intervention." @@ -356,6 +472,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "ce677d35", "metadata": {}, "source": [ "![image.png](./img/regression-decision-making-9.png)" @@ -363,6 +480,7 @@ }, { "cell_type": "markdown", + "id": "3355ea1c", "metadata": {}, "source": [ "Finally, you can see a list demonstrating which datapoints (patients) in the current data sample have the largest causal response to the selected treatment (s2 feature change), based on all features included in the estimated causal model." @@ -371,6 +489,7 @@ { "attachments": {}, "cell_type": "markdown", + "id": "3cb02322", "metadata": {}, "source": [ "![causal-table](./img/regression-decision-making-10.png)" @@ -393,7 +512,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.7.11" } }, "nbformat": 4,