From ddd6fe4e1fa6b5c6a06f9db2d4455f3298eb5cd5 Mon Sep 17 00:00:00 2001 From: Fritz Obermeyer Date: Sun, 24 May 2020 10:37:32 -0700 Subject: [PATCH] Use log_beta and log_binomial in Binomial.log_prob() --- pyro/contrib/epidemiology/distributions.py | 14 +++--- pyro/distributions/conjugate.py | 17 +++----- pyro/distributions/torch.py | 27 ++++++++++++ pyro/ops/special.py | 50 ++++++++++++++-------- tests/distributions/test_binomial.py | 19 +++++++- tests/ops/test_special.py | 21 +++++++-- 6 files changed, 110 insertions(+), 38 deletions(-) diff --git a/pyro/contrib/epidemiology/distributions.py b/pyro/contrib/epidemiology/distributions.py index 5856ad5f48..ef166e6cb5 100644 --- a/pyro/contrib/epidemiology/distributions.py +++ b/pyro/contrib/epidemiology/distributions.py @@ -42,9 +42,10 @@ def set_approx_sample_thresh(thresh): def set_approx_log_prob_tol(tol): """ EXPERIMENTAL Temporarily set the global default value of - ``BetaBinomial.approx_log_prob_tol``, thereby decreasing the computational - complexity of scoring :class:`~pyro.distributions.BetaBinomial` - distributions. + ``Binomial.approx_log_prob_tol`` and ``BetaBinomial.approx_log_prob_tol``, + thereby decreasing the computational complexity of scoring + :class:`~pyro.distributions.Binomial` and + :class:`~pyro.distributions.BetaBinomial` distributions. This is used internally by :class:`~pyro.contrib.epidemiology.compartmental.CompartmentalModel`. @@ -54,12 +55,15 @@ def set_approx_log_prob_tol(tol): """ assert isinstance(tol, (float, int)) assert tol > 0 - old = dist.BetaBinomial.approx_log_prob_tol + old1 = dist.Binomial.approx_log_prob_tol + old2 = dist.BetaBinomial.approx_log_prob_tol try: + dist.Binomial.approx_log_prob_tol = tol dist.BetaBinomial.approx_log_prob_tol = tol yield finally: - dist.BetaBinomial.approx_log_prob_tol = old + dist.Binomial.approx_log_prob_tol = old1 + dist.BetaBinomial.approx_log_prob_tol = old2 def infection_dist(*, diff --git a/pyro/distributions/conjugate.py b/pyro/distributions/conjugate.py index 7c457f1cd2..7783fba614 100644 --- a/pyro/distributions/conjugate.py +++ b/pyro/distributions/conjugate.py @@ -1,7 +1,6 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -import functools import numbers import torch @@ -10,7 +9,7 @@ from pyro.distributions.torch import Beta, Binomial, Dirichlet, Gamma, Multinomial, Poisson from pyro.distributions.torch_distribution import TorchDistribution -from pyro.ops.special import log_beta, log_beta_stirling +from pyro.ops.special import log_beta, log_binomial def _log_beta_1(alpha, value, is_sparse): @@ -44,9 +43,9 @@ class BetaBinomial(TorchDistribution): has_enumerate_support = True support = Binomial.support - # EXPERIMENTAL If set to a positive value, the .log_prob() method will use a - # shifted Sterling's approximation to the Beta function, reducing - # computational cost from 9 lgamma() evaluations to four log() evaluations + # EXPERIMENTAL If set to a positive value, the .log_prob() method will use + # a shifted Sterling's approximation to the Beta function, reducing + # computational cost from 9 lgamma() evaluations to 12 log() evaluations # plus arithmetic. Recommended values are between 0.1 and 0.01. approx_log_prob_tol = 0. @@ -82,16 +81,12 @@ def log_prob(self, value): if self._validate_args: self._validate_sample(value) - if self.approx_log_prob_tol > 0: - lbeta = functools.partial(log_beta_stirling, tol=self.approx_log_prob_tol) - else: - lbeta = log_beta - n = self.total_count k = value a = self.concentration1 b = self.concentration0 - return lbeta(k + a, n - k + b) - lbeta(a, b) - lbeta(k + 1, n - k + 1) - n.log1p() + tol = self.approx_log_prob_tol + return log_binomial(n, k, tol) + log_beta(k + a, n - k + b, tol) - log_beta(a, b, tol) @property def mean(self): diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index e1b5c9b3e8..d999fc2b4f 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -9,6 +9,12 @@ from pyro.distributions.constraints import IndependentConstraint from pyro.distributions.torch_distribution import TorchDistributionMixin from pyro.distributions.util import sum_rightmost +from pyro.ops.special import log_binomial + + +def _clamp_by_zero(x): + # works like clamp(x, min=0) but has grad at 0 is 0.5 + return (x.clamp(min=0) + x - x.clamp(max=0)) / 2 class Beta(torch.distributions.Beta, TorchDistributionMixin): @@ -36,6 +42,12 @@ class Binomial(torch.distributions.Binomial, TorchDistributionMixin): # sampling very large populations. approx_sample_thresh = math.inf + # EXPERIMENTAL If set to a positive value, the .log_prob() method will use + # a shifted Sterling's approximation to the Beta function, reducing + # computational cost from 3 lgamma() evaluations to 4 log() evaluations + # plus arithmetic. Recommended values are between 0.1 and 0.01. + approx_log_prob_tol = 0. + def sample(self, sample_shape=torch.Size()): if self.approx_sample_thresh < math.inf: exact = self.total_count <= self.approx_sample_thresh @@ -61,6 +73,21 @@ def sample(self, sample_shape=torch.Size()): return sample return super().sample(sample_shape) + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + + n = self.total_count + k = value + # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p) + # (case logit < 0) = k * logit - n * log1p(e^logit) + # (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p) + # = k * logit - n * logit - n * log1p(e^-logit) + # (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|) + normalize_term = n * (_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p()) + return (k * self.logits - normalize_term + + log_binomial(n, k, tol=self.approx_log_prob_tol)) + # This overloads .log_prob() and .enumerate_support() to speed up evaluating # log_prob on the support of this variable: we can completely avoid tensor ops diff --git a/pyro/ops/special.py b/pyro/ops/special.py index 75189cb31b..6b2f218260 100644 --- a/pyro/ops/special.py +++ b/pyro/ops/special.py @@ -5,24 +5,16 @@ import math import operator - -def log_beta(x, y): - """ - Log Beta function. - - :param torch.Tensor x: A positive tensor. - :param torch.Tensor y: A positive tensor. - """ - return x.lgamma() + y.lgamma() - (x + y).lgamma() +import torch -def log_beta_stirling(x, y, tol=0.1): +def log_beta(x, y, tol=0.): """ - Shifted Stirling's approximation to the log Beta function. + Computes log Beta function. - This is useful as a cheaper alternative to :func:`log_beta`. - - This adapts Stirling's approximation of the log Gamma function:: + When ``tol >= 0.02`` this uses a shifted Stirling's approximation to the + log Beta function. The approximation adapts Stirling's approximation of the + log Gamma function:: lgamma(z) ≈ (z - 1/2) * log(z) - z + log(2 * pi) / 2 @@ -31,8 +23,8 @@ def log_beta_stirling(x, y, tol=0.1): log_beta(x, y) ≈ ((x-1/2) * log(x) + (y-1/2) * log(y) - (x+y-1/2) * log(x+y) + log(2*pi)/2) - This additionally improves accuracy near zero by iteratively shifting - the log Gamma approximation using the recursion:: + The approximation additionally improves accuracy near zero by iteratively + shifting the log Gamma approximation using the recursion:: lgamma(x) = lgamma(x + 1) - log(x) @@ -44,11 +36,12 @@ def log_beta_stirling(x, y, tol=0.1): :param torch.Tensor y: A positive tensor. :param float tol: Bound on maximum absolute error. Defaults to 0.1. For very small ``tol``, this function simply defers to :func:`log_beta`. + :rtype: torch.Tensor """ assert isinstance(tol, (float, int)) and tol >= 0 if tol < 0.02: - # Eventually it is cheaper to defer to torch.lgamma(). - return log_beta(x, y) + # At small tolerance it is cheaper to defer to torch.lgamma(). + return x.lgamma() + y.lgamma() - (x + y).lgamma() # This bound holds for arbitrary x,y. We could do better with large x,y. shift = int(math.ceil(0.082 / tol)) @@ -65,3 +58,24 @@ def log_beta_stirling(x, y, tol=0.1): return (log_factor + (x - 0.5) * x.log() + (y - 0.5) * y.log() - (xy - 0.5) * xy.log() + (math.log(2 * math.pi) / 2 - shift)) + + +@torch.no_grad() +def log_binomial(n, k, tol=0.): + """ + Computes log binomial coefficient. + + When ``tol >= 0.02`` this uses a shifted Stirling's approximation to the + log Beta function via :func:`log_beta`. + + :param torch.Tensor n: A nonnegative integer tensor. + :param torch.Tensor k: An integer tensor ranging in ``[0, n]``. + :rtype: torch.Tensor + """ + assert isinstance(tol, (float, int)) and tol >= 0 + n_plus_1 = n + 1 + if tol < 0.02: + # At small tolerance it is cheaper to defer to torch.lgamma(). + return n_plus_1.lgamma() - (k + 1).lgamma() - (n_plus_1 - k).lgamma() + + return -n_plus_1.log() - log_beta(k + 1, n_plus_1 - k, tol=tol) diff --git a/tests/distributions/test_binomial.py b/tests/distributions/test_binomial.py index 1d6a04a950..7b077736a4 100644 --- a/tests/distributions/test_binomial.py +++ b/tests/distributions/test_binomial.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch import pyro.distributions as dist -from pyro.contrib.epidemiology.distributions import set_approx_sample_thresh +from pyro.contrib.epidemiology.distributions import set_approx_log_prob_tol, set_approx_sample_thresh from tests.common import assert_close @@ -33,3 +34,19 @@ def test_beta_binomial_approx_sample(concentration1, concentration0, total_count assert_close(expected.mean(), actual.mean(), rtol=0.1) assert_close(expected.std(), actual.std(), rtol=0.1) + + +@pytest.mark.parametrize("tol", [ + 1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1., +]) +def test_binomial_approx_log_prob(tol): + logits = torch.linspace(-10., 10., 100) + k = torch.arange(100.).unsqueeze(-1) + n_minus_k = torch.arange(100.).unsqueeze(-1).unsqueeze(-1) + n = k + n_minus_k + + expected = torch.distributions.Binomial(n, logits=logits).log_prob(k) + with set_approx_log_prob_tol(tol): + actual = dist.Binomial(n, logits=logits).log_prob(k) + + assert_close(actual, expected, atol=tol) diff --git a/tests/ops/test_special.py b/tests/ops/test_special.py index b5c6b3d9ad..fad3d236d1 100644 --- a/tests/ops/test_special.py +++ b/tests/ops/test_special.py @@ -4,18 +4,33 @@ import pytest import torch -from pyro.ops.special import log_beta, log_beta_stirling +from pyro.ops.special import log_beta, log_binomial @pytest.mark.parametrize("tol", [ 1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1., ]) def test_log_beta_stirling(tol): - x = torch.logspace(-5, 5, 100) + x = torch.logspace(-5, 5, 200) y = x.unsqueeze(-1) expected = log_beta(x, y) - actual = log_beta_stirling(x, y, tol=tol) + actual = log_beta(x, y, tol=tol) assert (actual <= expected).all() assert (expected < actual + tol).all() + + +@pytest.mark.parametrize("tol", [ + 1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1., +]) +def test_log_binomial_stirling(tol): + k = torch.arange(200.) + n_minus_k = k.unsqueeze(-1) + n = k + n_minus_k + + # Test binomial coefficient choose(n, k). + expected = (n + 1).lgamma() - (k + 1).lgamma() - (n_minus_k + 1).lgamma() + actual = log_binomial(n, k, tol=tol) + + assert (actual - expected).abs().max() < tol