Skip to content

Commit

Permalink
Initial commit of ML classification workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
Sydeaka Watson authored and Sydeaka Watson committed Feb 21, 2023
0 parents commit 3194afe
Show file tree
Hide file tree
Showing 9 changed files with 2,636 additions and 0 deletions.
1 change: 1 addition & 0 deletions .Rprofile
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
source("renv/activate.R")
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.Rproj.user
.Rhistory
.RData
.Ruserdata
model_results/*
renv/*
22 changes: 22 additions & 0 deletions R/load_libraries.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Load the packages
suppressPackageStartupMessages(library(dplyr))
suppressPackageStartupMessages(library(ggplot2))
suppressPackageStartupMessages(library(testthat))
suppressPackageStartupMessages(library(tidymodels))
suppressPackageStartupMessages(library(yardstick))
suppressPackageStartupMessages(library(DT))
suppressPackageStartupMessages(library(pROC))
suppressPackageStartupMessages(library(DALEXtra))

suppressPackageStartupMessages(library(here))
suppressPackageStartupMessages(library(roxygen2))
suppressPackageStartupMessages(library(openxlsx))
suppressPackageStartupMessages(library(gridExtra))
suppressPackageStartupMessages(library(tidymodels))
suppressPackageStartupMessages(library(vip))
suppressPackageStartupMessages(library(ranger))
suppressPackageStartupMessages(library(lightgbm))
suppressPackageStartupMessages(library(bonsai))
suppressPackageStartupMessages(library(xgboost))
suppressPackageStartupMessages(library(DataExplorer))
suppressPackageStartupMessages(library(DALEXtra))
338 changes: 338 additions & 0 deletions R/ml_helper_functions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@

###############################################################################################
### These tidymodels machine learning utilities were created by
### Sydeaka P. Watson, PhD of Korelasi Data Insights, LLC
### These functions are open source and are available in the following GitHub repository:
### https://github.com/korelasidata/tidymodels-ML-workflow
###############################################################################################




zero_variance <- function(vals) {
vals %>%
.[is.na(vals)] %>%
length(unique(.)) == 1
}

remove_zero_variance_fields <- function(dat) {
zv_fields <- sapply(dat, zero_variance) %>% .[. == TRUE] %>% names

if (length(zv_fields) == 0) {
log_info("All fields had some variability. Returning dataframe with no changes.")
return(dat)
} else {
log_info("The following fields were identified as having zero variance: {paste(zv_fields, collapse=', ')}")
fields_to_keep <- colnames(dat)[!(colnames(dat) %in% zv_fields)]
dat <- dat %>% select_at(fields_to_keep)
log_info("Fields successfully removed")
}

return(dat)
}


model_specifications <- list(
"xgboost" = boost_tree(engine = 'xgboost', trees = tune(),
tree_depth = tune(), min_n = tune(), learn_rate = tune(),
mtry = tune()),

"gbm" = boost_tree(engine = 'lightgbm', trees = tune(),
tree_depth = tune(), min_n = tune(), learn_rate = tune(),
mtry = tune()),

"random_forest" = rand_forest(trees = tune(), min_n = tune(), mtry = tune()) %>%
set_engine("ranger", importance = "impurity")
# set_engine("randomForest", importance = TRUE)
)



get_model_config <- function(model_formula, model_specifications, selected_algorithm, model_mode) {

model_spec <- model_specifications[[selected_algorithm]] %>%
set_mode(model_mode)

model_wflow <- workflow(model_formula, model_spec)

if (selected_algorithm == "xgboost") {
model_param_grid <- model_wflow %>%
extract_parameter_set_dials() %>%
update(
trees = trees(c(100, 1500)),
learn_rate = learn_rate(c(.00005, .5), trans= NULL),
tree_depth = tree_depth(c(6, 20)),
min_n = min_n(c(10, 60)),
mtry = mtry(c(5, 40))
)
}

if (selected_algorithm == "gbm") {
model_param_grid <- model_wflow %>%
extract_parameter_set_dials() %>%
update(
trees = trees(c(100, 1500)),
learn_rate = learn_rate(c(.00005, .5), trans= NULL),
tree_depth = tree_depth(c(6, 20)),
min_n = min_n(c(10, 60)),
mtry = mtry(c(5, 40))
)
}


if (selected_algorithm == "random_forest") {
model_param_grid <- model_wflow %>%
extract_parameter_set_dials() %>%
update(
trees = trees(c(100, 1500)),
min_n = min_n(c(10, 60)),
mtry = mtry(c(5, 40))
)
}


rtn <- list(
model_spec = model_spec,
model_wflow = model_wflow,
model_param_grid = model_param_grid
)

return(rtn)

}

get_varimp <- function(selected_algorithm, final_model_fit, engine_specific_model_fit=NULL) {
if (selected_algorithm %in% c("xgboost")) {
df_varimp <- final_model_fit %>%
extract_fit_parsnip() %>%
vip::vi(object=.) %>%
mutate(PctImportance = round(Importance / sum(Importance) * 100, 2))

plot_varimp <- final_model_fit %>%
extract_fit_parsnip() %>%
vip::vip(geom = "col") +
theme_bw()
}


if (selected_algorithm %in% c("random_forest")) {
# ranger varimp
df_varimp <- final_model_fit %>%
extract_fit_parsnip() %>%
vip::vi(object=.) %>%
mutate(PctImportance = round(Importance / sum(Importance) * 100, 2))

plot_varimp <- final_model_fit %>%
extract_fit_parsnip() %>%
vip::vip(geom = "col") +
theme_bw()


# randomForest varimp
# type = either 1 or 2, specifying the type of importance measure
# (1 = mean decrease in accuracy, 2 = mean decrease in node impurity).
# df_varimp <- engine_specific_model_fit %>%
# importance(type=2) %>%
# data.frame(Variable = rownames(.), .) %>%
# set_colnames(c("Variable", "Importance")) %>%
# mutate(PctImportance = round(Importance / sum(Importance) * 100, 2)) %>%
# arrange(desc(PctImportance))
#
# plot_varimp <- df_varimp %>%
# head(10) %>%
# ggplot(aes(x = reorder(Variable, PctImportance), y = PctImportance)) +
# geom_bar(stat = "identity", col = "black", show.legend = F) +
# coord_flip() +
# scale_fill_grey() +
# theme_bw() +
# ggtitle("Top 10 attributes") +
# xlab("") + ylab("% importance")
}



if (selected_algorithm %in% c("gbm")) {
tree_imp <- engine_specific_model_fit %>%
lgb.importance(percentage = TRUE)

df_varimp <- final_model_fit %>%
rename(Variable = Feature, Importance = Gain) %>%
select(Variable, Importance) %>%
mutate(PctImportance = round(Importance / sum(Importance) * 100, 2)) %>%
arrange(desc(PctImportance))

plot_varimp <- df_varimp %>%
head(10) %>%
ggplot(aes())
}




if (selected_algorithm == "gbm") {
# Applying varimp utils specific to lightgbm
tree_imp <- engine_specific_model_fit %>%
lgb.importance(percentage = TRUE)

df_varimp <- tree_imp %>%
rename(Variable = Feature, Importance = Gain) %>%
select(Variable, Importance) %>%
mutate(PctImportance = round(Importance / sum(Importance) * 100, 2)) %>%
arrange(desc(PctImportance))

plot_varimp <- df_varimp %>%
head(10) %>%
ggplot(aes(x = reorder(Variable, PctImportance), y = PctImportance)) +
geom_bar(stat = "identity", col = "black", show.legend = F) +
coord_flip() +
scale_fill_grey() +
theme_bw() +
ggtitle("Top 10 attributes") +
xlab("") + ylab("% importance")
}

return(list(
df_varimp = df_varimp,
plot_varimp = plot_varimp
))

}


plot_confusion_matrix <- function() {
return(NULL)

#cm <- pred_df %>% yardstick::conf_mat(Category, .pred_class)

# Now compute the average confusion matrix across all folds in
# terms of the proportion of the data contained in each cell.
# First get the raw cell counts per fold using the `tidy` method
library(tidyr)

cells_per_resample <- pred_df %>%
group_by(id) %>%
conf_mat(truth=Category, estimate=.pred_class) %>%
mutate(tidied = lapply(conf_mat, tidy)) %>%
unnest(tidied)

# Get the totals per resample
counts_per_resample <- pred_df %>%
group_by(id) %>%
summarize(total = n()) %>%
left_join(cells_per_resample, by = "id") %>%
# Compute the proportions
mutate(prop = value/total) %>%
group_by(name) %>%
# Average
summarize(prop = mean(prop))

counts_per_resample

# Now reshape these into a matrix
mean_cmat <- matrix(counts_per_resample$prop, byrow = TRUE, ncol = 4)
rownames(mean_cmat) <- levels(hpc_cv$obs)
colnames(mean_cmat) <- levels(hpc_cv$obs)

round(mean_cmat, 3)

# The confusion matrix can quickly be visualized using autoplot()
library(ggplot2)

autoplot(cm, type = "mosaic")
autoplot(cm, type = "heatmap")



cm <- caret::confusionMatrix(pred_df$.pred_class, pred_df$Category)
cm_d <- as.data.frame(cm$table) # extract the confusion matrix values as data.frame
cm_st <-data.frame(cm$overall) # confusion matrix statistics as data.frame
cm_st$cm.overall <- round(cm_st$cm.overall,2) # round the values
cm_d$diag <- cm_d$Prediction == cm_d$Reference # Get the Diagonal
cm_d$ndiag <- cm_d$Prediction != cm_d$Reference # Off Diagonal
cm_d[cm_d == 0] <- NA # Replace 0 with NA for white tiles
#cm_d$Reference <- reverse.levels(cm_d$Reference) # diagonal starts at top left
cm_d$ref_freq <- cm_d$Freq * ifelse(is.na(cm_d$diag),-1,1)


plt1 <- ggplot(data = cm_d, aes(x = Prediction , y = Reference, fill = Freq))+
scale_x_discrete(position = "top") +
geom_tile( data = cm_d,aes(fill = ref_freq)) +
scale_fill_gradient2(guide = FALSE ,low="red3",high="orchid4", midpoint = 0,na.value = 'white') +
geom_text(aes(label = Freq), color = 'black', size = 3)+
theme_bw() +
theme(panel.grid.major = element_blank(), panel.grid.minor = element_blank(),
legend.position = "none",
panel.border = element_blank(),
plot.background = element_blank(),
axis.line = element_blank(),
)

plt2 <- tableGrob(cm_st)

# TO DO: Need to export this plot somehow. `grid.arrange` plots to console only. Value of `plot_predictions` is a tableGrob
plot_predictions <- grid.arrange(plt1, plt2, nrow = 1, ncol = 2,
top = grid::textGrob("Confusion Matrix",
gp = grid::gpar(fontsize=25,font=1)))
}











plot_param <- function(metric_df, param, metric_name='rmse') {
metric_df %>%
filter(.metric == metric_name) %>%
arrange_at('mean') %>%
ggplot(aes_string(x=param, y='mean')) +
geom_point() +
xlab(param) +
ylab('') +
ggtitle(glue::glue("{metric_name} vs {param}"))
}









# Helper function to get a single model fit on a bootstrap resample
fit_model_on_bootstrap <- function(split, best_wflow) {
best_wflow %>%
fit(data = analysis(split))
}



# Helper function to get prediction intervals
# boot_model <- boot_models$model[[1]]
# input_data <- dat_train_and_val[1:3,]
# predict(boot_model, new_data = input_data)
bootstrap_pred_intervals <- function(boot_models, input_data, lower_pct = .05, upper_pct = 0.95) {
# Get predictions on all input cases using all bootstrap models
pred_df <- boot_models %>%
mutate(preds = map(model, \(mod) predict(mod, new_data=input_data)))

# Combine predictions across bootstraps into a matrix
pred_matrix <- bind_cols(pred_df$preds, .name_repair="minimal") %>%
as.matrix %>% t

# Compute upper and lower confidence bounds
pred_intervals <- pred_matrix %>% apply(2, quantile, probs=c(lower_pct, upper_pct)) %>% t

return(pred_intervals)
}

# bootstrap_pred_intervals(boot_models, input_data, lower_pct = .05, upper_pct = 0.95)




Loading

0 comments on commit 3194afe

Please sign in to comment.