Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds build_paramDict #26

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/diseaseProg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ caseFatalityRatioHospital_given_COVID_by_age = [0.00856164, 0.03768844, 0.023213
agePopulationTotal = 1000*[8044.056, 7642.473, 8558.707, 9295.024, 8604.251,
9173.465, 7286.777, 5830.635, 3450.616]


function _agePopulationRatio(agePopulationTotal)
agePopulationTotal *= 55.98/66.27
return agePopulationTotal/sum(agePopulationTotal)
Expand All @@ -41,7 +42,7 @@ function einsum(str, a, b, T)
end

function einsum(str, a, T)
if str=="ijlml->ijlm"
if str=="ijlml->ijlm"
return _einsum4(a, T)
elseif str=="ijkj->ijk"
return _einsum6(a, T)
Expand Down
92 changes: 92 additions & 0 deletions src/diseaseProg_withCallableStructs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Library Imports
using Parameters
using Test

abstract type CType end
agePopulationRatio=0.5
## Existing Function
function adjustRatesByAge_KeepAverageRateTest(rate; agePopulationRatio=agePopulationRatio,
ageRelativeAdjustment::Array=[],
maxOutRate::Float64=10.0)
if rate == 0
return fill(0, size(ageRelativeAdjustment))
end
if rate >= maxOutRate
@warn("covidTesting::adjustRatesByAge_KeepAverageRate Input rate $rate >
maxOutRate $maxOutRate, returning input rates")
return rate*(fill(1, size(ageRelativeAdjustment)))
end
out = fill(0, size(ageRelativeAdjustment))
out[1] = maxOutRate + 1
while sum(out .>= maxOutRate) > 0
corrFactor = sum(agePopulationRatio ./ (1 .+ ageRelativeAdjustment))
out = rate * (1 .+ ageRelativeAdjustment) * corrFactor
if sum(out .>= maxOutRate) > 0
@warn("covidTesting::adjustRatesByAge_KeepAverageRate Adjusted rate
larger than $maxOutRate encountered, reducing ageAdjustment
variance by 10%")
tmp_mean = sum(ageRelativeAdjustment)/length(ageRelativeAdjustment)
ageRelativeAdjustment = tmp_mean .+ sqrt(0.9)*(
ageRelativeAdjustment .- tmp_mean)
end
end
return out
end

## Callable struct

@with_kw mutable struct adjustRatesByAge_KeepAverageRate <: CType
agePopulationRatio::Float64=agePopulationRatio
ageRelativeAdjustment::Array=[]
dinskid marked this conversation as resolved.
Show resolved Hide resolved
maxOutRate::Float64=10.0
end
# maxOutRate::Float64=10.0

function (f::adjustRatesByAge_KeepAverageRate)(rate; agePopulationRatio=agePopulationRatio,
ageRelativeAdjustment::Array=[],
maxOutRate::Float64=10.0
)
f.agePopulationRatio = agePopulationRatio
f.ageRelativeAdjustment = ageRelativeAdjustment
f.maxOutRate = maxOutRate
if rate == 0
return fill(0, size(f.ageRelativeAdjustment))
end
if rate >= f.maxOutRate
@warn("covidTesting::adjustRatesByAge_KeepAverageRate Input rate $rate >
maxOutRate $(f.maxOutRate), returning input rates")
return rate*(fill(1, size(f.ageRelativeAdjustment)))
end
out = fill(0, size(f.ageRelativeAdjustment))
out[1] = f.maxOutRate + 1
while sum(out .>= f.maxOutRate) > 0
corrFactor = sum(f.agePopulationRatio ./ (1 .+ f.ageRelativeAdjustment))
out = rate * (1 .+ f.ageRelativeAdjustment) * corrFactor
if sum(out .>= f.maxOutRate) > 0
@warn("covidTesting::adjustRatesByAge_KeepAverageRate Adjusted rate
larger than $(f.maxOutRate) encountered, reducing ageAdjustment
variance by 10%")
tmp_mean = sum(f.ageRelativeAdjustment)/length(f.ageRelativeAdjustment)
f.ageRelativeAdjustment = tmp_mean .+ sqrt(0.9)*(
f.ageRelativeAdjustment .- tmp_mean)
end
end
return out
end

@testset "Function vs Callable Struct" begin
@test adjustRatesByAge_KeepAverageRateTest(10; ageRelativeAdjustment=[1,2,3]) ==
adjustRatesByAge_KeepAverageRate()(10; ageRelativeAdjustment=[1,2,3])
end

istransparent(::Any) = false
istransparent(::CType) = true

params(m) = params(m, Val(istransparent(m)))
params(m, ::Val{false}) = m
function params(m, ::Val{true})
fields = fieldnames(typeof(m))
NamedTuple{fields}(Tuple([params(getfield(m, field)) for field in fields]))
end

println(params(adjustRatesByAge_KeepAverageRate()))