-
Notifications
You must be signed in to change notification settings - Fork 0
/
02b_LME_ModelValidation.R
91 lines (75 loc) · 3.49 KB
/
02b_LME_ModelValidation.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# library(dplyr)
# library(ggplot2)
# library(lme4)
# library(reshape2)
# Read data and models ####
modeldata <- readRDS("SampleDataset_withLMMPreds.RDS")
# Assess models ####
# GLM predictions range from 0 to 1, inclusive, because of overfitting to distributors
summary(glm.model$fitted.values)
# LMM predictions have a more believable range
summary(fitted(lmm.model))
# Very similar fixed effects (coefficients) between models
summary(glm.model)$coef[setdiff(1:length(coef(glm.model)), grep("DistCode", names(coef(glm.model)))),]
summary(lmm.model)
# Validation ####
# Log-loss metric comparison
# There is some edge to using the lmm with DistCode over ignoring DistCode altogether
# i.e. the mixed effects model fits the holdout data better than the GLM without dist code
modeldata %>%
filter(Sample == "holdout") %>%
summarize(lmm.logloss = -sum(Surr*log(lmm.pred) + (1 - Surr)*log(1 - lmm.pred)),
glm.logloss = -sum(Surr*log(glm.pred) + (1 - Surr)*log(1 - glm.pred)),
glm.base.logloss = -sum(Surr*log(glm.pred.base) + (1 - Surr)*log(1 - glm.pred.base)))
# Two-way lift chart
# Compare the A/E of competing models in areas where those models disagree
# The LMM (blue) has an A/E closer to 1 in the extremes of the x-axis (where there is model disagreement)
twoway.plot <- modeldata %>%
filter(Sample == "holdout") %>%
mutate(PredRatio = lmm.pred/glm.pred.base) %>%
group_by(PredRatio.bucket = ntile(PredRatio, 20)) %>%
summarize(AoverE.glm.base = sum(Surr)/sum(glm.pred.base),
AoverE.lmm = sum(Surr)/sum(lmm.pred))
twoway.plot %>%
melt(id.vars = "PredRatio.bucket", measure.vars = c("AoverE.glm.base", "AoverE.lmm")) %>%
ggplot(aes(x = PredRatio.bucket, y = value)) +
geom_line(aes(color = variable)) +
geom_line(aes(y = 1), linetype = 2, color = "black")
# Create shrinkage plot ####
# Line up DistCode coefficients
coef.df <- data.frame(DistCode = gsub("DistCode", "", grep("DistCode", names(coef(glm.model)), value = T)),
Coef = coef(glm.model)[grep("DistCode", names(coef(glm.model)))]) %>%
left_join(data.frame(DistCode = rownames(coef(lmm.model)$DistCode),
Coef = coef(lmm.model)$DistCode$`(Intercept)`),
by = "DistCode",
suffix = c("_GLM", "_LMM")) %>%
mutate(Coef_GLM = Coef_GLM + coef(glm.model)[1]) # include global intercept for each distributor effect
dist.N <- modeldata %>%
group_by(DistCode) %>%
summarize(N = n()) %>%
ungroup()
C <- coef.df %>%
melt(id.vars = "DistCode",
measure.vars = c("Coef_GLM", "Coef_LMM")) %>%
mutate(y = ifelse(variable == "Coef_GLM", 1, 0),
DistCode = factor(DistCode)) %>%
left_join(dist.N, "DistCode")
C %>%
ggplot(aes(x = value, y = y, group = DistCode)) +
geom_point() +
geom_line(aes(size = log(N)), alpha = 0.25, color = "navy") +
xlab("Log-odds coefficient (LMM)") +
ylab("") +
ggtitle("Log-odds coefficient (GLM)") +
theme(plot.title = element_text(hjust = 0.5, size = 12),
axis.title.x = element_text(size = 12),
axis.text.y = element_blank()) +
scale_size_continuous(guide = F, range = c(0.25, 1))
ggsave("ShrinkagePlot.png", width = 7, height = 4)
# Distributor summary ####
# Range of coefficients
summary(coef.df$Coef_LMM)
# Effect multiplier between distributors with max and min coefficients
exp(max(coef.df$Coef_LMM) - min(coef.df$Coef_LMM))
# Effect multiplier between 97.5th percentile distributor and 2.5th percentile
exp(quantile(coef.df$Coef_LMM, 0.975) - quantile(coef.df$Coef_LMM, 0.025))