Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue 492: Create distribution that returns <: Union{Int,BigInt} #497

Merged
merged 22 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion EpiAware/src/EpiAwareUtils/EpiAwareUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ using DynamicPPL: Model, fix, condition, @submodel, @model
using MCMCChains: Chains
using Random: AbstractRNG, randexp
using Tables: rowtable
import Base: eltype

using Distributions, DocStringExtensions, QuadGK, Statistics, Turing

#Export Structures
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial
export HalfNormal, DirectSample, SafePoisson, SafeNegativeBinomial, SafeIntValued, SafeInt,
SafeDiscreteUnivariateDistribution

#Export functions
export scan, spread_draws, censored_cdf, censored_pmf, get_param_array, prefix_submodel, ∫F
Expand All @@ -32,6 +34,7 @@ include("turing-methods.jl")
include("DirectSample.jl")
include("post-inference.jl")
include("get_param_array.jl")
include("SafeInt.jl")
include("SafePoisson.jl")
include("SafeNegativeBinomial.jl")

Expand Down
16 changes: 16 additions & 0 deletions EpiAware/src/EpiAwareUtils/SafeInt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
const SafeInt = Union{Int, BigInt}

"""
A type to represent real-valued distributions, the purpose of this type is to avoid problems
with the `eltype` function when having `rand` calls in the model.
"""
struct SafeIntValued <: Distributions.ValueSupport end
function Base.eltype(::Type{<:Distributions.Sampleable{F, SafeIntValued}}) where {F}
SafeInt
end

"""
A constant alias for `Distribution{Univariate, SafeIntValued}`. This type represents a univariate distribution with real-valued outcomes.
"""
const SafeDiscreteUnivariateDistribution = Distributions.Distribution{
Distributions.Univariate, SafeIntValued}
2 changes: 1 addition & 1 deletion EpiAware/src/EpiAwareUtils/SafeNegativeBinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ var(d)
2.4617291430060293e40
```
"
struct SafeNegativeBinomial{T <: Real} <: DiscreteUnivariateDistribution
struct SafeNegativeBinomial{T <: Real} <: SafeDiscreteUnivariateDistribution
r::T
p::T

Expand Down
6 changes: 3 additions & 3 deletions EpiAware/src/EpiAwareUtils/SafePoisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var(d)
7.016735912097631e20
```
"
struct SafePoisson{T <: Real} <: DiscreteUnivariateDistribution
struct SafePoisson{T <: Real} <: SafeDiscreteUnivariateDistribution
λ::T

SafePoisson{T}(λ::Real) where {T <: Real} = new{T}(λ)
Expand Down Expand Up @@ -86,7 +86,7 @@ Distributions.rate(d::SafePoisson) = d.λ
### Statistics

Distributions.mean(d::SafePoisson) = d.λ
Distributions.mode(d::SafePoisson) = _safe_int_floor(d.λ)
Distributions.mode(d::SafePoisson) = floor(d.λ)
Distributions.var(d::SafePoisson) = d.λ
Distributions.skewness(d::SafePoisson) = one(typeof(d.λ)) / sqrt(d.λ)
Distributions.kurtosis(d::SafePoisson) = one(typeof(d.λ)) / d.λ
Expand Down Expand Up @@ -229,7 +229,7 @@ function log1pmx(x::Float64)
end

# Procedure F
function procf(λ, K::Int, s::Float64)
function procf(λ, K::SafeInt, s::Float64)
# can be pre-computed, but does not seem to affect performance
ω = 0.3989422804014327 / s
b1 = 0.041666666666666664 / λ
Expand Down
8 changes: 8 additions & 0 deletions EpiAware/test/EpiAwareUtils/SafeInt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@testitem "SafeInt Type Tests" begin
using Distributions
struct DummySampleable <: Sampleable{Univariate, SafeIntValued} end

@test SafeIntValued <: Distributions.ValueSupport
@test eltype(DummySampleable) <: Union{Int, BigInt}
@test SafeDiscreteUnivariateDistribution == Distribution{Univariate, SafeIntValued}
end
2 changes: 1 addition & 1 deletion EpiAware/test/EpiAwareUtils/SafeNegativeBinomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ end

dist = SafeNegativeBinomial(r, p)
@testset "Large value of mean samples a BigInt with SafePoisson" begin
@test rand(dist) isa BigInt
@test rand(dist) isa Union{Int, BigInt}
end
@testset "Large value of mean sample failure with Poisson" begin
_dist = EpiAware.EpiAwareUtils._negbin(dist)
Expand Down
8 changes: 4 additions & 4 deletions EpiAware/test/EpiAwareUtils/SafePoisson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
λ = 10.0
dist = SafePoisson(λ)
@test typeof(dist) <: SafePoisson
@test rand(dist) isa Int
@test rand(dist, 10) isa Vector{Int}
@test rand(dist, 10, 10) isa Array{Int}
@test rand(dist) isa SafeInt
@test rand(dist, 10) isa Vector{SafeInt}
@test rand(dist, 10, 10) isa Array{SafeInt}
end

@testitem "Check distribution properties of SafePoisson" begin
Expand Down Expand Up @@ -54,7 +54,7 @@ end
bigλ = exp(48.0) #Large value of λ
dist = SafePoisson(bigλ)
@testset "Large value of mean samples a BigInt with SafePoisson" begin
@test rand(dist) isa BigInt
@test rand(dist) isa SafeInt
end
@testset "Large value of mean sample failure with Poisson" begin
_dist = Poisson(dist.λ)
Expand Down
4 changes: 1 addition & 3 deletions EpiAware/test/EpiObsModels/modifiers/LatentDelay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,7 @@ end
ExpGrowthRate,
Renewal] .|>
em_type -> em_type(
data = EpiData([0.2, 0.5, 0.3],
em_type == Renewal ? softplus : exp
),
data = EpiData([0.2, 0.5, 0.3], exp),
initialisation_prior = Normal(log(100.0), 0.01)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ end

mdl = generate_observations(model, missing, 10)
draw = rand(mdl)
@test typeof(draw[:var"Test.y_t[1]"]) <: Int
@test typeof(draw[:var"Test.y_t[1]"]) <: Real
end
Loading