Skip to content

Commit

Permalink
Add default values, set mask to be boolean
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Nov 15, 2024
1 parent a1d144a commit f521506
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/covvfit/_quasimultinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import numpyro
import numpyro.distributions as distrib
from jaxtyping import Array, Float
from jaxtyping import Array, Bool, Float
from scipy import optimize


Expand Down Expand Up @@ -428,7 +428,7 @@ class _ProblemData(NamedTuple):
n_variants: int
ts: Float[Array, "cities timepoints"]
ys: Float[Array, "cities timepoints variants"]
mask: Float[Array, "cities timepoints"]
mask: Bool[Array, "cities timepoints"]
n_quasimul: Float[Array, "cities timepoints"]
overdispersion: Float[Array, "cities timepoints"]

Expand Down Expand Up @@ -549,10 +549,11 @@ def _validate_and_pad(
padding_value=0.0,
)
out_mask = _create_padded_array(
values=1.0,
values=1,
expected_lengths=_lengths,
padding_length=max_timepoints,
padding_value=0.0,
padding_value=0,
_out_dtype=bool,
)

# Create the array with variant proportions, padded with constant vectors
Expand Down Expand Up @@ -584,13 +585,16 @@ def _quasiloglikelihood_single_city(
n_quasimul: Float[Array, " timepoints"],
overdispersion: Float[Array, " timepoints"],
) -> float:
weight = n_quasimul / overdispersion
logps = calculate_logps(
ts=ts,
midpoints=_add_first_variant(relative_offsets),
growths=_add_first_variant(relative_growths),
)
return jnp.sum(mask[:, None] * weight[:, None] * ys * logps)
# Ensure compatible shapes:
mask = jnp.asarray(mask, dtype=float)[:, None]
weight = (n_quasimul / overdispersion)[:, None]

return jnp.sum(mask * weight * ys * logps)


_RelativeGrowthsAndOffsetsFunction = Callable[
Expand Down Expand Up @@ -635,8 +639,8 @@ def quasiloglikelihood(
def construct_model(
ys: list[jax.Array],
ts: list[jax.Array],
ns: _OverDispersionType,
overdispersion: _OverDispersionType,
ns: _OverDispersionType = 1.0,
overdispersion: _OverDispersionType = 1.0,
sigma_growth: float = 10.0,
sigma_offset: float = 1000.0,
) -> Callable:
Expand Down Expand Up @@ -704,8 +708,8 @@ def model():
def construct_total_loss(
ys: list[jax.Array],
ts: list[jax.Array],
ns: _OverDispersionType,
overdispersion: _OverDispersionType,
ns: _OverDispersionType = 1.0,
overdispersion: _OverDispersionType = 1.0,
accept_theta: bool = True,
average_loss: bool = False,
) -> Callable[[_ThetaType], _Float] | _RelativeGrowthsAndOffsetsFunction:
Expand Down

0 comments on commit f521506

Please sign in to comment.