Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a shifted Stirling approximation to log Beta function #2500

Merged
merged 4 commits into from
May 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ Newton Optimizers
:show-inheritance:
:member-order: bysource

Special Functions
-----------------

.. automodule:: pyro.ops.special
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Tensor Utilities
----------------

Expand Down
3 changes: 2 additions & 1 deletion pyro/contrib/epidemiology/compartmental.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pyro.infer.smcfilter import SMCFailed
from pyro.util import warn_if_nan

from .distributions import set_approx_sample_thresh
from .distributions import set_approx_log_prob_tol, set_approx_sample_thresh
from .util import align_samples, cat2, clamp, quantize, quantize_enumerate

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -144,6 +144,7 @@ def _clear_plates(self):
full_mass = False

@torch.no_grad()
@set_approx_log_prob_tol(0.1)
@set_approx_sample_thresh(100) # This is robust to gross approximation.
def heuristic(self, num_particles=1024, ess_threshold=0.5, retries=10):
"""
Expand Down
28 changes: 28 additions & 0 deletions pyro/contrib/epidemiology/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,34 @@ def set_approx_sample_thresh(thresh):
dist.Binomial.approx_sample_thresh = old


@contextmanager
def set_approx_log_prob_tol(tol):
"""
EXPERIMENTAL Temporarily set the global default value of
``Binomial.approx_log_prob_tol`` and ``BetaBinomial.approx_log_prob_tol``,
thereby decreasing the computational complexity of scoring
:class:`~pyro.distributions.Binomial` and
:class:`~pyro.distributions.BetaBinomial` distributions.

This is used internally by
:class:`~pyro.contrib.epidemiology.compartmental.CompartmentalModel`.

:param tol: New temporary tolold.
:type tol: int or float.
"""
assert isinstance(tol, (float, int))
assert tol > 0
old1 = dist.Binomial.approx_log_prob_tol
old2 = dist.BetaBinomial.approx_log_prob_tol
try:
dist.Binomial.approx_log_prob_tol = tol
dist.BetaBinomial.approx_log_prob_tol = tol
yield
finally:
dist.Binomial.approx_log_prob_tol = old1
dist.BetaBinomial.approx_log_prob_tol = old2


def infection_dist(*,
individual_rate,
num_infectious,
Expand Down
26 changes: 15 additions & 11 deletions pyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

from pyro.distributions.torch import Beta, Binomial, Dirichlet, Gamma, Multinomial, Poisson
from pyro.distributions.torch_distribution import TorchDistribution


def _log_beta(x, y):
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
from pyro.ops.special import log_beta, log_binomial


def _log_beta_1(alpha, value, is_sparse):
Expand Down Expand Up @@ -46,6 +43,12 @@ class BetaBinomial(TorchDistribution):
has_enumerate_support = True
support = Binomial.support

# EXPERIMENTAL If set to a positive value, the .log_prob() method will use
# a shifted Sterling's approximation to the Beta function, reducing
# computational cost from 9 lgamma() evaluations to 12 log() evaluations
# plus arithmetic. Recommended values are between 0.1 and 0.01.
approx_log_prob_tol = 0.

def __init__(self, concentration1, concentration0, total_count=1, validate_args=None):
concentration1, concentration0, total_count = broadcast_all(
concentration1, concentration0, total_count)
Expand Down Expand Up @@ -77,12 +80,13 @@ def sample(self, sample_shape=()):
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
log_factorial_n = torch.lgamma(self.total_count + 1)
log_factorial_k = torch.lgamma(value + 1)
log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
return (log_factorial_n - log_factorial_k - log_factorial_nmk +
_log_beta(value + self.concentration1, self.total_count - value + self.concentration0) -
_log_beta(self.concentration0, self.concentration1))

n = self.total_count
k = value
a = self.concentration1
b = self.concentration0
tol = self.approx_log_prob_tol
return log_binomial(n, k, tol) + log_beta(k + a, n - k + b, tol) - log_beta(a, b, tol)

@property
def mean(self):
Expand Down Expand Up @@ -223,7 +227,7 @@ def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
post_value = self.concentration + value
return -_log_beta(self.concentration, value + 1) - post_value.log() + \
return -log_beta(self.concentration, value + 1) - post_value.log() + \
self.concentration * self.rate.log() - post_value * (1 + self.rate).log()

@property
Expand Down
27 changes: 27 additions & 0 deletions pyro/distributions/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from pyro.distributions.constraints import IndependentConstraint
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions.util import sum_rightmost
from pyro.ops.special import log_binomial


def _clamp_by_zero(x):
# works like clamp(x, min=0) but has grad at 0 is 0.5
return (x.clamp(min=0) + x - x.clamp(max=0)) / 2


class Beta(torch.distributions.Beta, TorchDistributionMixin):
Expand Down Expand Up @@ -36,6 +42,12 @@ class Binomial(torch.distributions.Binomial, TorchDistributionMixin):
# sampling very large populations.
approx_sample_thresh = math.inf

# EXPERIMENTAL If set to a positive value, the .log_prob() method will use
# a shifted Sterling's approximation to the Beta function, reducing
# computational cost from 3 lgamma() evaluations to 4 log() evaluations
# plus arithmetic. Recommended values are between 0.1 and 0.01.
approx_log_prob_tol = 0.

def sample(self, sample_shape=torch.Size()):
if self.approx_sample_thresh < math.inf:
exact = self.total_count <= self.approx_sample_thresh
Expand All @@ -61,6 +73,21 @@ def sample(self, sample_shape=torch.Size()):
return sample
return super().sample(sample_shape)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)

n = self.total_count
k = value
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
# (case logit < 0) = k * logit - n * log1p(e^logit)
# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
# = k * logit - n * logit - n * log1p(e^-logit)
# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
normalize_term = n * (_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p())
return (k * self.logits - normalize_term
+ log_binomial(n, k, tol=self.approx_log_prob_tol))
Comment on lines +80 to +89
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was adapted from the upstream implementation:

        log_factorial_n = torch.lgamma(self.total_count + 1)
        log_factorial_k = torch.lgamma(value + 1)
        log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
        # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
        #     (case logit < 0)              = k * logit - n * log1p(e^logit)
        #     (case logit > 0)              = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
        #                                   = k * logit - n * logit - n * log1p(e^-logit)
        #     (merge two cases)             = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
        normalize_term = (self.total_count * _clamp_by_zero(self.logits)
                          + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits)))
                          - log_factorial_n)
        return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term



# This overloads .log_prob() and .enumerate_support() to speed up evaluating
# log_prob on the support of this variable: we can completely avoid tensor ops
Expand Down
81 changes: 81 additions & 0 deletions pyro/ops/special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import functools
import math
import operator

import torch


def log_beta(x, y, tol=0.):
"""
Computes log Beta function.

When ``tol >= 0.02`` this uses a shifted Stirling's approximation to the
log Beta function. The approximation adapts Stirling's approximation of the
log Gamma function::

lgamma(z) ≈ (z - 1/2) * log(z) - z + log(2 * pi) / 2

to approximate the log Beta function::

log_beta(x, y) ≈ ((x-1/2) * log(x) + (y-1/2) * log(y)
- (x+y-1/2) * log(x+y) + log(2*pi)/2)

The approximation additionally improves accuracy near zero by iteratively
shifting the log Gamma approximation using the recursion::

lgamma(x) = lgamma(x + 1) - log(x)

If this recursion is applied ``n`` times, then absolute error is bounded by
``error < 0.082 / n < tol``, thus we choose ``n`` based on the user
provided ``tol``.

:param torch.Tensor x: A positive tensor.
:param torch.Tensor y: A positive tensor.
:param float tol: Bound on maximum absolute error. Defaults to 0.1. For
very small ``tol``, this function simply defers to :func:`log_beta`.
:rtype: torch.Tensor
"""
assert isinstance(tol, (float, int)) and tol >= 0
if tol < 0.02:
# At small tolerance it is cheaper to defer to torch.lgamma().
return x.lgamma() + y.lgamma() - (x + y).lgamma()

# This bound holds for arbitrary x,y. We could do better with large x,y.
shift = int(math.ceil(0.082 / tol))

xy = x + y
factors = []
for _ in range(shift):
factors.append(xy / (x * y))
x = x + 1
y = y + 1
xy = xy + 1

log_factor = functools.reduce(operator.mul, factors).log()

return (log_factor + (x - 0.5) * x.log() + (y - 0.5) * y.log()
- (xy - 0.5) * xy.log() + (math.log(2 * math.pi) / 2 - shift))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this shift come from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stirling's approximation of each of the lgamma terms is

lgamma(x) ~ (x-1/2) log(x) - x + log(2 pi)/2

Now when we use this to approximate a log_beta function the x terms usually cancel

log_beta(x,y) = lgamma(x) + lgamma(y) - lgamma(x+y)
              = (x-1/2) log(x) + (y-1/2) log(y) - (x+y-1/2) log(x+y)
              + -x + -y + (x+y)       # <--- these cancel if shift = 0
              + (1+1-1) log(2 pi)/2

Now the trick to make this more accurate is to use the lgamma recursion and increment each of x, y, and x+y, which is written in the code as

x = x + 1
y = y + 1
xy = xy + 1

When we do that, the first line works unchanged, the second line gets an extra -1, and the last line remains unchanged. And if we increment each of x, y, and x+y shift-many times and each time gets an extra -1 then we can simply add the term shift.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, it comes from the second line (I thought it cancelled out). Thanks for a clear explanation!



@torch.no_grad()
def log_binomial(n, k, tol=0.):
"""
Computes log binomial coefficient.

When ``tol >= 0.02`` this uses a shifted Stirling's approximation to the
log Beta function via :func:`log_beta`.

:param torch.Tensor n: A nonnegative integer tensor.
:param torch.Tensor k: An integer tensor ranging in ``[0, n]``.
:rtype: torch.Tensor
"""
assert isinstance(tol, (float, int)) and tol >= 0
n_plus_1 = n + 1
if tol < 0.02:
# At small tolerance it is cheaper to defer to torch.lgamma().
return n_plus_1.lgamma() - (k + 1).lgamma() - (n_plus_1 - k).lgamma()

return -n_plus_1.log() - log_beta(k + 1, n_plus_1 - k, tol=tol)
19 changes: 18 additions & 1 deletion tests/distributions/test_binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

import pyro.distributions as dist
from pyro.contrib.epidemiology.distributions import set_approx_sample_thresh
from pyro.contrib.epidemiology.distributions import set_approx_log_prob_tol, set_approx_sample_thresh
from tests.common import assert_close


Expand Down Expand Up @@ -33,3 +34,19 @@ def test_beta_binomial_approx_sample(concentration1, concentration0, total_count

assert_close(expected.mean(), actual.mean(), rtol=0.1)
assert_close(expected.std(), actual.std(), rtol=0.1)


@pytest.mark.parametrize("tol", [
1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1.,
])
def test_binomial_approx_log_prob(tol):
logits = torch.linspace(-10., 10., 100)
k = torch.arange(100.).unsqueeze(-1)
n_minus_k = torch.arange(100.).unsqueeze(-1).unsqueeze(-1)
n = k + n_minus_k

expected = torch.distributions.Binomial(n, logits=logits).log_prob(k)
with set_approx_log_prob_tol(tol):
actual = dist.Binomial(n, logits=logits).log_prob(k)

assert_close(actual, expected, atol=tol)
36 changes: 36 additions & 0 deletions tests/ops/test_special.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from pyro.ops.special import log_beta, log_binomial


@pytest.mark.parametrize("tol", [
1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1.,
])
def test_log_beta_stirling(tol):
x = torch.logspace(-5, 5, 200)
y = x.unsqueeze(-1)

expected = log_beta(x, y)
actual = log_beta(x, y, tol=tol)

assert (actual <= expected).all()
assert (expected < actual + tol).all()


@pytest.mark.parametrize("tol", [
1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1.,
])
def test_log_binomial_stirling(tol):
k = torch.arange(200.)
n_minus_k = k.unsqueeze(-1)
n = k + n_minus_k

# Test binomial coefficient choose(n, k).
expected = (n + 1).lgamma() - (k + 1).lgamma() - (n_minus_k + 1).lgamma()
actual = log_binomial(n, k, tol=tol)

assert (actual - expected).abs().max() < tol