-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlabel_switching.R
100 lines (88 loc) · 2.42 KB
/
label_switching.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
prob <- function(i, j, theta) {
x <- data[i]
probs <- sapply(1:k, function(s) {
theta$weight[s] * dvonmises(x, mu=theta$mu[s], kappa=theta$kappa[s])
})
prob <- probs[j] / sum(probs)
return(prob)
}
apply.permutation <- function(theta, permutation) {
list(
mu=theta$mu[permutation],
kappa=theta$kappa[permutation],
weight=theta$weight[permutation]
# FIXME Include z / allocation
)
}
apply.permutation.list <- function(theta, permutation) {
lapply(1:(length(theta)-1), function(t) { # FIXME Why length(theta)-1 ?
th <- theta[[t]]
apply.permutation(th, permutation[t, ])
})
}
LS.UpdateAssignmentProb <- function(samples, permutation, k) { # Step 1
n <- length(samples[[1]]$z)
N <- length(samples)
vector <- sapply(1:n, function(i) {
sapply(1:k, function(j) {
# Expression (17) Stephens
som <- sum(sapply(1:N, function(t) {
theta <- samples[[t]]
theta <- apply.permutation(theta, permutation[t, ])
prob(i, j, theta)
}))
som/N
})
})
q <- t(vector)
q
}
LS.UpdatePermutation <- function(samples, q, k) { # Step 2
n <- length(samples[[1]]$z)
N <- length(samples)
selected.permutation = matrix(rep(NA, k*N), ncol=k)
for (t in 1:N) {
lowest.som <- -1
for (permutation in permn(1:k)) {
# Expression (16) Stephens
som <- 0
for (i in 1:n) {
for (j in 1:k) {
theta <- samples[[t]]
theta <- apply.permutation(theta, permutation)
p <- prob(i, j, theta)
val <- p * log(p/q[i, j])
som <- som + val
}
}
if (lowest.som < 0 || som < lowest.som) {
selected.permutation[t, ] <- permutation
lowest.som <- som
}
}
}
selected.permutation
}
PostProcessLabels <- function(samples) {
k <- samples[[1]]$k
N <- length(samples)
n <- length(samples[[1]]$z)
permutation <- matrix(rep(1:k, each=N), ncol=k)
permutation.prev <- NA
for (repetition in 1:100) {
cat("Repetition", repetition, "\n")
q <- LS.UpdateAssignmentProb(samples, permutation, k)
cat("Step 1 done\n")
permutation <- LS.UpdatePermutation(samples, q, k)
cat("Step 2 done\n")
if (repetition > 1) {
diff <- sum(permutation != permutation.prev)
cat("Diff:", diff, "\n")
if (diff == 0) {
break
}
}
permutation.prev <- permutation
}
return(apply.permutation.list(samples), permutation)
}