Skip to content

Commit

Permalink
Merge pull request #370 from n-kall/pareto_diags_names
Browse files Browse the repository at this point in the history
Change pareto functions to return unnamed numerics
  • Loading branch information
paul-buerkner authored May 16, 2024
2 parents 86c8fba + 173f7d7 commit 4d3ccf1
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 57 deletions.
88 changes: 53 additions & 35 deletions R/pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' @template args-pareto
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
#' @return `khat` estimated Generalized Pareto Distribution shape parameter k
#' @template return-conv
#'
#' @seealso [`pareto_diags`] for additional related diagnostics, and
#' [`pareto_smooth`] for Pareto smoothed draws.
Expand Down Expand Up @@ -39,29 +39,27 @@ pareto_khat.default <- function(x,
verbose = verbose,
return_k = TRUE,
smooth_draws = FALSE,
are_log_weights = are_log_weights,
...)
return(smoothed$diagnostics)
are_log_weights = are_log_weights
)
return(smoothed$diagnostics$khat)
}

#' @rdname pareto_khat
#' @export
pareto_khat.rvar <- function(x, ...) {
draws_diags <- summarise_rvar_by_element_with_chains(
x,
pareto_smooth.default,
return_k = TRUE,
smooth_draws = FALSE,
pareto_khat.default,
...
)
dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags)
margins <- seq_along(dim(draws_diags))

diags <- list(
khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat)
khat = apply(draws_diags, margins, function(x) x[[1]])
)

diags
diags$khat
}


Expand Down Expand Up @@ -107,8 +105,10 @@ pareto_khat.rvar <- function(x, ...) {
#' when the sample size is increased, compared to the central limit
#' theorem convergence rate. See Appendix B in Vehtari et al. (2024).
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_smooth`] for Pareto smoothed draws.
#' @seealso [`pareto_khat`], [`pareto_min_ss`],
#' [`pareto_khat_threshold`], and [`pareto_convergence_rate`] for
#' individual diagnostics; and [`pareto_smooth`] for Pareto smoothing
#' draws.
#' @examples
#' mu <- extract_variable_matrix(example_draws(), "mu")
#' pareto_diags(mu)
Expand Down Expand Up @@ -151,21 +151,18 @@ pareto_diags.default <- function(x,
pareto_diags.rvar <- function(x, ...) {
draws_diags <- summarise_rvar_by_element_with_chains(
x,
pareto_smooth.default,
return_k = TRUE,
smooth_draws = FALSE,
extra_diags = TRUE,
pareto_diags.default,
...
)

dim(draws_diags) <- dim(draws_diags) %||% length(draws_diags)
margins <- seq_along(dim(draws_diags))

diags <- list(
khat = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat),
min_ss = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$min_ss),
khat_threshold = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$khat_threshold),
convergence_rate = apply(draws_diags, margins, function(x) x[[1]]$diagnostics$convergence_rate)
khat = apply(draws_diags, margins, function(x) x[[1]]$khat),
min_ss = apply(draws_diags, margins, function(x) x[[1]]$min_ss),
khat_threshold = apply(draws_diags, margins, function(x) x[[1]]$khat_threshold),
convergence_rate = apply(draws_diags, margins, function(x) x[[1]]$convergence_rate)
)

diags
Expand All @@ -192,13 +189,20 @@ pareto_diags.rvar <- function(x, ...) {
#' @template ref-vehtari-paretosmooth-2022
#' @return Either a vector `x` of smoothed values or a named list
#' containing the vector `x` and a named list `diagnostics`
#' containing Pareto smoothing diagnostics: * `khat`: estimated
#' Pareto k shape parameter, and optionally * `min_ss`: minimum
#' sample size for reliable Pareto smoothed estimate *
#' `khat_threshold`: khat-threshold for reliable Pareto smoothed
#' estimates * `convergence_rate`: Relative convergence rate for
#' containing numeric values:
#'
#' * `khat`: estimated Pareto k shape parameter, and optionally
#' * `min_ss`: minimum sample size for reliable Pareto smoothed
#' estimate
#' * `khat_threshold`: sample size specific khat threshold for
#' reliable Pareto smoothed estimates
#' * `convergence_rate`: Relative convergence rate for
#' Pareto smoothed estimates
#'
#' If any of the draws is non-finite, that is, `NA`, `NaN`, `Inf`, or
#' `-Inf`, Pareto smoothing will not be performed, and the original
#' draws will be returned and and diagnostics will be `NA` (numeric).
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_diags`] for additional diagnostics.
#' @examples
Expand Down Expand Up @@ -265,13 +269,27 @@ pareto_smooth.default <- function(x,
verbose <- as_one_logical(verbose)
are_log_weights <- as_one_logical(are_log_weights)

if (extra_diags) {
return_k <- TRUE
}

# check for infinite or na values
if (should_return_NA(x)) {
warning_no_call("Input contains infinite or NA values, or is constant. Fitting of generalized Pareto distribution not performed.")
if (!return_k) {
out <- x
} else if (!extra_diags) {
out <- list(x = x, diagnostics = list(khat = NA_real_))
} else {
out <- list(x = x, diagnostics = NA_real_)
out <- list(
x = x,
diagnostics = list(
khat = NA_real_,
min_ss = NA_real_,
khat_threshold = NA_real_,
convergence_rate = NA_real_
)
)
}
return(out)
}
Expand Down Expand Up @@ -379,13 +397,13 @@ pareto_khat_threshold <- function(x, ...) {
#' @rdname pareto_diags
#' @export
pareto_khat_threshold.default <- function(x, ...) {
c(khat_threshold = ps_khat_threshold(length(x)))
ps_khat_threshold(length(x))
}

#' @rdname pareto_diags
#' @export
pareto_khat_threshold.rvar <- function(x, ...) {
c(khat_threshold = ps_khat_threshold(ndraws(x)))
ps_khat_threshold(ndraws(x))
}

#' @rdname pareto_diags
Expand All @@ -397,15 +415,15 @@ pareto_min_ss <- function(x, ...) {
#' @rdname pareto_diags
#' @export
pareto_min_ss.default <- function(x, ...) {
k <- pareto_khat(x)$k
c(min_ss = ps_min_ss(k))
k <- pareto_khat(x)
ps_min_ss(k)
}

#' @rdname pareto_diags
#' @export
pareto_min_ss.rvar <- function(x, ...) {
k <- pareto_khat(x)$k
c(min_ss = ps_min_ss(k))
k <- pareto_khat(x)
ps_min_ss(k)
}

#' @rdname pareto_diags
Expand All @@ -417,15 +435,15 @@ pareto_convergence_rate <- function(x, ...) {
#' @rdname pareto_diags
#' @export
pareto_convergence_rate.default <- function(x, ...) {
k <- pareto_khat(x)$khat
c(convergence_rate = ps_convergence_rate(k, length(x)))
k <- pareto_khat(x)
ps_convergence_rate(k, length(x))
}

#' @rdname pareto_diags
#' @export
pareto_convergence_rate.rvar <- function(x, ...) {
k <- pareto_khat(x)
c(convergence_rate = ps_convergence_rate(k, ndraws(x)))
ps_convergence_rate(k, ndraws(x))
}


Expand Down Expand Up @@ -616,7 +634,7 @@ ps_convergence_rate <- function(k, ndraws, ...) {
}

#' Pareto tail length
#'
#'
#' Calculate the tail length from number of draws and relative efficiency
#' r_eff. See Appendix H in Vehtari et al. (2024). This function is
#' used internally and is exported to be available for other packages.
Expand Down
2 changes: 1 addition & 1 deletion man-roxygen/ref-vehtari-paretosmooth-2022.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#' Aki Vehtari, Daniel Simpson, Andrew Gelman, Yuling Yao and
#' Jonah Gabry (2024). Pareto Smoothed Importance Sampling.
#' *Journal of Machine Learning Research*, 25(72):1-58.
#' [PDF](https://jmlr.org/papers/v25/19-556.html)
#' [PDF](https://jmlr.org/papers/v25/19-556.html)
6 changes: 4 additions & 2 deletions man/pareto_diags.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 13 additions & 1 deletion man/pareto_khat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 13 additions & 5 deletions man/pareto_smooth.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 4 additions & 13 deletions tests/testthat/test-pareto_smooth.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
test_that("pareto_khat returns expected reasonable values", {
tau <- extract_variable_matrix(example_draws(), "tau")

pk <- pareto_khat(tau)
expect_true(names(pk) == "khat")

})


test_that("pareto_khat handles constant tail correctly", {

# left tail is constant, so khat should be NA, but for "both" it
Expand All @@ -30,16 +21,16 @@ test_that("pareto_khat handles tail argument", {
pkl <- pareto_khat(tau, tail = "left")
pkr <- pareto_khat(tau, tail = "right")
pkb <- pareto_khat(tau)
expect_true(pkl$khat < pkr$khat)
expect_equal(pkr$khat, pkb$khat)
expect_true(pkl < pkr)
expect_equal(pkr, pkb)
})

test_that("pareto_khat handles ndraws_tail argument", {

tau <- extract_variable_matrix(example_draws(), "tau")
pk10 <- pareto_khat(tau, tail = "right", ndraws_tail = 10)
pk25 <- pareto_khat(tau, tail = "right", ndraws_tail = 25)
expect_true(pk10$khat > pk25$khat)
expect_true(pk10 > pk25)

expect_warning(pareto_khat(tau, tail = "both", ndraws_tail = 201),
"Number of tail draws cannot be more than half ",
Expand All @@ -57,7 +48,7 @@ test_that("pareto_khat handles r_eff argument", {
tau <- extract_variable_matrix(example_draws(), "tau")
pk1 <- pareto_khat(tau, r_eff = 1)
pk0.6 <- pareto_khat(tau, r_eff = 0.6)
expect_true(pk1$khat < pk0.6$khat)
expect_true(pk1 < pk0.6)

})

Expand Down

0 comments on commit 4d3ccf1

Please sign in to comment.