forked from lei-zhang/socialRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
reinforcement_learning_HBA.R
99 lines (81 loc) · 3.28 KB
/
reinforcement_learning_HBA.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# =============================================================================
#### Info ####
# =============================================================================
# simple reinforcement learning model
#
# true parameters: lr = rnorm(10, mean=0.6, sd=0.12); tau = rnorm(10, mean=1.5, sd=0.2)
#
# Lei Zhang
# =============================================================================
#### Construct Data ####
# =============================================================================
# clear workspace
library(rstan)
library(ggplot2)
library(R.matlab)
load('data/PPC/rl_mp.RData')
sz <- dim(rl_mp)
nSubjects <- sz[1]
nTrials <- sz[2]
dataList <- list(nSubjects=nSubjects,
nTrials=nTrials,
choice=rl_mp[,,1],
reward=rl_mp[,,2])
# =============================================================================
#### Running Stan ####
# =============================================================================
rstan_options(auto_write = TRUE)
options(mc.cores = 2)
modelFile <- 'code/rl_ppc.stan'
nIter <- 2000
nChains <- 4
nWarmup <- floor(nIter/2)
nThin <- 1
cat("Estimating", modelFile, "model... \n")
startTime = Sys.time(); print(startTime)
cat("Calling", nChains, "simulations in Stan... \n")
fit_rl <- stan(modelFile,
data = dataList,
chains = nChains,
iter = nIter,
warmup = nWarmup,
thin = nThin,
init = "random",
seed = 1450154637
)
cat("Finishing", modelFile, "model simulation ... \n")
endTime = Sys.time(); print(endTime)
cat("It took",as.character.Date(endTime - startTime), "\n")
# =============================================================================
#### Model Summary and Diagnostics ####
# =============================================================================
print(fit_rl)
plot_dens_lr <- stan_plot(fit_rl, pars=c('lr_mu','lr'), show_density=T, fill_color = 'skyblue')
plot_dens_tau <- stan_plot(fit_rl, pars=c('tau_mu','tau'), show_density=T, fill_color = 'skyblue')
print(plot_dens_lr)
print(plot_dens_tau)
# =============================================================================
#### Violin plot of posterior means ####
# =============================================================================
pars_value <- get_posterior_mean(fit_rl, pars=c('lr','tau'))[,5]
pars_name <- as.factor(c(rep('lr',10),rep('tau',10)))
df <- data.frame(pars_value=pars_value, pars_name=pars_name)
myconfig <- theme_bw(base_size = 20) +
theme(panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.background = element_blank() )
data_summary <- function(x) {
m <- mean(x)
ymin <- m-sd(x)
ymax <- m+sd(x)
return(c(y=m,ymin=ymin,ymax=ymax))
}
g1 <- ggplot(df, aes(x=pars_name, y=pars_value, color = pars_name, fill=pars_name))
g1 <- g1 + geom_violin(trim=TRUE, size=2)
g1 <- g1 + stat_summary(fun.data=data_summary, geom="pointrange", color="black", size=1.5)
g1 <- g1 + scale_fill_manual(values=c("#2179b5", "#c60256"))
g1 <- g1 + scale_color_manual(values=c("#2179b5", "#c60256"))
g1 <- g1 + myconfig + theme(legend.position="none")
g1 <- g1 + labs(x = '', y = 'parameter value') + ylim(0.3,2.2)
print(g1)