Skip to content

Commit

Permalink
Merge branch 'master' into add-test-adjust
Browse files Browse the repository at this point in the history
  • Loading branch information
koheiw authored Sep 4, 2024
2 parents fb58ab7 + f0d18b8 commit 18481da
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 5 deletions.
3 changes: 2 additions & 1 deletion R/lda.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@
#' \item{k}{the number of topics.}
#' \item{last_iter}{the number of iterations in Gibbs sampling.}
#' \item{max_iter}{the maximum number of iterations in Gibbs sampling.}
#' \item{auto_iter}{`auto_iter` is used if `TRUE`.}
#' \item{auto_iter}{the use of `auto_iter`}
#' \item{adjust_alpha}{the value of `adjust_alpha`.}
#' \item{alpha}{the smoothing parameter for `theta`.}
#' \item{beta}{the smoothing parameter for `phi`.}
#' \item{epsilon}{the amount of adjustment for `adjust_alpha`.}
Expand Down
1 change: 1 addition & 0 deletions R/seededlda.R
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ tfm <- function(x, dictionary, levels = 1,
if (!quanteda::is.dictionary(dictionary))
stop("dictionary must be a dictionary object", call. = FALSE)

docvars(x) <- NULL # sanitize dfm
dict <- flatten_dictionary(dictionary, levels)
key <- names(dict)
feat <- featnames(x)
Expand Down
3 changes: 2 additions & 1 deletion man/textmodel_lda.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/lda.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <RcppArmadillo.h>
#include "lib.h"
#include "dev.h"
#include "lda.h"
Expand Down Expand Up @@ -38,6 +39,7 @@ List cpp_lda(arma::sp_mat &mt, int k, int max_iter, double min_delta,
Rcpp::Named("max_iter") = lda.max_iter,
Rcpp::Named("last_iter") = lda.iter,
Rcpp::Named("auto_iter") = (lda.min_delta == 0),
Rcpp::Named("adjust_alpha") = lda.adjust,
Rcpp::Named("alpha") = as<NumericVector>(wrap(lda.alpha)),
Rcpp::Named("beta") = as<NumericVector>(wrap(lda.beta)),
Rcpp::Named("epsilon") = as<NumericVector>(wrap(lda.epsilon)),
Expand Down
18 changes: 18 additions & 0 deletions tests/testthat/test-internal.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,24 @@ test_that("tfm works with ngrams", {
c("un" = 2, "icc" = 2, "other" = 0))
})

test_that("tfm works with dfm with x in docvars (#87)", {

dict <- dictionary(list("A" = "a", "B" = "b"))
dat <- data.frame(text = c("a b c", "A B C"),
x = c(1, 2))
corp <- corpus(dat)
toks <- tokens(corp)
dfmt <- dfm(toks)

expect_equal(
as.matrix(seededlda:::tfm(dfmt, dict, residula = 1)),
matrix(c(2, 0, 0, 0, 2, 0, 0, 0 ,0), nrow = 3,
dimnames = list(c("A", "B", "other"), c("a", "b", "c")))
)

})


test_that("levels is working", {

dict <- dictionary(list(A = list(
Expand Down
4 changes: 3 additions & 1 deletion tests/testthat/test-textmodel_lda.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ test_that("LDA is working", {
)
expect_equal(
names(lda),
c("k", "max_iter", "last_iter", "auto_iter", "alpha", "beta", "epsilon", "gamma",
c("k", "max_iter", "last_iter", "auto_iter", "adjust_alpha",
"alpha", "beta", "epsilon", "gamma",
"phi", "theta", "words", "data", "batch_size", "call", "version")
)
expect_equal(lda$last_iter, 200)
Expand Down Expand Up @@ -124,6 +125,7 @@ test_that("adjust_alpha works", {
expect_equivalent(rowSums(lda$theta), rep(1.0, ndoc(lda$data)))
expect_equivalent(rowSums(lda$phi), rep(1.0, lda$k))

expect_equal(lda$adjust_alpha, 0.5)
expect_true(all(lda$alpha != 0.5))
expect_true(all(lda$epsilon > 0))

Expand Down
5 changes: 3 additions & 2 deletions tests/testthat/test-textmodel_seededlda.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ test_that("seeded LDA is working", {
)
expect_equal(
names(lda),
c("k", "max_iter", "last_iter", "auto_iter", "alpha", "beta", "epsilon", "gamma",
"phi", "theta", "words", "data", "batch_size", "call", "version",
c("k", "max_iter", "last_iter", "auto_iter", "adjust_alpha",
"alpha", "beta", "epsilon", "gamma", "phi", "theta",
"words", "data", "batch_size", "call", "version",
"dictionary", "valuetype", "case_insensitive", "seeds",
"residual", "weight")
)
Expand Down

0 comments on commit 18481da

Please sign in to comment.