-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_plot.R
97 lines (84 loc) · 2.61 KB
/
model_plot.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
library(doParallel)
library(parallel)
library(caret)
library(ggplot2)
library(ggbeeswarm)
library(patchwork)
library(RColorBrewer)
# Register parallelization using all cores
registerDoParallel(cores = detectCores())
# Make cross validation object for caret
cv <- trainControl(method = "repeatedcv",
allowParallel = TRUE)
# Get features and target
x <- movie_clean %>% model.matrix(avg_rating ~ ., data = .)
y <- movie_clean %>% pull(avg_rating)
# Train different models
# NOTE: caret acts as a wrapper and needs underlying model libraries
# to be installed
set.seed(123)
elastic <- train(x, y,
method = "glmnet",
trControl = cv)
set.seed(123)
xgb <- train(x, y,
method = "xgbTree",
trControl = cv)
set.seed(123)
knn <- train(x, y,
method = "knn",
trControl = cv)
set.seed(123)
mars <- train(x, y,
method = "earth",
trControl = cv)
set.seed(123)
rf <- train(x, y,
method = "rf",
trControl = cv)
# Print MAE of mean prediction
(movie_clean$avg_rating - mean(movie_clean$avg_rating)) %>%
abs() %>%
mean()
# Obtain R-squareds and MAEs
results <- elastic$results %>%
select(MAE, Rsquared) %>%
mutate(model = "Elastic Net") %>%
rbind(xgb$results %>%
select(MAE, Rsquared) %>%
mutate(model = "XGBoost")) %>%
rbind(knn$results %>%
select(MAE, Rsquared) %>%
mutate(model = "KNN")) %>%
rbind(mars$results %>%
select(MAE, Rsquared) %>%
mutate(model = "MARS")) %>%
rbind(rf$results %>%
select(MAE, Rsquared) %>%
mutate(model = "Random Forest")) %>%
group_by(model) %>%
mutate(mean_Rsquared = mean(Rsquared),
mean_MAE = mean(MAE)) %>%
ungroup()
# Make the first plot with R-squareds
p1 <- results %>%
ggplot(aes(x = reorder(model, mean_Rsquared), y = Rsquared, color = model)) +
geom_quasirandom() +
ggtitle("Cross-validated accuracy measures for different models",
subtitle = "Using default hyperparameter search by the caret library") +
theme_bw() +
theme(legend.position = "none") +
scale_colour_brewer(type = "qual", palette = "Dark2") +
scale_y_continuous(limits = c(0, 1)) +
xlab("")
# Make the second plot with MAEs
p2 <- results %>%
ggplot(aes(x = reorder(model, mean_Rsquared), y = MAE, color = model)) +
geom_quasirandom() +
theme_bw() +
theme(legend.position = "none") +
scale_colour_brewer(type = "qual", palette = "Dark2") +
scale_y_continuous(limits = c(0, max(results$MAE))) +
xlab("")
# Plot vertically using patchwork
p1 / p2