Skip to content

Commit

Permalink
Change scan to specialise on a subtype of AbstractEpiModel
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBrand1 committed Feb 28, 2024
1 parent c4b6a9a commit d676612
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 5 deletions.
4 changes: 2 additions & 2 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ using Distributions,
DataFramesMeta

# Exported utilities
export create_discrete_pmf, spread_draws
export create_discrete_pmf, spread_draws, scan

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections
export EpiData, Renewal, ExpGrowthRate, DirectInfections, AbstractEpiModel

# Exported Turing model constructors
export make_epi_inference_model
Expand Down
29 changes: 28 additions & 1 deletion EpiAware/src/utilities.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,32 @@

function scan(f::F, init, xs) where {F}
"""
scan(f::F, init, xs) where {F <: AbstractEpiModel}
Apply `f` to each element of `xs` and accumulate the results.
`f` must be a [callable](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects)
on a sub-type of `AbstractEpiModel`.
### Design note
`scan` is being restricted to `AbstractEpiModel` sub-types to ensure:
1. That compiler specialization is [activated](https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing)
2. Also avoids potential compiler [overhead](https://docs.julialang.org/en/v1/devdocs/functions/#compiler-efficiency-issues)
from specialisation on `f<: Function`.
# Arguments
- `f`: A callable/functor that takes two arguments, `carry` and `x`, and returns a new
`carry` and a result `y`.
- `init`: The initial value for the `carry` variable.
- `xs`: An iterable collection of elements.
# Returns
- `ys`: An array containing the results of applying `f` to each element of `xs`.
- `carry`: The final value of the `carry` variable after processing all elements of `xs`.
"""
function scan(f::F, init, xs) where {F <: AbstractEpiModel}
carry = init
ys = similar(xs)
for (i, x) in enumerate(xs)
Expand Down
28 changes: 26 additions & 2 deletions EpiAware/test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,19 @@
xs = [1, 2, 3, 4, 5]
expected_ys = [1, 3, 6, 10, 15]
expected_carry = 15
ys, carry = EpiAware.scan(add, 0, xs)

# Check that a generic function CAN'T be used
@test_throws MethodError EpiAware.scan(add, 0, xs)

# Check that a callable subtype of `AbstractEpiModel` CAN be used
struct TestEpiModelAdd <: AbstractEpiModel
end
function (epimodel::TestEpiModelAdd)(a, b)
return a + b, a + b
end

ys, carry = EpiAware.scan(TestEpiModelAdd(), 0, xs)

@test ys == expected_ys
@test carry == expected_carry
end
Expand All @@ -22,7 +34,19 @@ end
expected_ys = [1, 2, 6, 24, 120]
expected_carry = 120

ys, carry = EpiAware.scan(multiply, 1, xs)
# Check that a generic function CAN'T be used
@test_throws MethodError ys, carry=EpiAware.scan(multiply, 1, xs)

# Check that a callable subtype of `AbstractEpiModel` CAN be used
struct TestEpiModelMult <: AbstractEpiModel
end

function (epimodel::TestEpiModelMult)(a, b)
return a * b, a * b
end

ys, carry = EpiAware.scan(TestEpiModelMult(), 1, xs)

@test ys == expected_ys
@test carry == expected_carry
end
Expand Down

0 comments on commit d676612

Please sign in to comment.