Skip to content

Commit

Permalink
Allow negative total_count in Extended distributions (#2439)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Apr 24, 2020
1 parent 560aafb commit e284f64
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
6 changes: 3 additions & 3 deletions examples/sir_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def reparameterized_discrete_model(args, data):
dist.ExtendedBinomial(I_prev, prob_i),
obs=I2R)
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2I.clamp(min=0), rho),
dist.ExtendedBinomial(S2I, rho),
obs=datum)


Expand Down Expand Up @@ -301,7 +301,7 @@ def continuous_model(args, data):
dist.ExtendedBinomial(I_prev, prob_i),
obs=I2R)
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2I.clamp(min=0), rho),
dist.ExtendedBinomial(S2I, rho),
obs=datum)


Expand Down Expand Up @@ -408,7 +408,7 @@ def vectorized_model(args, data):
# Compute probability factors.
S2I_logp = dist.ExtendedBinomial(S_prev, -(rate_s * I_prev).expm1()).log_prob(S2I)
I2R_logp = dist.ExtendedBinomial(I_prev, prob_i).log_prob(I2R)
obs_logp = dist.ExtendedBinomial(S2I.clamp(min=0), rho).log_prob(data)
obs_logp = dist.ExtendedBinomial(S2I, rho).log_prob(data)

# Manually perform variable elimination.
logp = S_logp + (I_logp + obs_logp) + S2I_logp + I2R_logp
Expand Down
2 changes: 1 addition & 1 deletion pyro/contrib/epidemiology/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,5 +98,5 @@ def transition_bwd(self, params, prev, curr, t):
dist.ExtendedBinomial(prev["I"], prob_i),
obs=I2R)
pyro.sample("obs_{}".format(t),
dist.ExtendedBinomial(S2I.clamp(min=0), rho),
dist.ExtendedBinomial(S2I, rho),
obs=self.data[t])
16 changes: 12 additions & 4 deletions pyro/distributions/extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
class ExtendedBinomial(Binomial):
"""
EXPERIMENTAL :class:`~pyro.distributions.Binomial` distribution extended to
have logical support the entire integers. Numerical support is still the
integer interval ``[0, total_count]``
have logical support the entire integers and to allow arbitrary integer
``total_count``. Numerical support is still the integer interval ``[0,
total_count]``.
"""
arg_constraints = {"total_count": constraints.integer,
"probs": constraints.unit_interval,
"logits": constraints.real}
support = constraints.integer

def log_prob(self, value):
Expand All @@ -26,9 +30,13 @@ def log_prob(self, value):
class ExtendedBetaBinomial(BetaBinomial):
"""
EXPERIMENTAL :class:`~pyro.distributions.BetaBinomial` distribution
extended to have logical support the entire integers. Numerical support is
still the integer interval ``[0, total_count]``
extended to have logical support the entire integers and to allow arbitrary
integer ``total_count``. Numerical support is still the integer interval
``[0, total_count]``.
"""
arg_constraints = {"concentration1": constraints.positive,
"concentration0": constraints.positive,
"total_count": constraints.integer}
support = constraints.integer

def log_prob(self, value):
Expand Down
10 changes: 10 additions & 0 deletions tests/distributions/test_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def test_extended_binomial():
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
total_count = torch.arange(-10, 0.)
d = dist.ExtendedBinomial(total_count, 0.5)
assert (d.log_prob(data) == -math.inf).all()


def test_extended_beta_binomial():
concentration1 = torch.tensor([1.0, 2.0, 1.0])
Expand Down Expand Up @@ -65,3 +70,8 @@ def test_extended_beta_binomial():
# Check on value error.
with pytest.raises(ValueError):
d2.log_prob(torch.tensor(0.5))

# Check on negative total_count.
total_count = torch.arange(-10, 0.)
d = dist.ExtendedBetaBinomial(1.5, 1.5, total_count)
assert (d.log_prob(data) == -math.inf).all()

0 comments on commit e284f64

Please sign in to comment.