Skip to content

Commit

Permalink
Merge pull request #78 from koheiw/add-perplexity
Browse files Browse the repository at this point in the history
Add perplexity
  • Loading branch information
koheiw authored May 27, 2024
2 parents 0062e6f + 770a63b commit 6882c24
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 24 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: seededlda
Type: Package
Title: Seeded Sequential LDA for Topic Modeling
Version: 1.2.1
Version: 1.3.0
Authors@R: c(person("Kohei", "Watanabe", email = "[email protected]", role = c("aut", "cre", "cph")),
person("Phan", "Xuan-Hieu", email = "[email protected]", role = c("aut", "cph"), comment = "GibbsLDA++"))
Description: Seeded Sequential LDA can classify sentences of texts into pre-define topics with a small number of seed words (Watanabe & Baturo, 2023) <doi:10.1177/08944393231178605>.
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by roxygen2: do not edit by hand

S3method(divergence,textmodel_lda)
S3method(perplexity,textmodel_lda)
S3method(predict,textmodel_lda)
S3method(print,textmodel_lda)
S3method(sizes,textmodel_lda)
Expand All @@ -10,6 +11,7 @@ S3method(textmodel_seededlda,dfm)
S3method(textmodel_seqlda,dfm)
S3method(topics,textmodel_lda)
export(divergence)
export(perplexity)
export(sizes)
export(terms)
export(textmodel_lda)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## Changes in v1.3.0

- Add `perplexity()` to compute perplexity scores of fitted LDA models.

## Changes in v1.2.1

- Fix tests on systems when the TBB library is unavailable.
Expand Down
4 changes: 2 additions & 2 deletions R/lda.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ textmodel_lda.dfm <- function(
gamma <- 0
}
words <- model$words
warning("k, alpha and beta values are overwriten by the fitted model", call. = FALSE)
warning("k, alpha, beta and gamma values are overwritten by the fitted model", call. = FALSE)
} else {
label <- paste0("topic", seq_len(k))
words <- NULL
Expand Down Expand Up @@ -162,7 +162,7 @@ lda <- function(x, k, label, max_iter, auto_iter, alpha, beta, gamma,
dimnames(result$theta) <- list(rownames(x), label)
result$data <- x
result$batch_size <- batch_size
result$call <- match.call(sys.function(-2), call = sys.call(-2))
result$call <- try(match.call(sys.function(-2), call = sys.call(-2)), silent = TRUE)
result$version <- utils::packageVersion("seededlda")
class(result) <- c("textmodel_lda", "textmodel", "list")
return(result)
Expand Down
58 changes: 49 additions & 9 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
#' Optimize the number of topics
#' Optimize the number of topics for LDA
#'
#' `divergence()` computes the regularized topic divergence to find the optimal
#' number of topics for LDA.
#' `divergence()` computes the regularized topic divergence scores to help users
#' to find the optimal number of topics for LDA.
#' @param x a LDA model fitted by [textmodel_seededlda()] or [textmodel_lda()].
#' @param min_size the minimum size of topics for regularized topic divergence.
#' Ignored when `regularize = FALSE`.
#' @param select names of topics for which the divergence is computed.
#' @param x a LDA model fitted by [textmodel_seededlda()] or [textmodel_lda()].
#' @param regularize if `TRUE`, returns the regularized divergence.
#' @param newdata if provided, `theta` and `phi` are estimated through fresh
#' Gibbs sampling.
#' @param ... additional arguments passed to [textmodel_lda].
#' @details `divergence()` computes the average Jensen-Shannon divergence
#' between all the pairs of topic vectors in `x$phi`. The divergence score
#' maximizes when the chosen number of topic `k` is optimal (Deveaud et al.,
#' 2014). The regularized divergence penalizes topics smaller than `min_size`
#' to avoid fragmentation (Watanabe & Baturo, forthcoming).
#' @seealso [sizes]
#' @seealso [perplexity]
#' @references Deveaud, Romain et al. (2014). "Accurate and Effective Latent
#' Concept Modeling for Ad Hoc Information Retrieval".
#' doi:10.3166/DN.17.1.61-84. *Document Numérique*.
#'
#' Watanabe, Kohei & Baturo, Alexander. (2023). "Seeded Sequential LDA:
#' A Semi-supervised Algorithm for Topic-specific Analysis of Sentences".
#' Watanabe, Kohei & Baturo, Alexander. (2023). "Seeded Sequential LDA: A
#' Semi-supervised Algorithm for Topic-specific Analysis of Sentences".
#' doi:10.1177/08944393231178605. *Social Science Computer Review*.
#' @export
divergence <- function(x, min_size = 0.01, select = NULL,
regularize = TRUE) {
regularize = TRUE, newdata = NULL, ...) {
UseMethod("divergence")
}

#' @importFrom proxyC dist
#' @export
divergence.textmodel_lda <- function(x, min_size = 0.01, select = NULL,
regularize = TRUE) {
regularize = TRUE, newdata = NULL, ...) {

min_size <- check_double(min_size, min = 0, max = 1)
select <- check_character(select, min_len = 2, max_len = nrow(x$phi),
strict = TRUE, allow_null = TRUE)
regularize <- check_logical(regularize, strict = TRUE)

if (!is.null(newdata)) {
suppressWarnings({
x <- textmodel_lda(newdata, model = x, ...)
})
}

if (is.null(select)) {
l <- rep(TRUE, nrow(x$phi))
} else {
Expand All @@ -56,6 +65,37 @@ divergence.textmodel_lda <- function(x, min_size = 0.01, select = NULL,
sum(as.matrix(div[l, l]) * w, na.rm = TRUE) + (min_size ^ 2)
}

#' Optimize the hyper-parameters for LDA
#'
#' `perplexity()` computes the perplexity score to help users to chose the
#' optimal values of hyper-parameters for LDA.
#' @param x a LDA model fitted by [textmodel_seededlda()] or [textmodel_lda()].
#' @param newdata if provided, `theta` and `phi` are estimated through fresh
#' Gibbs sampling.
#' @param ... additional arguments passed to [textmodel_lda].
#' @details `perplexity()` predicts the distribution of words in the dfm based
#' on `x$alpha` and `x$gamma` and then compute the sum of disparity between their
#' predicted and observed frequencies. The perplexity score minimizes when the
#' chosen values of hyper-parameters such as `k`, `alpha` and `gamma` are
#' optimal.
#' @seealso [divergence]
#' @export
perplexity <- function(x, newdata = NULL, ...) {
UseMethod("perplexity")
}

#' @export
perplexity.textmodel_lda <- function(x, newdata = NULL, ...) {
if (!is.null(newdata)) {
suppressWarnings({
x <- textmodel_lda(newdata, model = x, ...)
})
}
#exp(-sum(log(x$theta %*% x$phi[,featnames(x$data)]) * x$data) / sum(x$data))
mat <- as(x$data, "TsparseMatrix")
exp(-sum(log(colSums(x$phi[,mat@j + 1] * t(x$theta)[,mat@i + 1])) * mat@x) / sum(mat@x))
}


#' Compute the sizes of topics
#'
Expand Down
26 changes: 19 additions & 7 deletions man/divergence.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions man/perplexity.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/textmodel_lda.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-textmodel_lda.R
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ test_that("model argument works with LDA", {
# in-sample prediction
expect_warning({
lda1 <- textmodel_lda(dfmt_train[1:50,], model = lda)
}, "k, alpha and beta values are overwriten by the fitted model")
}, "k, alpha, beta and gamma values are overwritten by the fitted model")
expect_false(all(lda$phi == lda1$phi))
expect_identical(dimnames(lda$phi), dimnames(lda1$phi))
expect_true(mean(topics(lda)[1:50] == topics(lda1)) > 0.8)
Expand All @@ -202,7 +202,7 @@ test_that("model argument works with LDA", {
# out-of-sample prediction
expect_warning({
lda2 <- textmodel_lda(dfmt_test, model = lda)
}, "k, alpha and beta values are overwriten by the fitted model")
}, "k, alpha, beta and gamma values are overwritten by the fitted model")
expect_false(all(lda$phi == lda2$phi))
expect_identical(dimnames(lda$phi), dimnames(lda2$phi))
expect_equal(
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-textmodel_seededlda.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ test_that("model argument works with seeded LDA", {
# in-sample prediction
expect_warning({
lda1 <- textmodel_lda(dfmt_train[1:50,], model = lda)
}, "k, alpha and beta values are overwriten by the fitted model")
}, "k, alpha, beta and gamma values are overwritten by the fitted model")
expect_false(all(lda$phi == lda1$phi))
expect_identical(dimnames(lda$phi), dimnames(lda1$phi))
expect_gt(mean(topics(lda)[1:50] == topics(lda1)), 0.8)
Expand All @@ -162,7 +162,7 @@ test_that("model argument works with seeded LDA", {
# out-of-sample prediction
expect_warning({
lda2 <- textmodel_lda(dfmt_test, model = lda)
}, "k, alpha and beta values are overwriten by the fitted model")
}, "k, alpha, beta and gamma values are overwritten by the fitted model")
expect_false(all(lda$phi == lda2$phi))
expect_identical(dimnames(lda$phi), dimnames(lda2$phi))
expect_equal(
Expand Down
47 changes: 47 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,53 @@ slda <- textmodel_seededlda(dfmt, dict, residual = TRUE, weight = 0.02,

test_that("divergence() is working", {

# in-sample
div1 <- divergence(lda)
expect_equal(div1, 0.33, tolerance = 0.01)

toks_val <- tokens(data_corpus_moviereviews[501:600],
remove_punct = TRUE,
remove_symbols = TRUE,
remove_number = TRUE)
dfmt_val <- dfm(toks_val) %>%
dfm_remove(stopwords(), min_nchar = 2) %>%
dfm_trim(max_docfreq = 0.1, docfreq_type = "prop")

# out-sample
set.seed(1234)
expect_output(
div2 <- divergence(lda, newdata = dfmt_val, max_iter = 100, verbose = TRUE),
"Fitting LDA with 5 topics.*"
)
expect_equal(div2, 0.34, tolerance = 0.01)
})

test_that("perplexity() is working", {

# in-sample
ppl1 <- perplexity(lda)
expect_equal(ppl1, 7742, tolerance = 1)

toks_val <- tokens(data_corpus_moviereviews[501:600],
remove_punct = TRUE,
remove_symbols = TRUE,
remove_number = TRUE)
dfmt_val <- dfm(toks_val) %>%
dfm_remove(stopwords(), min_nchar = 2) %>%
dfm_trim(max_docfreq = 0.1, docfreq_type = "prop")

# out-sample
set.seed(1234)
expect_output(
ppl2 <- perplexity(lda, newdata = dfmt_val, max_iter = 100, verbose = TRUE),
"Fitting LDA with 5 topics.*"
)
expect_equal(ppl2, 7534, tolerance = 1)
})


test_that("regularize is working", {

# LDA
expect_equal(divergence(lda),
0.34, tolerance = 0.01)
Expand Down

0 comments on commit 6882c24

Please sign in to comment.