From aa1d0b288dfbc5bb42714a7a6d792b94c2bdb7ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kcen=20Eraslan?= Date: Sat, 10 Nov 2018 23:08:40 -0500 Subject: [PATCH] Simplify LDA input parameterization. --- misc/cluster/lda/corr-lda.stan | 30 ++++++++++++--------- misc/cluster/lda/lda.data.R | 48 ++++++---------------------------- misc/cluster/lda/lda.stan | 20 ++++++++------ misc/cluster/lda/sim-lda.R | 16 ++++++------ 4 files changed, 45 insertions(+), 69 deletions(-) diff --git a/misc/cluster/lda/corr-lda.stan b/misc/cluster/lda/corr-lda.stan index 14c4d9da0..7f38220f5 100644 --- a/misc/cluster/lda/corr-lda.stan +++ b/misc/cluster/lda/corr-lda.stan @@ -2,28 +2,26 @@ data { int K; // num topics int V; // num words int M; // num docs - int N; // total word instances - int w[N]; // word n - int doc[N]; // doc ID for word n + int corpus[M,V]; // word freq matrix, doc x word vector[V] beta; // word prior } parameters { vector[K] mu; // topic mean corr_matrix[K] Omega; // correlation matrix vector[K] sigma; // scales - vector[K] eta[M]; // logit topic dist for doc m + vector[K] eta[M]; // logit topic dist for doc m simplex[V] phi[K]; // word dist for topic k } transformed parameters { simplex[K] theta[M]; // simplex topic dist for doc m cov_matrix[K] Sigma; // covariance matrix for (m in 1:M) - theta[m] <- softmax(eta[m]); + theta[m] = softmax(eta[m]); for (m in 1:K) { - Sigma[m,m] <- sigma[m] * sigma[m] * Omega[m,m]; + Sigma[m,m] = sigma[m] * sigma[m] * Omega[m,m]; for (n in (m+1):K) { - Sigma[m,n] <- sigma[m] * sigma[n] * Omega[m,n]; - Sigma[n,m] <- Sigma[m,n]; + Sigma[m,n] = sigma[m] * sigma[n] * Omega[m,n]; + Sigma[n,m] = Sigma[m,n]; } } } @@ -38,10 +36,16 @@ model { for (m in 1:M) eta[m] ~ multi_normal(mu,Sigma); // token probabilities - for (n in 1:N) { - real gamma[K]; - for (k in 1:K) - gamma[k] <- log(theta[doc[n],k]) + log(phi[k,w[n]]); - increment_log_prob(log_sum_exp(gamma)); // likelihood + for (i in 1:M) { + for (j in 1:V) { + int count = corpus[i,j]; + real gamma[K]; + if (count > 0) { + for (k in 1:K) { + gamma[k] = (log(theta[i,k]) + log(phi[k,j]))*count; + } + increment_log_prob(log_sum_exp(gamma)); // likelihood + } + } } } diff --git a/misc/cluster/lda/lda.data.R b/misc/cluster/lda/lda.data.R index 32e476246..d2108b0f8 100644 --- a/misc/cluster/lda/lda.data.R +++ b/misc/cluster/lda/lda.data.R @@ -4,46 +4,14 @@ V <- 5 M <- 25 -N <- -262 -w <- -c(4L, 3L, 5L, 4L, 3L, 3L, 3L, 3L, 3L, 4L, 5L, 3L, 4L, 4L, 5L, -3L, 4L, 4L, 4L, 3L, 5L, 4L, 5L, 2L, 3L, 3L, 1L, 5L, 5L, 1L, 4L, -3L, 1L, 2L, 5L, 4L, 4L, 3L, 5L, 4L, 2L, 4L, 5L, 3L, 4L, 1L, 4L, -4L, 3L, 2L, 1L, 2L, 1L, 2L, 2L, 2L, 1L, 2L, 2L, 3L, 1L, 2L, 2L, -4L, 4L, 5L, 4L, 5L, 5L, 4L, 3L, 5L, 4L, 4L, 4L, 2L, 2L, 1L, 1L, -2L, 1L, 3L, 1L, 2L, 1L, 1L, 1L, 3L, 2L, 3L, 3L, 5L, 4L, 5L, 4L, -3L, 5L, 4L, 2L, 2L, 2L, 1L, 3L, 2L, 1L, 3L, 1L, 3L, 1L, 1L, 2L, -1L, 2L, 2L, 4L, 4L, 4L, 5L, 5L, 4L, 4L, 5L, 4L, 3L, 3L, 3L, 1L, -3L, 3L, 4L, 2L, 1L, 3L, 4L, 4L, 5L, 4L, 4L, 4L, 3L, 4L, 3L, 4L, -5L, 1L, 2L, 1L, 3L, 2L, 1L, 1L, 2L, 3L, 3L, 3L, 3L, 4L, 1L, 4L, -4L, 4L, 4L, 3L, 4L, 4L, 1L, 2L, 2L, 3L, 3L, 1L, 1L, 4L, 1L, 3L, -1L, 5L, 3L, 2L, 2L, 1L, 1L, 2L, 3L, 3L, 4L, 4L, 5L, 3L, 4L, 3L, -1L, 5L, 5L, 5L, 3L, 3L, 4L, 5L, 3L, 3L, 3L, 2L, 3L, 1L, 3L, 3L, -1L, 3L, 1L, 5L, 5L, 5L, 2L, 2L, 3L, 3L, 3L, 1L, 1L, 5L, 5L, 5L, -3L, 1L, 5L, 4L, 1L, 3L, 3L, 3L, 3L, 4L, 2L, 5L, 1L, 3L, 5L, 2L, -5L, 5L, 2L, 1L, 3L, 3L, 5L, 3L, 5L, 3L, 3L, 5L, 1L, 2L, 2L, 1L, -1L, 2L, 1L, 2L, 3L, 1L, 1L) -doc <- -c(1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, -2L, 2L, 2L, 2L, 2L, 2L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, -3L, 3L, 3L, 3L, 3L, 3L, 3L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, -4L, 4L, 4L, 4L, 4L, 5L, 5L, 5L, 5L, 5L, 5L, 5L, 5L, 5L, 5L, 5L, -6L, 6L, 6L, 6L, 6L, 6L, 6L, 7L, 7L, 7L, 7L, 7L, 8L, 8L, 8L, 8L, -8L, 8L, 8L, 8L, 8L, 8L, 8L, 8L, 8L, 8L, 8L, 8L, 9L, 9L, 9L, 9L, -9L, 9L, 9L, 10L, 10L, 10L, 10L, 10L, 10L, 10L, 10L, 10L, 10L, -10L, 10L, 10L, 10L, 10L, 10L, 11L, 11L, 11L, 11L, 11L, 11L, 12L, -12L, 12L, 12L, 13L, 13L, 13L, 13L, 13L, 13L, 13L, 13L, 13L, 14L, -14L, 14L, 14L, 14L, 14L, 14L, 14L, 14L, 14L, 14L, 15L, 15L, 15L, -15L, 15L, 15L, 15L, 15L, 15L, 15L, 15L, 16L, 16L, 16L, 16L, 16L, -16L, 16L, 16L, 16L, 16L, 17L, 17L, 17L, 17L, 17L, 17L, 17L, 17L, -17L, 18L, 18L, 18L, 18L, 18L, 18L, 18L, 18L, 18L, 18L, 19L, 19L, -19L, 19L, 19L, 19L, 19L, 19L, 19L, 20L, 20L, 20L, 20L, 20L, 20L, -20L, 20L, 20L, 21L, 21L, 21L, 21L, 21L, 21L, 21L, 21L, 21L, 21L, -22L, 22L, 22L, 22L, 22L, 22L, 22L, 22L, 23L, 23L, 23L, 23L, 23L, -23L, 23L, 23L, 23L, 23L, 24L, 24L, 24L, 24L, 24L, 24L, 24L, 24L, -24L, 24L, 24L, 24L, 24L, 24L, 24L, 24L, 24L, 24L, 24L, 24L, 24L, -25L, 25L, 25L, 25L, 25L, 25L, 25L, 25L, 25L, 25L, 25L) +corpus <- +structure(c(4, 1, 1, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 3, 0, +1, 0, 0, 1, 4, 0, 1, 1, 1, 1, 6, 6, 0, 0, 1, 1, 0, 0, 1, 0, 0, +0, 3, 2, 1, 1, 1, 1, 0, 2, 1, 3, 2, 2, 5, 6, 1, 2, 6, 3, 1, 2, +3, 6, 1, 5, 1, 4, 6, 1, 3, 3, 3, 1, 2, 3, 2, 0, 1, 0, 0, 2, 4, +2, 7, 2, 1, 7, 0, 6, 2, 4, 0, 0, 1, 1, 1, 2, 2, 0, 0, 0, 1, 0, +2, 1, 0, 4, 7, 0, 1, 1, 3, 2, 4, 3, 3, 5, 0, 0, 1, 1, 2, 2, 0, +0, 0, 2), .Dim = c(25L, 5L)) alpha <- c(0.5, 0.5) beta <- diff --git a/misc/cluster/lda/lda.stan b/misc/cluster/lda/lda.stan index 39fae61a6..4bc121752 100644 --- a/misc/cluster/lda/lda.stan +++ b/misc/cluster/lda/lda.stan @@ -2,9 +2,7 @@ data { int K; // num topics int V; // num words int M; // num docs - int N; // total word instances - int w[N]; // word n - int doc[N]; // doc ID for word n + int corpus[M,V]; // word freq matrix, doc x word vector[K] alpha; // topic prior vector[V] beta; // word prior } @@ -17,10 +15,16 @@ model { theta[m] ~ dirichlet(alpha); // prior for (k in 1:K) phi[k] ~ dirichlet(beta); // prior - for (n in 1:N) { - real gamma[K]; - for (k in 1:K) - gamma[k] <- log(theta[doc[n],k]) + log(phi[k,w[n]]); - increment_log_prob(log_sum_exp(gamma)); // likelihood + for (i in 1:M) { + for (j in 1:V) { + int count = corpus[i,j]; + real gamma[K]; + if (count > 0) { + for (k in 1:K) { + gamma[k] <- (log(theta[i,k]) + log(phi[k,j]))*count; + } + increment_log_prob(log_sum_exp(gamma)); // likelihood + } + } } } diff --git a/misc/cluster/lda/sim-lda.R b/misc/cluster/lda/sim-lda.R index 7eae66524..29c0f65bb 100644 --- a/misc/cluster/lda/sim-lda.R +++ b/misc/cluster/lda/sim-lda.R @@ -8,6 +8,8 @@ phi <- array(NA,c(2,5)); phi[1,] = c(0.330, 0.330, 0.330, 0.005, 0.005); phi[2,] = c(0.005, 0.005, 0.330, 0.330, 0.330); +set.seed(123) + M <- 25; # docs avg_doc_length <- 10; doc_length <- rpois(M,avg_doc_length); @@ -18,16 +20,14 @@ beta <- rep(1/V,V); theta <- rdirichlet(M,alpha); -w <- rep(NA,N); -doc <- rep(NA,N); -n <- 1; +corpus <- matrix(0, nrow = M, ncol = V); for (m in 1:M) { for (i in 1:doc_length[m]) { - z <- which(rmultinom(1,1,theta[m,]) == 1); - w[n] <- which(rmultinom(1,1,phi[z,]) == 1); - doc[n] <- m; - n <- n + 1; + topic_id <- which(rmultinom(1,1,theta[m,]) == 1); + word_id <- which(rmultinom(1,1,phi[topic_id,]) == 1); + corpus[m, word_id] = corpus[m, word_id] + 1; } } +stopifnot(all(rowSums(corpus) == doc_length)); -dump(c("K","V","M","N","z","w","doc","alpha","beta"),"lda.data.R"); +dump(c("K","V","M","corpus","alpha","beta"),"lda.data.R"); \ No newline at end of file