-
Notifications
You must be signed in to change notification settings - Fork 146
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
594 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
--- | ||
title: "Untitled" | ||
output: html_document | ||
--- | ||
|
||
```{r setup, include=FALSE} | ||
knitr::opts_chunk$set(echo = TRUE) | ||
``` | ||
|
||
## ID and response | ||
|
||
* 'id': unique identifier for the customer | ||
* 'attrition_flag': whether the customer is churned (0 = no; 1 = yes) | ||
|
||
## Categorical | ||
|
||
* 'gender': gender of the customer | ||
* 'education_level': education level of the customer | ||
* 'income_category': income range of the customer | ||
|
||
# Numeric | ||
|
||
* 'total_relationship_count': number of relationships | ||
* 'months_inactive_12_mon': number of months the customer is inactive in the past 12 months | ||
* 'credit_limit': the customer's credit limit | ||
* 'total_revolving_bal': the customer's total revolving balance | ||
* 'total_amt_chng_q4_q1': the amount the balance changed from Q4 to Q1 | ||
* 'total_trans_amt': the value of all the customer's transactions in the period | ||
* 'total_trans_ct': the number of all of the customer's transactions | ||
* 'total_ct_chng_q4_q1': the difference in number of the customer's transactions from Q4 to Q1 | ||
* 'avg_utilization_ratio': the customer's average utilization ratio during the period | ||
|
||
Strategy: | ||
|
||
* change attrition flag to "churned" yes/no | ||
* create summarize_churn function early on | ||
* for linear, log on credit limit (maybe others?) | ||
* Lot of numeric, so xgboost early on | ||
* Try turning income category and education level into ordinal numbers (unknown as either 0 or as mean impute) | ||
|
||
EDA: | ||
|
||
* Density of all numeric, comparing attrition and not | ||
* AUC of attrition vs all predictors | ||
|
||
* Relationship between relationships/months inactive (line plot) | ||
* Credit limit, age, total revolving balance, total amt change, (density plot or bucketed line plot) | ||
* gender, Education + income and churn (bar plot) | ||
* Total relationships and churn (line plot) | ||
|
||
```{r} | ||
library(tidymodels) | ||
library(tidyverse) | ||
library(scales) | ||
library(lubridate) | ||
library(stacks) | ||
theme_set(theme_light()) | ||
doParallel::registerDoParallel(cores = 4) | ||
``` | ||
|
||
```{r} | ||
education_levels <- c("Unknown", "Uneducated", "High School", "College", "Graduate", "Doctorate", "Post-Graduate") | ||
``` | ||
|
||
```{r} | ||
process_data <- function(tbl) { | ||
tbl %>% | ||
mutate(education_level = fct_relevel(education_level, education_levels), | ||
income_category = fct_reorder(income_category, parse_number(income_category)), | ||
income_category = fct_relevel(income_category, "Less than $40K")) | ||
} | ||
``` | ||
|
||
|
||
```{r} | ||
dataset <- read_csv("~/Downloads/sliced-s01e07-HmPsw2/train.csv") %>% | ||
mutate(churned = factor(ifelse(attrition_flag == 1, "yes", "no"))) %>% | ||
select(-attrition_flag) %>% | ||
process_data() | ||
holdout <- read_csv("~/Downloads/sliced-s01e07-HmPsw2/test.csv") %>% | ||
process_data() | ||
set.seed(2021) | ||
spl <- initial_split(dataset) | ||
train <- training(spl) | ||
test <- testing(spl) | ||
train_5fold <- train %>% | ||
vfold_cv(5) | ||
``` | ||
|
||
```{r} | ||
summarize_churn <- function(tbl) { | ||
tbl %>% | ||
summarize(n = n(), | ||
n_churned = sum(churned == "yes"), | ||
pct_churned = n_churned / n, | ||
low = qbeta(.025, n_churned + .5, n - n_churned + .5), | ||
high = qbeta(.975, n_churned + .5, n - n_churned + .5)) %>% | ||
arrange(desc(n)) | ||
} | ||
plot_categorical <- function(tbl, category, ...) { | ||
if (!is.factor(pull(tbl, {{ category }}))) { | ||
tbl <- tbl %>% | ||
mutate({{ category }} := fct_reorder({{ category }}, pct_churned)) | ||
} | ||
tbl %>% | ||
ggplot(aes(pct_churned, {{ category }}, ...)) + | ||
geom_col(position = position_dodge()) + | ||
geom_errorbarh(position = position_dodge(width = 1), aes(xmin = low, xmax = high), height = .2) + | ||
scale_x_continuous(labels = percent) + | ||
labs(x = "% in category that churned") | ||
} | ||
train %>% | ||
group_by(gender) %>% | ||
summarize_churn() %>% | ||
plot_categorical(gender) | ||
train %>% | ||
group_by(education_level, gender) %>% | ||
summarize_churn() %>% | ||
plot_categorical(education_level, fill = gender, group = gender) | ||
train %>% | ||
group_by(education_level) %>% | ||
summarize_churn() %>% | ||
plot_categorical(education_level) + | ||
coord_flip() + | ||
labs(x = "% churned", | ||
y = "Education", | ||
title = "How does churn risk differ by education") | ||
train %>% | ||
group_by(income_category) %>% | ||
summarize_churn() %>% | ||
plot_categorical(income_category) + | ||
coord_flip() + | ||
labs(x = "% churned", | ||
y = "Income category", | ||
title = "How does churn risk differ by income category") | ||
qplot(train$customer_age) | ||
train %>% | ||
group_by(age = cut(customer_age, c(20, 40, 50, 60, 70))) %>% | ||
summarize_churn() %>% | ||
plot_categorical(age) + | ||
coord_flip() | ||
``` | ||
|
||
Let's look at all numerics! | ||
|
||
```{r} | ||
gathered <- train %>% | ||
select(customer_age, total_relationship_count:churned) %>% | ||
gather(metric, value, -churned) %>% | ||
mutate(churned = str_to_title(churned)) | ||
gathered %>% | ||
ggplot(aes(value, fill = churned)) + | ||
geom_density(alpha = .5) + | ||
facet_wrap(~ metric, scales = "free") + | ||
scale_x_log10() + | ||
labs(title = "How do numeric metrics differ between churn/didn't churn?", | ||
subtitle = "All predictors on a log scale", | ||
x = "Log(predictor)", | ||
fill = "Churned?") | ||
gathered %>% | ||
group_by(metric) %>% | ||
mutate(rank = percent_rank(value)) %>% | ||
ggplot(aes(rank, fill = churned)) + | ||
geom_density(alpha = .5) + | ||
facet_wrap(~ metric, scales = "free") + | ||
labs(title = "How do numeric metrics differ between churn/didn't churn?", | ||
subtitle = "All predictors on a percentile scale", | ||
x = "Percentile of predictor", | ||
fill = "Churned?") | ||
``` | ||
|
||
Let's start building models! | ||
|
||
XGBoost on ones that are clearly predictive or low cardinality. | ||
|
||
```{r} | ||
mset <- metric_set(mn_log_loss) | ||
control <- control_grid(save_workflow = TRUE, | ||
save_pred = TRUE, | ||
extract = extract_model) | ||
``` | ||
|
||
Trained a model on all the variables, with income and education turned into ordinal variables and "unknown" imputed to the mean (revisit that choice). | ||
|
||
```{r} | ||
# Let's try 10-fold CV | ||
set.seed(2021-07-13) | ||
train_10fold <- train %>% | ||
vfold_cv(10) | ||
factor_to_ordinal <- function(x) { | ||
ifelse(x == "Unknown", NA, as.integer(x)) | ||
} | ||
``` | ||
|
||
```{r} | ||
# Try all numeric, and turn categorical into ordinal | ||
xg_rec <- recipe(churned ~ ., | ||
data = train) %>% | ||
step_mutate(avg_transaction_amt = total_trans_amt / total_trans_ct) %>% | ||
step_mutate(income_category = factor_to_ordinal(income_category), | ||
education_level = factor_to_ordinal(education_level)) %>% | ||
step_impute_mean(all_numeric_predictors()) %>% | ||
step_dummy(all_nominal_predictors()) %>% | ||
step_rm(id) %>% | ||
# For now, remove the formerly categorical ones (low importance) | ||
step_rm(income_category, education_level) | ||
xg_wf <- workflow(xg_rec, | ||
boost_tree("classification", | ||
trees = tune(), | ||
mtry = tune(), | ||
tree_depth = tune(), | ||
learn_rate = tune())) | ||
# Trying out tree depth | ||
xg_tune <- xg_wf %>% | ||
tune_grid(train_10fold, | ||
metrics = mset, | ||
control = control, | ||
grid = crossing(trees = seq(400, 1300, 20), | ||
mtry = c(2, 3), | ||
tree_depth = c(5), | ||
learn_rate = c(.018, .02))) | ||
autoplot(xg_tune) | ||
xg_tune %>% | ||
collect_metrics() %>% | ||
arrange(mean) | ||
``` | ||
|
||
Next model: try removing total_trans_amt (now that we have avg_transaction_amt) | ||
|
||
```{r} | ||
xg_fit <- xg_wf %>% | ||
finalize_workflow(select_best(xg_tune)) %>% | ||
fit(train) | ||
xg_fit %>% | ||
augment(test, type.predict = "prob") %>% | ||
mn_log_loss(churned, .pred_no) | ||
importances <- xgboost::xgb.importance(model = extract_fit_engine(xg_fit)) | ||
importances %>% | ||
mutate(Feature = fct_reorder(Feature, Gain)) %>% | ||
ggplot(aes(Gain, Feature)) + | ||
geom_col() | ||
``` | ||
|
||
Training our first xgboost model! A few selected predictors, mostly numeric. | ||
|
||
CV log loss on: | ||
|
||
* selected ~7 predictors is about .11 | ||
* Adding in rest of predictors, around .087 | ||
* Removing the (formerly) categorical ones, .0853 | ||
* Adding in average transaction amount: 0.085 (on holdout 0.0849) | ||
|
||
Everything except education + income, added average transaction amount, mtry = 3, 780 trees, learn rate = .02 | ||
|
||
Attempt 2: | ||
|
||
* Everything except education + income, added average transaction amount, mtry = 3, 1040 trees, learn rate = .02, tree depth = 5. On test set, .0835. | ||
|
||
```{r} | ||
xg_fit_full <- xg_wf %>% | ||
finalize_workflow(select_best(xg_tune)) %>% | ||
fit(dataset) | ||
# Last check before out the door | ||
xgboost::xgb.importance(model = extract_fit_engine(xg_fit_full)) | ||
xg_fit_full %>% | ||
augment(holdout) %>% | ||
select(id, attrition_flag = .pred_yes) %>% | ||
write_csv("~/Desktop/attempt2.csv") | ||
``` | ||
|
||
Plan for second hour: more tuning on learn rate and max depth, and more EDA while that trains. | ||
|
||
Soon, will do an variable importance plot on the full prediction. | ||
|
||
```{r} | ||
``` | ||
|
||
|
||
|
||
|
||
|
Oops, something went wrong.
19f8c0c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question- I’m using the development (GitHub) version of the workflows package! As of just this week,
workflow()
can take arguments for recipe and model.You can install the dev version with:
For the second one, I'm sorry but I don't recognize the issue 😦 . Is it possible you overwrote the
test
at some point? It had been set up astest <- testing(spl)
.