Skip to content

Commit

Permalink
Use log_beta and log_binomial in Binomial.log_prob()
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed May 24, 2020
1 parent 6994cb4 commit ddd6fe4
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 38 deletions.
14 changes: 9 additions & 5 deletions pyro/contrib/epidemiology/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def set_approx_sample_thresh(thresh):
def set_approx_log_prob_tol(tol):
"""
EXPERIMENTAL Temporarily set the global default value of
``BetaBinomial.approx_log_prob_tol``, thereby decreasing the computational
complexity of scoring :class:`~pyro.distributions.BetaBinomial`
distributions.
``Binomial.approx_log_prob_tol`` and ``BetaBinomial.approx_log_prob_tol``,
thereby decreasing the computational complexity of scoring
:class:`~pyro.distributions.Binomial` and
:class:`~pyro.distributions.BetaBinomial` distributions.
This is used internally by
:class:`~pyro.contrib.epidemiology.compartmental.CompartmentalModel`.
Expand All @@ -54,12 +55,15 @@ def set_approx_log_prob_tol(tol):
"""
assert isinstance(tol, (float, int))
assert tol > 0
old = dist.BetaBinomial.approx_log_prob_tol
old1 = dist.Binomial.approx_log_prob_tol
old2 = dist.BetaBinomial.approx_log_prob_tol
try:
dist.Binomial.approx_log_prob_tol = tol
dist.BetaBinomial.approx_log_prob_tol = tol
yield
finally:
dist.BetaBinomial.approx_log_prob_tol = old
dist.Binomial.approx_log_prob_tol = old1
dist.BetaBinomial.approx_log_prob_tol = old2


def infection_dist(*,
Expand Down
17 changes: 6 additions & 11 deletions pyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools
import numbers

import torch
Expand All @@ -10,7 +9,7 @@

from pyro.distributions.torch import Beta, Binomial, Dirichlet, Gamma, Multinomial, Poisson
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.ops.special import log_beta, log_beta_stirling
from pyro.ops.special import log_beta, log_binomial


def _log_beta_1(alpha, value, is_sparse):
Expand Down Expand Up @@ -44,9 +43,9 @@ class BetaBinomial(TorchDistribution):
has_enumerate_support = True
support = Binomial.support

# EXPERIMENTAL If set to a positive value, the .log_prob() method will use a
# shifted Sterling's approximation to the Beta function, reducing
# computational cost from 9 lgamma() evaluations to four log() evaluations
# EXPERIMENTAL If set to a positive value, the .log_prob() method will use
# a shifted Sterling's approximation to the Beta function, reducing
# computational cost from 9 lgamma() evaluations to 12 log() evaluations
# plus arithmetic. Recommended values are between 0.1 and 0.01.
approx_log_prob_tol = 0.

Expand Down Expand Up @@ -82,16 +81,12 @@ def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

if self.approx_log_prob_tol > 0:
lbeta = functools.partial(log_beta_stirling, tol=self.approx_log_prob_tol)
else:
lbeta = log_beta

n = self.total_count
k = value
a = self.concentration1
b = self.concentration0
return lbeta(k + a, n - k + b) - lbeta(a, b) - lbeta(k + 1, n - k + 1) - n.log1p()
tol = self.approx_log_prob_tol
return log_binomial(n, k, tol) + log_beta(k + a, n - k + b, tol) - log_beta(a, b, tol)

@property
def mean(self):
Expand Down
27 changes: 27 additions & 0 deletions pyro/distributions/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from pyro.distributions.constraints import IndependentConstraint
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions.util import sum_rightmost
from pyro.ops.special import log_binomial


def _clamp_by_zero(x):
# works like clamp(x, min=0) but has grad at 0 is 0.5
return (x.clamp(min=0) + x - x.clamp(max=0)) / 2


class Beta(torch.distributions.Beta, TorchDistributionMixin):
Expand Down Expand Up @@ -36,6 +42,12 @@ class Binomial(torch.distributions.Binomial, TorchDistributionMixin):
# sampling very large populations.
approx_sample_thresh = math.inf

# EXPERIMENTAL If set to a positive value, the .log_prob() method will use
# a shifted Sterling's approximation to the Beta function, reducing
# computational cost from 3 lgamma() evaluations to 4 log() evaluations
# plus arithmetic. Recommended values are between 0.1 and 0.01.
approx_log_prob_tol = 0.

def sample(self, sample_shape=torch.Size()):
if self.approx_sample_thresh < math.inf:
exact = self.total_count <= self.approx_sample_thresh
Expand All @@ -61,6 +73,21 @@ def sample(self, sample_shape=torch.Size()):
return sample
return super().sample(sample_shape)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

n = self.total_count
k = value
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
# (case logit < 0) = k * logit - n * log1p(e^logit)
# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
# = k * logit - n * logit - n * log1p(e^-logit)
# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
normalize_term = n * (_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p())
return (k * self.logits - normalize_term
+ log_binomial(n, k, tol=self.approx_log_prob_tol))


# This overloads .log_prob() and .enumerate_support() to speed up evaluating
# log_prob on the support of this variable: we can completely avoid tensor ops
Expand Down
50 changes: 32 additions & 18 deletions pyro/ops/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,16 @@
import math
import operator


def log_beta(x, y):
"""
Log Beta function.
:param torch.Tensor x: A positive tensor.
:param torch.Tensor y: A positive tensor.
"""
return x.lgamma() + y.lgamma() - (x + y).lgamma()
import torch


def log_beta_stirling(x, y, tol=0.1):
def log_beta(x, y, tol=0.):
"""
Shifted Stirling's approximation to the log Beta function.
Computes log Beta function.
This is useful as a cheaper alternative to :func:`log_beta`.
This adapts Stirling's approximation of the log Gamma function::
When ``tol >= 0.02`` this uses a shifted Stirling's approximation to the
log Beta function. The approximation adapts Stirling's approximation of the
log Gamma function::
lgamma(z) ≈ (z - 1/2) * log(z) - z + log(2 * pi) / 2
Expand All @@ -31,8 +23,8 @@ def log_beta_stirling(x, y, tol=0.1):
log_beta(x, y) ≈ ((x-1/2) * log(x) + (y-1/2) * log(y)
- (x+y-1/2) * log(x+y) + log(2*pi)/2)
This additionally improves accuracy near zero by iteratively shifting
the log Gamma approximation using the recursion::
The approximation additionally improves accuracy near zero by iteratively
shifting the log Gamma approximation using the recursion::
lgamma(x) = lgamma(x + 1) - log(x)
Expand All @@ -44,11 +36,12 @@ def log_beta_stirling(x, y, tol=0.1):
:param torch.Tensor y: A positive tensor.
:param float tol: Bound on maximum absolute error. Defaults to 0.1. For
very small ``tol``, this function simply defers to :func:`log_beta`.
:rtype: torch.Tensor
"""
assert isinstance(tol, (float, int)) and tol >= 0
if tol < 0.02:
# Eventually it is cheaper to defer to torch.lgamma().
return log_beta(x, y)
# At small tolerance it is cheaper to defer to torch.lgamma().
return x.lgamma() + y.lgamma() - (x + y).lgamma()

# This bound holds for arbitrary x,y. We could do better with large x,y.
shift = int(math.ceil(0.082 / tol))
Expand All @@ -65,3 +58,24 @@ def log_beta_stirling(x, y, tol=0.1):

return (log_factor + (x - 0.5) * x.log() + (y - 0.5) * y.log()
- (xy - 0.5) * xy.log() + (math.log(2 * math.pi) / 2 - shift))


@torch.no_grad()
def log_binomial(n, k, tol=0.):
"""
Computes log binomial coefficient.
When ``tol >= 0.02`` this uses a shifted Stirling's approximation to the
log Beta function via :func:`log_beta`.
:param torch.Tensor n: A nonnegative integer tensor.
:param torch.Tensor k: An integer tensor ranging in ``[0, n]``.
:rtype: torch.Tensor
"""
assert isinstance(tol, (float, int)) and tol >= 0
n_plus_1 = n + 1
if tol < 0.02:
# At small tolerance it is cheaper to defer to torch.lgamma().
return n_plus_1.lgamma() - (k + 1).lgamma() - (n_plus_1 - k).lgamma()

return -n_plus_1.log() - log_beta(k + 1, n_plus_1 - k, tol=tol)
19 changes: 18 additions & 1 deletion tests/distributions/test_binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

import pyro.distributions as dist
from pyro.contrib.epidemiology.distributions import set_approx_sample_thresh
from pyro.contrib.epidemiology.distributions import set_approx_log_prob_tol, set_approx_sample_thresh
from tests.common import assert_close


Expand Down Expand Up @@ -33,3 +34,19 @@ def test_beta_binomial_approx_sample(concentration1, concentration0, total_count

assert_close(expected.mean(), actual.mean(), rtol=0.1)
assert_close(expected.std(), actual.std(), rtol=0.1)


@pytest.mark.parametrize("tol", [
1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1.,
])
def test_binomial_approx_log_prob(tol):
logits = torch.linspace(-10., 10., 100)
k = torch.arange(100.).unsqueeze(-1)
n_minus_k = torch.arange(100.).unsqueeze(-1).unsqueeze(-1)
n = k + n_minus_k

expected = torch.distributions.Binomial(n, logits=logits).log_prob(k)
with set_approx_log_prob_tol(tol):
actual = dist.Binomial(n, logits=logits).log_prob(k)

assert_close(actual, expected, atol=tol)
21 changes: 18 additions & 3 deletions tests/ops/test_special.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,33 @@
import pytest
import torch

from pyro.ops.special import log_beta, log_beta_stirling
from pyro.ops.special import log_beta, log_binomial


@pytest.mark.parametrize("tol", [
1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1.,
])
def test_log_beta_stirling(tol):
x = torch.logspace(-5, 5, 100)
x = torch.logspace(-5, 5, 200)
y = x.unsqueeze(-1)

expected = log_beta(x, y)
actual = log_beta_stirling(x, y, tol=tol)
actual = log_beta(x, y, tol=tol)

assert (actual <= expected).all()
assert (expected < actual + tol).all()


@pytest.mark.parametrize("tol", [
1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1.,
])
def test_log_binomial_stirling(tol):
k = torch.arange(200.)
n_minus_k = k.unsqueeze(-1)
n = k + n_minus_k

# Test binomial coefficient choose(n, k).
expected = (n + 1).lgamma() - (k + 1).lgamma() - (n_minus_k + 1).lgamma()
actual = log_binomial(n, k, tol=tol)

assert (actual - expected).abs().max() < tol

0 comments on commit ddd6fe4

Please sign in to comment.