Skip to content

Commit

Permalink
Merge pull request #87 from mrc-ide/negative-binomial
Browse files Browse the repository at this point in the history
Add Negative Binomial distribution
  • Loading branch information
richfitz authored Oct 18, 2024
2 parents e6e3d12 + c4d310a commit 3d876ce
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 78 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: monty
Title: Monte Carlo Models
Version: 0.2.19
Version: 0.2.20
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "[email protected]"),
person("Wes", "Hinsley", role = "aut"),
Expand Down
8 changes: 6 additions & 2 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ monty_rng_binomial <- function(ptr, n, r_size, r_prob, n_threads, is_float) {
.Call(`_monty_monty_rng_binomial`, ptr, n, r_size, r_prob, n_threads, is_float)
}

monty_rng_nbinomial <- function(ptr, n, r_size, r_prob, n_threads, is_float) {
.Call(`_monty_monty_rng_nbinomial`, ptr, n, r_size, r_prob, n_threads, is_float)
monty_rng_negative_binomial_prob <- function(ptr, n, r_size, r_prob, n_threads, is_float) {
.Call(`_monty_monty_rng_negative_binomial_prob`, ptr, n, r_size, r_prob, n_threads, is_float)
}

monty_rng_negative_binomial_mu <- function(ptr, n, r_size, r_mu, n_threads, is_float) {
.Call(`_monty_monty_rng_negative_binomial_mu`, ptr, n, r_size, r_mu, n_threads, is_float)
}

monty_rng_hypergeometric <- function(ptr, n, r_n1, r_n2, r_k, n_threads, is_float) {
Expand Down
34 changes: 33 additions & 1 deletion R/dsl-distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ distr_gamma_rate <- distribution(

distr_gamma_scale <- distribution(
name = "Gamma",
variant = "rate",
variant = "scale",
density = function(x, shape, scale) {
dgamma(x, shape, scale = scale, log = TRUE)
},
Expand All @@ -125,6 +125,36 @@ distr_hypergeometric <- distribution(
mean = quote(k * n1 / (n1 + n2))),
cpp = list(density = "hypergeometric", sample = "hypergeometric"))

distr_negative_binomial_prob <- distribution(
name = "NegativeBinomial",
variant = "prob",
density = function(x, size, prob) {
dnbinom(x, size, prob = prob, log = TRUE)
},
domain = c(0, Inf),
sample = function(rng, size, prob) rng$negative_binomial_prob(1, size, prob),
expr = list(
density = quote(lgamma(x + size) - lgamma(size) - lgamma(x + 1) +
x * log(1 - prob) + size * log(prob)),
mean = quote(size * (1 - prob) / prob)),
cpp = list(density = "negative_binomial_prob",
sample = "negative_binomial_prob"))

distr_negative_binomial_mu <- distribution(
name = "NegativeBinomial",
variant = "mu",
density = function(x, size, mu) {
dnbinom(x, size, mu = mu, log = TRUE)
},
domain = c(0, Inf),
sample = function(rng, size, mu) rng$negative_binomial_mu(1, size, mu),
expr = list(
density = quote(lgamma(x + size) - lgamma(size) - lgamma(x + 1) +
size * log(size) + x * log(mu) -
(size + x) * log(size + mu)),
mean = quote(mu)),
cpp = list(density = "negative_binomial_mu", sample = "negative_binomial_mu"))

distr_normal <- distribution(
name = "Normal",
density = function(x, mean, sd) dnorm(x, mean, sd, log = TRUE),
Expand Down Expand Up @@ -166,6 +196,8 @@ dsl_distributions <- local({
distr_gamma_rate,
distr_gamma_scale,
distr_hypergeometric,
distr_negative_binomial_prob,
distr_negative_binomial_mu,
distr_normal,
distr_poisson,
distr_uniform)
Expand Down
27 changes: 23 additions & 4 deletions R/rng.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@
##' rng$binomial(5, 10, 0.3)
##'
##' # Negative binomially distributed random numbers with size and prob
##' rng$nbinomial(5, 10, 0.3)
##' rng$negative_binomial_prob(5, 10, 0.3)
##'
##' # Negative binomially distributed random numbers with size and mean mu
##' rng$negative_binomial_mu(5, 10, 25)
##'
##' # Hypergeometric distributed random numbers with parameters n1, n2 and k
##' rng$hypergeometric(5, 6, 10, 4)
Expand Down Expand Up @@ -283,9 +286,25 @@ monty_rng <- R6::R6Class(
##' (between 0 and 1, length 1 or n)
##'
##' @param n_threads Number of threads to use; see Details
nbinomial = function(n, size, prob, n_threads = 1L) {
monty_rng_nbinomial(private$ptr, n, size, prob, n_threads,
private$float)
negative_binomial_prob = function(n, size, prob, n_threads = 1L) {
monty_rng_negative_binomial_prob(private$ptr, n, size, prob, n_threads,
private$float)
},

##' @description Generate `n` numbers from a negative binomial distribution
##'
##' @param n Number of samples to draw (per stream)
##'
##' @param size The target number of successful trials
##' (zero or more, length 1 or n)
##'
##' @param mu The mean
##' (zero or more, length 1 or n)
##'
##' @param n_threads Number of threads to use; see Details
negative_binomial_mu = function(n, size, mu, n_threads = 1L) {
monty_rng_negative_binomial_mu(private$ptr, n, size, mu, n_threads,
private$float)
},

##' @description Generate `n` numbers from a hypergeometric distribution
Expand Down
41 changes: 0 additions & 41 deletions inst/include/monty/random/nbinomial.hpp

This file was deleted.

52 changes: 52 additions & 0 deletions inst/include/monty/random/negative_binomial.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#pragma once

#include <cmath>

#include "monty/random/gamma.hpp"
#include "monty/random/poisson.hpp"
#include "monty/random/generator.hpp"

namespace monty {
namespace random {

namespace {

template <typename real_type>
void negative_binomial_validate(real_type size, real_type prob) {
if(!R_FINITE(size) || !R_FINITE(prob) || size <= 0 || prob <= 0 || prob > 1) {
char buffer[256];
snprintf(buffer, 256,
"Invalid call to negative_binomial with size = %g, prob = %g",
size, prob);
monty::utils::fatal_error(buffer);
}
}

}

template <typename real_type, typename rng_state_type>
real_type negative_binomial_prob(rng_state_type& rng_state, real_type size, real_type prob) {
#ifdef __CUDA_ARCH__
static_assert("negative_binomial_prob() not implemented for GPU targets");
#endif
negative_binomial_validate(size, prob);

if (rng_state.deterministic) {
return (1 - prob) * size / prob;
}
return (prob == 1) ? 0 : poisson(rng_state, gamma_scale(rng_state, size, (1 - prob) / prob));
}

template <typename real_type, typename rng_state_type>
real_type negative_binomial_mu(rng_state_type& rng_state, real_type size, real_type mu) {
#ifdef __CUDA_ARCH__
static_assert("negative_binomial_mu() not implemented for GPU targets");
#endif
const auto prob = size / (size + mu);
negative_binomial_validate(size, prob);

return negative_binomial_prob(rng_state, size, prob);
}

}
}
2 changes: 1 addition & 1 deletion inst/include/monty/random/random.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "monty/random/gamma.hpp"
#include "monty/random/hypergeometric.hpp"
#include "monty/random/multinomial.hpp"
#include "monty/random/nbinomial.hpp"
#include "monty/random/negative_binomial.hpp"
#include "monty/random/normal.hpp"
#include "monty/random/poisson.hpp"
#include "monty/random/uniform.hpp"
Expand Down
41 changes: 35 additions & 6 deletions man/monty_rng.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,17 @@ extern "C" SEXP _monty_monty_rng_binomial(SEXP ptr, SEXP n, SEXP r_size, SEXP r_
END_CPP11
}
// random.cpp
cpp11::sexp monty_rng_nbinomial(SEXP ptr, int n, cpp11::doubles r_size, cpp11::doubles r_prob, int n_threads, bool is_float);
extern "C" SEXP _monty_monty_rng_nbinomial(SEXP ptr, SEXP n, SEXP r_size, SEXP r_prob, SEXP n_threads, SEXP is_float) {
cpp11::sexp monty_rng_negative_binomial_prob(SEXP ptr, int n, cpp11::doubles r_size, cpp11::doubles r_prob, int n_threads, bool is_float);
extern "C" SEXP _monty_monty_rng_negative_binomial_prob(SEXP ptr, SEXP n, SEXP r_size, SEXP r_prob, SEXP n_threads, SEXP is_float) {
BEGIN_CPP11
return cpp11::as_sexp(monty_rng_nbinomial(cpp11::as_cpp<cpp11::decay_t<SEXP>>(ptr), cpp11::as_cpp<cpp11::decay_t<int>>(n), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(r_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(r_prob), cpp11::as_cpp<cpp11::decay_t<int>>(n_threads), cpp11::as_cpp<cpp11::decay_t<bool>>(is_float)));
return cpp11::as_sexp(monty_rng_negative_binomial_prob(cpp11::as_cpp<cpp11::decay_t<SEXP>>(ptr), cpp11::as_cpp<cpp11::decay_t<int>>(n), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(r_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(r_prob), cpp11::as_cpp<cpp11::decay_t<int>>(n_threads), cpp11::as_cpp<cpp11::decay_t<bool>>(is_float)));
END_CPP11
}
// random.cpp
cpp11::sexp monty_rng_negative_binomial_mu(SEXP ptr, int n, cpp11::doubles r_size, cpp11::doubles r_mu, int n_threads, bool is_float);
extern "C" SEXP _monty_monty_rng_negative_binomial_mu(SEXP ptr, SEXP n, SEXP r_size, SEXP r_mu, SEXP n_threads, SEXP is_float) {
BEGIN_CPP11
return cpp11::as_sexp(monty_rng_negative_binomial_mu(cpp11::as_cpp<cpp11::decay_t<SEXP>>(ptr), cpp11::as_cpp<cpp11::decay_t<int>>(n), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(r_size), cpp11::as_cpp<cpp11::decay_t<cpp11::doubles>>(r_mu), cpp11::as_cpp<cpp11::decay_t<int>>(n_threads), cpp11::as_cpp<cpp11::decay_t<bool>>(is_float)));
END_CPP11
}
// random.cpp
Expand Down Expand Up @@ -208,7 +215,8 @@ static const R_CallMethodDef CallEntries[] = {
{"_monty_monty_rng_jump", (DL_FUNC) &_monty_monty_rng_jump, 2},
{"_monty_monty_rng_long_jump", (DL_FUNC) &_monty_monty_rng_long_jump, 2},
{"_monty_monty_rng_multinomial", (DL_FUNC) &_monty_monty_rng_multinomial, 6},
{"_monty_monty_rng_nbinomial", (DL_FUNC) &_monty_monty_rng_nbinomial, 6},
{"_monty_monty_rng_negative_binomial_mu", (DL_FUNC) &_monty_monty_rng_negative_binomial_mu, 6},
{"_monty_monty_rng_negative_binomial_prob", (DL_FUNC) &_monty_monty_rng_negative_binomial_prob, 6},
{"_monty_monty_rng_normal", (DL_FUNC) &_monty_monty_rng_normal, 7},
{"_monty_monty_rng_pointer_init", (DL_FUNC) &_monty_monty_rng_pointer_init, 4},
{"_monty_monty_rng_pointer_sync", (DL_FUNC) &_monty_monty_rng_pointer_sync, 2},
Expand Down
Loading

0 comments on commit 3d876ce

Please sign in to comment.