Skip to content

Commit

Permalink
Vectorize pmfs (#1)
Browse files Browse the repository at this point in the history
* use scipy factorial

* Vectorize pmf's

* repair tests
  • Loading branch information
swo authored Feb 27, 2025
1 parent 596d515 commit 935144e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 20 deletions.
57 changes: 42 additions & 15 deletions src/reedfrost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import math

import numpy as np
import scipy.optimize
import scipy.stats
from numpy.typing import NDArray
from scipy.optimize import brentq
from scipy.special import factorial


@functools.cache
def _gontcharoff(k: int, q: float, m: int) -> float:
def _gontcharoff1(k: int, q: float, m: int) -> float:
"""
Gontcharoff polynomials, specific to the Lefevre & Picard
formulations of Reed-Frost final outbreak size pmf calculations
Expand All @@ -27,31 +29,56 @@ def _gontcharoff(k: int, q: float, m: int) -> float:
else:
return 1.0 / math.factorial(k) - sum(
[
q ** ((m + i) * (k - i)) / math.factorial(k - i) * _gontcharoff(i, q, m)
q ** ((m + i) * (k - i)) / factorial(k - i) * _gontcharoff1(i, q, m)
for i in range(0, k)
]
)


def pmf(k: int, n: int, p: float, m: int = 1) -> float:
def _gontcharoff(
k: int | NDArray[np.int64], q: float, m: int
) -> float | NDArray[np.float64]:
"""
Gontcharoff polynomials, specific to the Lefevre & Picard
formulations of Reed-Frost final outbreak size pmf calculations
See Lefevre & Picard 1990 (doi:10.2307/1427595) equation 2.1
Args:
k (int, or int array): degree
q (float): 1 - probability of "effective" contact
m (int): number of initial infected
Returns:
float, or float array: value of the polynomial
"""
if isinstance(k, int):
return _gontcharoff1(k=k, q=q, m=m)
else:
return np.array([_gontcharoff1(k=kk, q=q, m=m) for kk in k])


def pmf(
k: int | NDArray[np.int64], n: int, p: float, m: int = 1
) -> float | NDArray[np.float64]:
"""
Probability mass function for final size of a Reed-Frost outbreak
See Lefevre & Picard 1990 (doi:10.2307/1427595) equation 3.10
Args:
k (int): number of total infections
k (int, or int array): number of total infections
n (int): initial number susceptible
m (int): initial number infected
p (float): probability of "effective contact" (i.e., infection)
Returns:
float: pmf of the total infection distribution
float, or float array: pmf of the total infection distribution
"""
q = 1.0 - p
return (
math.factorial(n)
/ math.factorial(n - k)
factorial(n)
/ factorial(n - k)
* q ** ((n - k) * (m + k))
* _gontcharoff(k, q, m)
)
Expand All @@ -76,26 +103,26 @@ def f(t):

# do type checking here because type hinting gets confused about whether
# this results a tuple or a float
result = scipy.optimize.brentq(f, 0.0, 1.0, full_output=False)
result = brentq(f, 0.0, 1.0, full_output=False)
assert isinstance(result, float)
return result


def dist_large(n: int, lambda_: float, i_n: int = 1):
def pmf_large(
k: int | NDArray[np.int64], n: int, lambda_: float, i_n: int = 1
) -> float | NDArray[np.float64]:
"""Distribution of outbreak sizes, given a large outbreak
See Barbour & Sergey 2004 (doi:10.1016/j.spa.2004.03.013) corollary 3.4
Args:
k (int, or int array): number of total infections
n (int): initial number of susceptibles
lambda_ (float): reproduction number
i_n (int, optional): initial number of susceptibles. Defaults to 1.
Returns:
scipy.stats.norm: RV object
Examples:
dist_large(100, 1.5, 1).pdf(np.linspace(0, 100))
float, or float array: pmf of the total infection distribution
"""
if not lambda_ > 1.0:
raise RuntimeWarning(
Expand All @@ -105,4 +132,4 @@ def dist_large(n: int, lambda_: float, i_n: int = 1):
sigma = np.sqrt(theta * (1.0 - theta) / (1 - lambda_ * theta) ** 2)
sd = np.sqrt(n) * sigma
mean = n * (1.0 - theta)
return scipy.stats.norm(loc=mean, scale=sd)
return scipy.stats.norm.pdf(x=k, loc=mean, scale=sd)
24 changes: 19 additions & 5 deletions tests/test_reed_frost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import pytest
from pytest import approx

import reedfrost as rf

Expand All @@ -8,11 +8,25 @@ def test_pmf_2():
"""For 1 infected and 2 susceptible, do the math by hand"""

for p in [0.1, 0.7]:
assert rf.pmf(0, 2, p, m=1) == approx((1 - p) ** 2, abs=1e-6)
assert rf.pmf(1, 2, p, m=1) == approx(2 * p * (1 - p) ** 2, abs=1e-6)
assert rf.pmf(2, 2, p, m=1) == approx(p**2 + 2 * p**2 * (1 - p), abs=1e-6)
current = rf.pmf(k=np.array([0, 1, 2]), n=2, p=p, m=1)
expected = [(1 - p) ** 2, 2 * p * (1 - p) ** 2, p**2 + 2 * p**2 * (1 - p)]
np.testing.assert_allclose(current, expected, atol=1e-6, rtol=0.0)


def test_pmf_vector():
"""PMF can take vectors"""
result = rf.pmf(k=np.array([0, 1, 2]), n=2, p=0.1, m=1)
assert isinstance(result, np.ndarray)
assert len(result) == 3


def test_pmf_large():
current = rf.pmf_large(k=np.array([0, 10, 50, 90]), n=100, lambda_=1.5, i_n=1)
np.testing.assert_allclose(
current, [2.383925e-07, 8.889213e-06, 2.349612e-02, 1.623636e-03], rtol=1e-6
)


def test_large_dist_warning():
with pytest.raises(RuntimeWarning):
rf.dist_large(n=10, lambda_=0.5)
rf.pmf_large(k=1, n=10, lambda_=0.5)

0 comments on commit 935144e

Please sign in to comment.