-
Notifications
You must be signed in to change notification settings - Fork 1
/
multilevel_small_example.Rmd
138 lines (107 loc) · 4.31 KB
/
multilevel_small_example.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
---
title: "R Notebook"
output: html_notebook
---
```{r setup}
library(rstanarm)
library(tidybayes)
library(tidyverse)
library(cowplot)
library(scico)
theme_set(theme_cowplot())
output_dir <- here::here("local_temp_data")
if(!dir.exists(output_dir)) {
dir.create(output_dir)
}
options(mc.cores = parallel::detectCores())
```
```{r, fig.width = 3, fig.height = 5}
set.seed(842223)
batches <- c("A","B", "C")
base_data <- tibble(batch = batches, var_intercept = c(0.1, 0, -0.1), var_effect = rnorm(length(batches), 0, 0.1)) %>%
crossing(tibble(group = c("control", "treatment"), b = c(0, 1))) %>%
crossing(tibble(replicate = 1:3)) %>%
#filter(batch == "A" | replicate < 3) %>%
mutate(log_relative_expression = var_intercept + b* (0.25) + rnorm(n(), 0, 0.25),
relative_expression = exp(log_relative_expression))
base_plot <- function(data) {
set.seed(9882264) #Force seed to maintain identical jitter
data %>%
ggplot(aes(x = group, y = log_relative_expression, color = group, fill = group)) +
scale_y_continuous("Log(Relative expression)") +
# ggplot(aes(x = group, y = log_relative_expression, color = group, fill = group)) +
# scale_y_continuous("Log(Relative expression)") +
theme(axis.title.x = element_blank())
}
add_my_point <- function(plot) {
plot +
geom_point( position = position_jitter(width = 0.05), size = 4, shape = 21) +
scale_color_manual(values = c("#c46e9b","#05a8aa"), guide = FALSE) +
scale_fill_manual(values = c("#f49ac8", "#5fc7c8"), guide = FALSE)
}
mean_se_log <- function(x) {
exp(mean_se(log(x)))
}
add_error_bars <- function(plot, nudge = -0.3, width = 0.2) {
position = position_nudge(x = nudge)
plot +
stat_summary(fun.data = "mean_se", geom = "errorbar", width = width, color = "#707070", position = position) +
stat_summary(fun.y = "mean", geom = "point", shape = 3, size = 3, stroke = 3, color = "#707070", position = position)
}
plot_together_points <- base_data %>%
base_plot() %>%
add_my_point()
plot_together_points
ggsave(paste0(output_dir, "/together_points.png"), plot_together_points, width = 3, height = 5)
plot_together <- base_data %>%
base_plot() %>%
add_error_bars() %>%
add_my_point()
plot_together
ggsave(paste0(output_dir, "/together.png"), plot_together, width = 3, height = 5)
```
```{r, fig.width = 6, fig.height = 5}
plot_separate <- base_data %>% base_plot() %>%
add_error_bars() %>% add_my_point() +
facet_wrap(~batch)
plot_separate
ggsave(paste0(output_dir, "/separate.png"), plot_separate, width = 6, height = 5)
```
```{r}
fit <- stan_lmer(log_relative_expression ~ 1 + group + (1 | batch), data = base_data)
```
```{r}
summary(fit, probs = c(0.025,0.5, 0.975))
```
```{r}
t.test(log_relative_expression ~ group, data = base_data)
car::Anova(lm(log_relative_expression ~ batch + group, data = base_data))
mean(as.array(fit)[,,"grouptreatment"] < 0)
```
```{r}
predict_data <- tibble(batch = batches) %>% crossing(group = unique(base_data$group))
s <- posterior_linpred(fit, newdata = predict_data)
fit_to_plot <- predict_data %>% mutate(log_mean = colMeans(s), mean = exp(log_mean), sd = sqrt(diag(cov(s))), low = log_mean - sd, high = log_mean + sd)
fit_to_plot
```
```{r, fig.width = 6, fig.height = 5}
add_my_crossbar <- function(plot, crossbar_data) {
plot + geom_crossbar(aes(y = log_mean, ymin = low, ymax = high), data = crossbar_data, color = "#5d2e8c", fill = "white", width = 0.2, size = 2, position = position_nudge(x = 0.3))
}
plot_multilevel <- base_data %>% base_plot() %>%
add_error_bars() %>%
add_my_crossbar(fit_to_plot) %>% add_my_point() + facet_wrap(~batch)
plot_multilevel
ggsave(paste0(output_dir, "/separate_multilevel.png"), plot_multilevel, width = 6, height = 5)
```
```{r, fig.width=3, fig.height=5}
predict_data_new_level <- tibble(group = unique(base_data$group), batch = "NEW_BATCH")
s_n <- posterior_linpred(fit, newdata = predict_data_new_level)
fit_to_plot_n <- predict_data_new_level %>% mutate(log_mean = colMeans(s_n), mean = exp(log_mean), sd = sqrt(diag(cov(s_n))), low = log_mean - sd, high = log_mean + sd)
plot_multilevel_intercept <- base_data %>% base_plot() %>%
add_error_bars() %>%
add_my_crossbar(fit_to_plot_n) %>%
add_my_point()
plot_multilevel_intercept
ggsave(paste0(output_dir, "/separate_together.png"), plot_multilevel_intercept, width = 3, height = 5)
```