From 4a30802c6d12a78f66276e4ee9e834ad0c2f4e8e Mon Sep 17 00:00:00 2001 From: edknock Date: Thu, 17 Oct 2024 18:13:52 +0100 Subject: [PATCH] setup negative binomial --- DESCRIPTION | 2 +- R/dsl-distributions.R | 30 ++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index eb1b244c..4975d5b1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: monty Title: Monte Carlo Models -Version: 0.2.19 +Version: 0.2.20 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Wes", "Hinsley", role = "aut"), diff --git a/R/dsl-distributions.R b/R/dsl-distributions.R index acd4576d..048e7f3c 100644 --- a/R/dsl-distributions.R +++ b/R/dsl-distributions.R @@ -125,6 +125,34 @@ distr_hypergeometric <- distribution( mean = quote(k * n1 / (n1 + n2))), cpp = list(density = "hypergeometric", sample = "hypergeometric")) +distr_negative_binomial_mu <- distribution( + name = "NegativeBinomial", + variant = "mu", + density = function(x, size, mu) { + dnbinom(x, size, mu = mu, log = TRUE) + }, + domain = c(0, Inf), + sample = function(rng, size, mu) rng$negative_binomial_mu(1, size, mu), + expr = list( + density = quote(lchoose(x + size - 1, x) + size * log(size) + + x * log(mu) - (size + x) * log(size + mu)), + mean = quote(mu)), + cpp = list(density = "negative_binomial_mu", sample = "negative_binomial_mu")) + +distr_negative_binomial_prob <- distribution( + name = "NegativeBinomial", + variant = "prob", + density = function(x, size, prob) { + dnbinom(x, size, prob = prob, log = TRUE) + }, + domain = c(0, Inf), + sample = function(rng, size, prob) rng$negative_binomial_prob(1, size, prob), + expr = list( + density = quote(lchoose(x + size - 1, x) + x * log(1 - prob) + + size * log(prob)), + mean = quote(size * (1 - prob) / prob)), + cpp = list(density = "negative_binomial_prob", sample = "negative_binomial_prob")) + distr_normal <- distribution( name = "Normal", density = function(x, mean, sd) dnorm(x, mean, sd, log = TRUE), @@ -166,6 +194,8 @@ dsl_distributions <- local({ distr_gamma_rate, distr_gamma_scale, distr_hypergeometric, + distr_negative_binomial_mu, + distr_negative_binomial_prob, distr_normal, distr_poisson, distr_uniform)