-
-
Notifications
You must be signed in to change notification settings - Fork 984
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a shifted Stirling approximation to log Beta function #2500
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import functools | ||
import math | ||
import operator | ||
|
||
import torch | ||
|
||
|
||
def log_beta(x, y, tol=0.): | ||
""" | ||
Computes log Beta function. | ||
|
||
When ``tol >= 0.02`` this uses a shifted Stirling's approximation to the | ||
log Beta function. The approximation adapts Stirling's approximation of the | ||
log Gamma function:: | ||
|
||
lgamma(z) ≈ (z - 1/2) * log(z) - z + log(2 * pi) / 2 | ||
|
||
to approximate the log Beta function:: | ||
|
||
log_beta(x, y) ≈ ((x-1/2) * log(x) + (y-1/2) * log(y) | ||
- (x+y-1/2) * log(x+y) + log(2*pi)/2) | ||
|
||
The approximation additionally improves accuracy near zero by iteratively | ||
shifting the log Gamma approximation using the recursion:: | ||
|
||
lgamma(x) = lgamma(x + 1) - log(x) | ||
|
||
If this recursion is applied ``n`` times, then absolute error is bounded by | ||
``error < 0.082 / n < tol``, thus we choose ``n`` based on the user | ||
provided ``tol``. | ||
|
||
:param torch.Tensor x: A positive tensor. | ||
:param torch.Tensor y: A positive tensor. | ||
:param float tol: Bound on maximum absolute error. Defaults to 0.1. For | ||
very small ``tol``, this function simply defers to :func:`log_beta`. | ||
:rtype: torch.Tensor | ||
""" | ||
assert isinstance(tol, (float, int)) and tol >= 0 | ||
if tol < 0.02: | ||
# At small tolerance it is cheaper to defer to torch.lgamma(). | ||
return x.lgamma() + y.lgamma() - (x + y).lgamma() | ||
|
||
# This bound holds for arbitrary x,y. We could do better with large x,y. | ||
shift = int(math.ceil(0.082 / tol)) | ||
|
||
xy = x + y | ||
factors = [] | ||
for _ in range(shift): | ||
factors.append(xy / (x * y)) | ||
x = x + 1 | ||
y = y + 1 | ||
xy = xy + 1 | ||
|
||
log_factor = functools.reduce(operator.mul, factors).log() | ||
|
||
return (log_factor + (x - 0.5) * x.log() + (y - 0.5) * y.log() | ||
- (xy - 0.5) * xy.log() + (math.log(2 * math.pi) / 2 - shift)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where does this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stirling's approximation of each of the
Now when we use this to approximate a
Now the trick to make this more accurate is to use the
When we do that, the first line works unchanged, the second line gets an extra There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah, it comes from the second line (I thought it cancelled out). Thanks for a clear explanation! |
||
|
||
|
||
@torch.no_grad() | ||
def log_binomial(n, k, tol=0.): | ||
""" | ||
Computes log binomial coefficient. | ||
|
||
When ``tol >= 0.02`` this uses a shifted Stirling's approximation to the | ||
log Beta function via :func:`log_beta`. | ||
|
||
:param torch.Tensor n: A nonnegative integer tensor. | ||
:param torch.Tensor k: An integer tensor ranging in ``[0, n]``. | ||
:rtype: torch.Tensor | ||
""" | ||
assert isinstance(tol, (float, int)) and tol >= 0 | ||
n_plus_1 = n + 1 | ||
if tol < 0.02: | ||
# At small tolerance it is cheaper to defer to torch.lgamma(). | ||
return n_plus_1.lgamma() - (k + 1).lgamma() - (n_plus_1 - k).lgamma() | ||
|
||
return -n_plus_1.log() - log_beta(k + 1, n_plus_1 - k, tol=tol) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
|
||
from pyro.ops.special import log_beta, log_binomial | ||
|
||
|
||
@pytest.mark.parametrize("tol", [ | ||
1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1., | ||
]) | ||
def test_log_beta_stirling(tol): | ||
x = torch.logspace(-5, 5, 200) | ||
y = x.unsqueeze(-1) | ||
|
||
expected = log_beta(x, y) | ||
actual = log_beta(x, y, tol=tol) | ||
|
||
assert (actual <= expected).all() | ||
assert (expected < actual + tol).all() | ||
|
||
|
||
@pytest.mark.parametrize("tol", [ | ||
1e-8, 1e-6, 1e-4, 1e-2, 0.02, 0.05, 0.1, 0.2, 0.1, 1., | ||
]) | ||
def test_log_binomial_stirling(tol): | ||
k = torch.arange(200.) | ||
n_minus_k = k.unsqueeze(-1) | ||
n = k + n_minus_k | ||
|
||
# Test binomial coefficient choose(n, k). | ||
expected = (n + 1).lgamma() - (k + 1).lgamma() - (n_minus_k + 1).lgamma() | ||
actual = log_binomial(n, k, tol=tol) | ||
|
||
assert (actual - expected).abs().max() < tol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was adapted from the upstream implementation: