-
Notifications
You must be signed in to change notification settings - Fork 284
/
18_ANOVAonewayPyMC.py
138 lines (115 loc) · 3.94 KB
/
18_ANOVAonewayPyMC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
One way BANOVA
"""
from __future__ import division
import numpy as np
import pymc3 as pm
import pandas as pd
import matplotlib.pyplot as plt
plt.style.use('seaborn-darkgrid')
from scipy.stats import norm
from hpd import *
from theano import tensor as T
# THE DATA.
# Specify data source:
dataSource = ["McDonaldSK1991" , "SolariLS2008" , "Random"][0]
# Load the data:
if dataSource == "McDonaldSK1991":
datarecord = pd.read_csv("McDonaldSK1991data.txt", sep='\s+', skiprows=18, skipfooter=25)
y = datarecord['Size']
Ntotal = len(y)
x = (datarecord['Group'] - 1).values
xnames = pd.unique(datarecord['Site'])
NxLvl = len(xnames)
contrast_dict = {'BIGvSMALL':[-1/3,-1/3,1/2,-1/3,1/2],
'ORE1vORE2': [1,-1,0,0,0],
'ALAvORE':[-1/2,-1/2,1,0,0],
'NPACvORE':[-1/2,-1/2,1/2,1/2,0],
'USAvRUS':[1/3,1/3,1/3,-1,0],
'FINvPAC':[-1/4,-1/4,-1/4,-1/4,1],
'ENGvOTH':[1/3,1/3,1/3,-1/2,-1/2],
'FINvRUS':[0,0,0,-1,1]}
if dataSource == "SolariLS2008":
datarecord = pd.read_csv("SolariLS2008data.txt", sep='\s+', skiprows=21)
y = datarecord['Acid']
Ntotal = len(y)
x = (datarecord['Type'] - 1).values
xnames = pd.unique(x)
NxLvl = len(xnames)
contrast_dict = {'G3vOTHER':[-1/8,-1/8,1,-1/8,-1/8,-1/8,-1/8,-1/8,-1/8]}
if dataSource == "Random":
np.random.seed(47405)
ysdtrue = 4.0
a0true = 100
atrue = [2, -2] # sum to zero
npercell = 8
x = []
y = []
for xidx in range(len(atrue)):
for subjidx in range(npercell):
x.append(xidx)
y.append(a0true + atrue[xidx] + norm.rvs(1, ysdtrue))
Ntotal = len(y)
NxLvl = len(set(x))
# # Construct list of all pairwise comparisons, to compare with NHST TukeyHSD:
contrast_dict = None
for g1idx in range(NxLvl):
for g2idx in range(g1idx+1, NxLvl):
cmpVec = np.repeat(0, NxLvl)
cmpVec[g1idx] = -1
cmpVec[g2idx] = 1
contrast_dict = (contrast_dict, cmpVec)
z = (y - np.mean(y))/np.std(y)
## THE MODEL.
with pm.Model() as model:
# define the hyperpriors
a_SD_unabs = pm.StudentT('a_SD_unabs', mu=0, lam=0.001, nu=1)
a_SD = abs(a_SD_unabs) + 0.1
atau = 1 / a_SD**2
# define the priors
sigma = pm.Uniform('sigma', 0, 10) # y values are assumed to be standardized
tau = 1 / sigma**2
a0 = pm.Normal('a0', mu=0, tau=0.001) # y values are assumed to be standardized
a = pm.Normal('a', mu=0 , tau=atau, shape=NxLvl)
b = pm.Deterministic('b', a - T.mean(a))
mu = a0 + b[x]
# define the likelihood
yl = pm.Normal('yl', mu, tau=tau, observed=z)
# Generate a MCMC chain
trace = pm.sample(2000, progressbar=False)
# EXAMINE THE RESULTS
burnin = 1000
thin = 10
# Print summary for each trace
#pm.summary(trace[burnin::thin])
#pm.summary(trace)
# Check for mixing and autocorrelation
#pm.autocorrplot(trace[burnin::thin], vars=model.unobserved_RVs[:-1])
## Plot KDE and sampled values for each parameter.
#pm.traceplot(trace[burnin::thin])
pm.traceplot(trace)
a0_sample = trace['a0'][burnin::thin]
b_sample = trace['b'][burnin::thin]
b0_sample = a0_sample * np.std(y) + np.mean(y)
b_sample = b_sample * np.std(y)
plt.figure(figsize=(20, 4))
for i in range(5):
ax = plt.subplot(1, 5, i+1)
pm.plot_posterior(b_sample[:,i], bins=50, ax=ax)
ax.set_xlabel(r'$\beta1_{}$'.format(i))
ax.set_title('x:{}'.format(i))
plt.tight_layout()
plt.savefig('Figure_18.2a.png')
nContrasts = len(contrast_dict)
if nContrasts > 0:
plt.figure(figsize=(20, 8))
count = 1
for key, value in contrast_dict.items():
contrast = np.dot(b_sample, value)
ax = plt.subplot(2, 4, count)
pm.plot_posterior(contrast, ref_val=0.0, bins=50, ax=ax)
ax.set_title('Contrast {}'.format(key))
count += 1
plt.tight_layout()
plt.savefig('Figure_18.2b.png')
plt.show()