Skip to content

Commit

Permalink
prototype binomial obs model
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Abbott committed Jan 7, 2025
1 parent 104bdd8 commit cc08506
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 8 deletions.
5 changes: 3 additions & 2 deletions EpiAware/src/EpiObsModels/EpiObsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ using Turing, Distributions, DocStringExtensions, SparseArrays, LinearAlgebra
using LogExpFunctions: xexpy, log1pexp

# Observation error models
export PoissonError, NegativeBinomialError
export PoissonError, NegativeBinomialError, BinomialError

# Observation error model functions
export generate_observation_error_priors, observation_error
export generate_observation_error_priors, define_y_t, observation_error

# Observation model modifiers
export LatentDelay, Ascertainment, PrefixObservationModel, RecordExpectedObs
Expand All @@ -41,6 +41,7 @@ include("StackObservationModels.jl")
include("ObservationErrorModels/methods.jl")
include("ObservationErrorModels/NegativeBinomialError.jl")
include("ObservationErrorModels/PoissonError.jl")
include("ObservationErrorModels/BinomialError.jl")
include("utils.jl")

end
70 changes: 70 additions & 0 deletions EpiAware/src/EpiObsModels/ObservationErrorModels/BinomialError.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
@doc raw"
The `BinomialError` struct represents an observation model for
binomial errors. It is a subtype of `AbstractTuringObservationErrorModel`,
which incorporates the number of trials `N` (default 1000) and models the number of successes `y_t`. Unlike the `NegativeBinomialError` model,
the `BinomialError` model does not require `Y_t` to be an expected count
but instead a probability of success. It also requires that `y_t` is a named tuple with `N` and `y_t` as fields. Where `N` is the number of trials
and `y_t` is the number of successes.
## Constructors
- `BinomialError(N::Integer)`: Constructs a `BinomialError` object with default values for `N` and `y_t`.
- `BinomialError(; N::Integer = 1000)`: Constructs a `BinomialError` object, allowing the user to set the number
of trials if no `y_t` is provided.
## Examples
```julia
using EpiAware
# Create a BinomialError model with default number of trials
bin = BinomialError()
# Create a BinomialError model with custom number of trials
bin_custom = BinomialError(20)
# Define observation data with number of trials and the number of successes
y_data = (N = fill(10, 10), y_t = fill(3.0, 10))
# Generate observations using the BinomialError model
bin_model = generate_observations(bin, y_data, fill(0.5, 10))
# Sample from the generated model
sample = rand(bin_model)
```
"
@kwdef struct BinomialError{I <: Integer} <: AbstractTuringObservationErrorModel
"The number of trials for the binomial distribution."
N::I = 1000
end

@doc raw"
Defines `y_t` when it is `missing` for `BinomialError`.
Constructs a NamedTuple with `N` from the model `obs_model` and `y_t` as a vector of `Missing` values.
"
function define_y_t(
obs_model::BinomialError, y_t, Y_t)
if ismissing(y_t)
y_t = Vector{Missing}(missing, length(Y_t))
else
y_t = y_t.y_t
end
return y_t
end

@doc raw"
Generates priors for the `BinomialError` model. Extracts `N` from `y_t` when it is a NamedTuple.
"
@model function generate_observation_error_priors(obs_model::BinomialError, y_t, Y_t)
if ismissing(y_t)
N = fill(obs_model.N, length(Y_t))
else
N = y_t.N
end
return (N = N,)
end

@doc raw"
This function generates the observation error model based on the binomial error model. It dispatches to the `Binomial` distribution using the number of trials `N` and the probability of success `p`.
"
function observation_error(obs_model::BinomialError, Y_t, N)
return Binomial(N, Y_t)
end
22 changes: 16 additions & 6 deletions EpiAware/src/EpiObsModels/ObservationErrorModels/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,21 @@ It dispatches to the `observation_error` function to generate the observation er
obs_model::AbstractTuringObservationErrorModel,
y_t,
Y_t)

@submodel priors = generate_observation_error_priors(obs_model, y_t, Y_t)

if ismissing(y_t)
y_t = Vector{Missing}(missing, length(Y_t))
end
y_t_defined = define_y_t(obs_model, y_t, Y_t)

diff_t = length(y_t) - length(Y_t)
diff_t = length(y_t_defined) - length(Y_t)
@assert diff_t>=0 "The observation vector must be longer than or equal to the expected observation vector"

pad_Y_t = Y_t .+ 1e-6

for i in eachindex(Y_t)
y_t[i + diff_t] ~ observation_error(obs_model, pad_Y_t[i], priors...)
y_t_defined[i + diff_t] ~ observation_error(obs_model, pad_Y_t[i], priors...)
end

return y_t
return y_t_defined
end

@doc raw"
Expand All @@ -35,6 +34,17 @@ Generates priors for the observation error model. This should return a named tup
return NamedTuple()
end

@doc raw"
Defines `y_t` when it is `missing` by dispatching on the type of `obs_model`
"
function define_y_t(
obs_model::AbstractTuringObservationErrorModel, y_t, Y_t)
if ismissing(y_t)
y_t = Vector{Missing}(missing, length(Y_t))
end
return y_t
end

@doc raw"
The observation error distribution for the observation error model. This function should return the distribution for the observation error given the expected observation value `Y_t` and the priors generated by `generate_observation_error_priors`.
"
Expand Down

0 comments on commit cc08506

Please sign in to comment.