From bcbc041302fda2d0ec19891b925e1eb84d48e91e Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 28 Aug 2024 13:48:31 -0500 Subject: [PATCH 1/2] cleanup names --- docs/source/tutorials/periodic_effects.qmd | 2 +- pyrenew/process/rtperiodicdiffar.py | 34 +++++++++++----------- test/test_rtperiodicdiff.py | 6 ++-- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 1603ed59..9ad019e6 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -27,7 +27,7 @@ from pyrenew import process, deterministic rt_proc = process.RtWeeklyDiffARProcess( name="rt_weekly_diff", offset=0, - log_rt_rv=deterministic.DeterministicVariable( + log_rt_init_rv=deterministic.DeterministicVariable( name="log_rt", value=jnp.array([0.1, 0.2]) ), autoreg_rv=deterministic.DeterministicVariable( diff --git a/pyrenew/process/rtperiodicdiffar.py b/pyrenew/process/rtperiodicdiffar.py index 9186b9ef..747d6982 100644 --- a/pyrenew/process/rtperiodicdiffar.py +++ b/pyrenew/process/rtperiodicdiffar.py @@ -54,7 +54,7 @@ def __init__( name: str, offset: int, period_size: int, - log_rt_rv: RandomVariable, + log_rt_init_rv: RandomVariable, autoreg_rv: RandomVariable, periodic_diff_sd_rv: RandomVariable, ar_process_suffix: str = "_first_diff_ar_process_noise", @@ -69,7 +69,7 @@ def __init__( offset : int Relative point at which data starts, must be between 0 and period_size - 1. - log_rt_rv : RandomVariable + log_rt_init_rv : RandomVariable Log Rt prior for the first two observations. autoreg_rv : RandomVariable Autoregressive parameter. @@ -87,7 +87,7 @@ def __init__( """ self.validate( - log_rt_rv=log_rt_rv, + log_rt_init_rv=log_rt_init_rv, autoreg_rv=autoreg_rv, periodic_diff_sd_rv=periodic_diff_sd_rv, ) @@ -95,7 +95,7 @@ def __init__( self.name = name self.period_size = period_size self.offset = offset - self.log_rt_rv = log_rt_rv + self.log_rt_init_rv = log_rt_init_rv self.autoreg_rv = autoreg_rv self.periodic_diff_sd_rv = periodic_diff_sd_rv self.ar_diff = DifferencedProcess( @@ -109,7 +109,7 @@ def __init__( @staticmethod def validate( - log_rt_rv: any, + log_rt_init_rv: any, autoreg_rv: any, periodic_diff_sd_rv: any, ) -> None: @@ -118,7 +118,7 @@ def validate( Parameters ---------- - log_rt_rv : any + log_rt_init_rv : any Log Rt prior for the first two observations. autoreg_rv : any Autoregressive parameter. @@ -130,7 +130,7 @@ def validate( None """ - _assert_sample_and_rtype(log_rt_rv) + _assert_sample_and_rtype(log_rt_init_rv) _assert_sample_and_rtype(autoreg_rv) _assert_sample_and_rtype(periodic_diff_sd_rv) @@ -159,9 +159,9 @@ def sample( """ # Initial sample - log_rt_rv = self.log_rt_rv.sample(**kwargs)[0].value - b = self.autoreg_rv.sample(**kwargs)[0].value - s_r = self.periodic_diff_sd_rv.sample(**kwargs)[0].value + log_rt_init = self.log_rt_init_rv.sample(**kwargs)[0].value + autoreg = self.autoreg_rv.sample(**kwargs)[0].value + noise_sd = self.periodic_diff_sd_rv.sample(**kwargs)[0].value # How many periods to sample? n_periods = (duration + self.period_size - 1) // self.period_size @@ -170,11 +170,11 @@ def sample( log_rt = self.ar_diff( n=n_periods, - init_vals=jnp.array([log_rt_rv[0]]), - autoreg=b, - noise_sd=s_r, + init_vals=jnp.array([log_rt_init[0]]), + autoreg=autoreg, + noise_sd=noise_sd, fundamental_process_init_vals=jnp.array( - [log_rt_rv[1] - log_rt_rv[0]] + [log_rt_init[1] - log_rt_init[0]] ), )[0] @@ -201,7 +201,7 @@ def __init__( self, name: str, offset: int, - log_rt_rv: RandomVariable, + log_rt_init_rv: RandomVariable, autoreg_rv: RandomVariable, periodic_diff_sd_rv: RandomVariable, ) -> None: @@ -214,7 +214,7 @@ def __init__( Name of the site. offset : int Relative point at which data starts, must be between 0 and 6. - log_rt_rv : RandomVariable + log_rt_init_rv : RandomVariable Log Rt prior for the first two observations. autoreg_rv : RandomVariable Autoregressive parameter. @@ -230,7 +230,7 @@ def __init__( name=name, offset=offset, period_size=7, - log_rt_rv=log_rt_rv, + log_rt_init_rv=log_rt_init_rv, autoreg_rv=autoreg_rv, periodic_diff_sd_rv=periodic_diff_sd_rv, ) diff --git a/test/test_rtperiodicdiff.py b/test/test_rtperiodicdiff.py index 8d1ac28a..73bc3382 100644 --- a/test/test_rtperiodicdiff.py +++ b/test/test_rtperiodicdiff.py @@ -17,7 +17,7 @@ def test_rtweeklydiff() -> None: params = { "name": "test", "offset": 0, - "log_rt_rv": DeterministicVariable( + "log_rt_init_rv": DeterministicVariable( name="log_rt", value=jnp.array([0.1, 0.2]) ), "autoreg_rv": DeterministicVariable( @@ -65,7 +65,7 @@ def test_rtweeklydiff_no_autoregressive() -> None: params = { "name": "test", "offset": 0, - "log_rt_rv": DeterministicVariable( + "log_rt_init_rv": DeterministicVariable( name="log_rt", value=jnp.array([0.0, 0.0]) ), # No autoregression! @@ -109,7 +109,7 @@ def test_rtperiodicdiff_smallsample(inits): params = { "name": "test", "offset": 0, - "log_rt_rv": DeterministicVariable( + "log_rt_init_rv": DeterministicVariable( name="log_rt", value=inits, ), From 80eb58fc23cc7ec42437cffb6a37c184fa488da1 Mon Sep 17 00:00:00 2001 From: damonbayer Date: Wed, 28 Aug 2024 14:28:41 -0500 Subject: [PATCH 2/2] Remove RtWeeklyDiffARProcess --- docs/source/tutorials/periodic_effects.qmd | 5 ++- pyrenew/process/__init__.py | 6 +-- pyrenew/process/rtperiodicdiffar.py | 46 ---------------------- test/test_rtperiodicdiff.py | 13 +++--- 4 files changed, 12 insertions(+), 58 deletions(-) diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 9ad019e6..e215dedd 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -24,9 +24,10 @@ from pyrenew import process, deterministic ```{python} # The random process for Rt -rt_proc = process.RtWeeklyDiffARProcess( +rt_proc = process.RtPeriodicDiffARProcess( name="rt_weekly_diff", offset=0, + period_size=7, log_rt_init_rv=deterministic.DeterministicVariable( name="log_rt", value=jnp.array([0.1, 0.2]) ), @@ -57,7 +58,7 @@ for i in range(0, 30, 7): plt.show() ``` -The implementation of the `RtWeeklyDiffARProcess` (which is an instance of `RtPeriodicDiffARProcess`), uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven. +The implementation of the `RtPeriodicDiffARProcess` uses `repeat_until_n` to repeating values: `repeat_until_n(..., period_size=7)`. The `RtWeeklyDiff` class is a particular case of `RtPeriodicDiff` with a period size of seven. ## Repeated sequences (tiling) diff --git a/pyrenew/process/__init__.py b/pyrenew/process/__init__.py index 638ea45d..45996193 100644 --- a/pyrenew/process/__init__.py +++ b/pyrenew/process/__init__.py @@ -10,10 +10,7 @@ ) from pyrenew.process.periodiceffect import DayOfWeekEffect, PeriodicEffect from pyrenew.process.randomwalk import RandomWalk, StandardNormalRandomWalk -from pyrenew.process.rtperiodicdiffar import ( - RtPeriodicDiffARProcess, - RtWeeklyDiffARProcess, -) +from pyrenew.process.rtperiodicdiffar import RtPeriodicDiffARProcess __all__ = [ "IIDRandomSequence", @@ -25,5 +22,4 @@ "PeriodicEffect", "DayOfWeekEffect", "RtPeriodicDiffARProcess", - "RtWeeklyDiffARProcess", ] diff --git a/pyrenew/process/rtperiodicdiffar.py b/pyrenew/process/rtperiodicdiffar.py index 747d6982..2af85c7b 100644 --- a/pyrenew/process/rtperiodicdiffar.py +++ b/pyrenew/process/rtperiodicdiffar.py @@ -190,49 +190,3 @@ def sample( t_unit=self.t_unit, ), ) - - -class RtWeeklyDiffARProcess(RtPeriodicDiffARProcess): - """ - Weekly Rt with autoregressive first differences. - """ - - def __init__( - self, - name: str, - offset: int, - log_rt_init_rv: RandomVariable, - autoreg_rv: RandomVariable, - periodic_diff_sd_rv: RandomVariable, - ) -> None: - """ - Default constructor for RtWeeklyDiffARProcess class. - - Parameters - ---------- - name : str - Name of the site. - offset : int - Relative point at which data starts, must be between 0 and 6. - log_rt_init_rv : RandomVariable - Log Rt prior for the first two observations. - autoreg_rv : RandomVariable - Autoregressive parameter. - periodic_diff_sd_rv : RandomVariable - Standard deviation of the noise. - - Returns - ------- - None - """ - - super().__init__( - name=name, - offset=offset, - period_size=7, - log_rt_init_rv=log_rt_init_rv, - autoreg_rv=autoreg_rv, - periodic_diff_sd_rv=periodic_diff_sd_rv, - ) - - return None diff --git a/test/test_rtperiodicdiff.py b/test/test_rtperiodicdiff.py index 73bc3382..413a9f9a 100644 --- a/test/test_rtperiodicdiff.py +++ b/test/test_rtperiodicdiff.py @@ -8,7 +8,7 @@ from numpy.testing import assert_array_equal from pyrenew.deterministic import DeterministicVariable -from pyrenew.process import RtWeeklyDiffARProcess +from pyrenew.process import RtPeriodicDiffARProcess def test_rtweeklydiff() -> None: @@ -26,10 +26,11 @@ def test_rtweeklydiff() -> None: "periodic_diff_sd_rv": DeterministicVariable( name="periodic_diff_sd_rv", value=jnp.array([0.1]) ), + "period_size": 7, } duration = 30 - rtwd = RtWeeklyDiffARProcess(**params) + rtwd = RtPeriodicDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): rt = rtwd(duration=duration).rt.value @@ -44,7 +45,7 @@ def test_rtweeklydiff() -> None: # Checking start off a different day of the week params["offset"] = 5 - rtwd = RtWeeklyDiffARProcess(**params) + rtwd = RtPeriodicDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): rt2 = rtwd(duration=duration).rt.value @@ -76,9 +77,10 @@ def test_rtweeklydiff_no_autoregressive() -> None: name="periodic_diff_sd_rv", value=jnp.array([0.1]), ), + "period_size": 7, } - rtwd = RtWeeklyDiffARProcess(**params) + rtwd = RtPeriodicDiffARProcess(**params) duration = 1000 @@ -120,9 +122,10 @@ def test_rtperiodicdiff_smallsample(inits): name="periodic_diff_sd_rv", value=jnp.array([0.1]), ), + "period_size": 7, } - rtwd = RtWeeklyDiffARProcess(**params) + rtwd = RtPeriodicDiffARProcess(**params) with numpyro.handlers.seed(rng_seed=223): rt = rtwd(duration=6).rt.value