Skip to content

Commit

Permalink
Merge pull request #83 from koheiw/dev-new-words
Browse files Browse the repository at this point in the history
Dev new words
  • Loading branch information
koheiw authored Aug 17, 2024
2 parents 37b96bd + b0a68f6 commit f52bd94
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 13 deletions.
3 changes: 1 addition & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ LinkingTo: Rcpp, RcppArmadillo (>= 0.7.600.1.0), quanteda, testthat
Suggests:
spelling,
testthat,
topicmodels,
keyATM
topicmodels
RoxygenNote: 7.3.1
Roxygen: list(markdown = TRUE)
Language: en-US
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## Changes in v1.3.3

- Add `update_model` to update terms of existing models to classify documents with unseen words more accurately.

## Changes in v1.3.2

- Improve the way to convert std::vector to arma::mat.
Expand Down
19 changes: 15 additions & 4 deletions R/lda.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#' document is affected by the previous document's topics.
#' @param model a fitted LDA model; if provided, `textmodel_lda()` inherits
#' parameters from an existing model. See details.
#' @param update_model if `TRUE`, update the terms of `model` to recognize unseen
#' words.
#' @details If `auto_iter = TRUE`, the iteration stops even before `max_iter`
#' when `delta <= 0`. `delta` is computed to measure the changes in the number
#' of words whose topics are updated by the Gibbs sampler in every 100
Expand Down Expand Up @@ -79,21 +81,30 @@
#' }
textmodel_lda <- function(
x, k = 10, max_iter = 2000, auto_iter = FALSE, alpha = 0.5, beta = 0.1, gamma = 0,
model = NULL, batch_size = 1.0, verbose = quanteda_options("verbose")
model = NULL, update_model = FALSE, batch_size = 1.0,
verbose = quanteda_options("verbose")
) {
UseMethod("textmodel_lda")
}

#' @export
textmodel_lda.dfm <- function(
x, k = 10, max_iter = 2000, auto_iter = FALSE, alpha = 0.5, beta = 0.1, gamma = 0,
model = NULL, batch_size = 1.0, verbose = quanteda_options("verbose")
model = NULL, update_model = FALSE, batch_size = 1.0,
verbose = quanteda_options("verbose")
) {

if (!is.null(model)) {
if (!is.textmodel_lda(model))
stop("model must be a fitted textmodel_lda")
x <- dfm_match(x, colnames(model$phi))

words <- model$words
if (update_model) {
words <- t(dfm_match(as.dfm(t(words)), featnames(x)))
} else {
x <- dfm_match(x, rownames(words))
}

k <- model$k
label <- rownames(model$phi)
alpha <- model$alpha
Expand All @@ -103,7 +114,7 @@ textmodel_lda.dfm <- function(
} else {
gamma <- 0
}
words <- model$words

warning("k, alpha, beta and gamma values are overwritten by the fitted model", call. = FALSE)
} else {
label <- paste0("topic", seq_len(k))
Expand Down
1 change: 0 additions & 1 deletion R/seededlda.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#' A Semi-supervised Algorithm for Topic-specific Analysis of Sentences".
#' doi:10.1177/08944393231178605. *Social Science Computer Review*.
#' @returns The same as [textmodel_lda()] with extra elements for `dictionary`.
#' @seealso [keyATM][keyATM::keyATM]
#' @examples
#' \donttest{
#' require(seededlda)
Expand Down
4 changes: 4 additions & 0 deletions man/textmodel_lda.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions man/textmodel_seededlda.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 12 additions & 3 deletions tests/testthat/test-textmodel_lda.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ test_that("LDA works with empty documents", {
test_that("model argument works with LDA", {
skip_on_cran()

dfmt_train <- head(dfmt, 450)
dfmt_test <- tail(dfmt, 50)
dfmt_train <- dfm_trim(head(dfmt, 450))
dfmt_test <- dfm_trim(tail(dfmt, 50))

# fit new model
lda <- textmodel_lda(dfmt_train, k = 5)
Expand All @@ -203,12 +203,21 @@ test_that("model argument works with LDA", {
expect_warning({
lda2 <- textmodel_lda(dfmt_test, model = lda)
}, "k, alpha, beta and gamma values are overwritten by the fitted model")
expect_false(all(lda$phi == lda2$phi))
expect_identical(dimnames(lda$phi), dimnames(lda2$phi))
expect_equal(
levels(topics(lda2)),
c("topic1", "topic2", "topic3", "topic4", "topic5")
)

# out-of-sample with new words
expect_warning({
lda3 <- textmodel_lda(dfmt_test, model = lda, update_model = TRUE)
}, "k, alpha, beta and gamma values are overwritten by the fitted model")
expect_false(identical(dimnames(lda$phi), dimnames(lda3$phi)))
expect_equal(
levels(topics(lda3)),
c("topic1", "topic2", "topic3", "topic4", "topic5")
)
})

test_that("select and min_prob are working", {
Expand Down

0 comments on commit f52bd94

Please sign in to comment.