Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a shifted Stirling approximation to log Beta function #2500

Merged
merged 4 commits into from
May 24, 2020

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented May 24, 2020

Addresses #2426
Blocking #2498
See derivation notebook.

This implements a cheap approximate log_beta() function for use in Binomial.log_prob() and BetaBinomial.log_prob(). This is motivated by profiling results of #2498 which showed that torch.lgamma() was taking more than half of inference time. This PR also refactors to use the new approximation in CompartimentalModel.heuristic(), and I plan to use the approximation in more places in #2498.

Accuracy

Users specify a tol parameter which is absolute error (in log space) and is worst for small values of x,y in log_beta(x,y). When does this happen in pyro.contrib.epidemiology? The only circumstance when one of x,y is small is in superspreading models with small k and tiny infectious population, so k * I <= 1; in all other cases this approximation is more accurate than the stated tol.

Speed

This approximation speeds up log_beta() by about 2.5x at loosest tolerance 0.1.

time function
116ms log_beta(x, y)
46ms log_beta_stirling(x, y, tol=0.1)
74ms log_beta_stirling(x, y, tol=0.05)
171ms log_beta_stirling(x, y, tol=0.02)

Tested

  • added unit tests for log_beta() and log_binomial() at a variety of tolerances
  • added unit test of Binomial.log_prob() at a variety of tolerances
  • refactoring of BetaBinomial is covered by existing tests
  • new usage in CompartmentalModel.heuristic is covered by superspreader smoke tests
  • added test for NAN gradients in ExtendedBinomial and ExtendedBetaBinomial

@fritzo fritzo requested a review from fehiepsi May 24, 2020 03:40
fehiepsi
fehiepsi previously approved these changes May 24, 2020
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation LGTM! I just have one question regarding shift term (the tests passed so I might miss some point).

log_factor = functools.reduce(operator.mul, factors).log()

return (log_factor + (x - 0.5) * x.log() + (y - 0.5) * y.log()
- (xy - 0.5) * xy.log() + (math.log(2 * math.pi) / 2 - shift))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this shift come from?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stirling's approximation of each of the lgamma terms is

lgamma(x) ~ (x-1/2) log(x) - x + log(2 pi)/2

Now when we use this to approximate a log_beta function the x terms usually cancel

log_beta(x,y) = lgamma(x) + lgamma(y) - lgamma(x+y)
              = (x-1/2) log(x) + (y-1/2) log(y) - (x+y-1/2) log(x+y)
              + -x + -y + (x+y)       # <--- these cancel if shift = 0
              + (1+1-1) log(2 pi)/2

Now the trick to make this more accurate is to use the lgamma recursion and increment each of x, y, and x+y, which is written in the code as

x = x + 1
y = y + 1
xy = xy + 1

When we do that, the first line works unchanged, the second line gets an extra -1, and the last line remains unchanged. And if we increment each of x, y, and x+y shift-many times and each time gets an extra -1 then we can simply add the term shift.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, it comes from the second line (I thought it cancelled out). Thanks for a clear explanation!

@fritzo
Copy link
Member Author

fritzo commented May 24, 2020

I'll quick add this approximation also to dist.Binomial, which I hadn't realized also uses lgamma

@fritzo fritzo dismissed stale reviews from martinjankowiak and fehiepsi via ddd6fe4 May 24, 2020 17:37
Comment on lines +80 to +89
n = self.total_count
k = value
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
# (case logit < 0) = k * logit - n * log1p(e^logit)
# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
# = k * logit - n * logit - n * log1p(e^-logit)
# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
normalize_term = n * (_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p())
return (k * self.logits - normalize_term
+ log_binomial(n, k, tol=self.approx_log_prob_tol))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was adapted from the upstream implementation:

        log_factorial_n = torch.lgamma(self.total_count + 1)
        log_factorial_k = torch.lgamma(value + 1)
        log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
        # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
        #     (case logit < 0)              = k * logit - n * log1p(e^logit)
        #     (case logit > 0)              = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
        #                                   = k * logit - n * logit - n * log1p(e^-logit)
        #     (merge two cases)             = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
        normalize_term = (self.total_count * _clamp_by_zero(self.logits)
                          + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits)))
                          - log_factorial_n)
        return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term

@fritzo
Copy link
Member Author

fritzo commented May 24, 2020

I've updated to use log_beta() in Binomial.log_prob(). Sorry for the duplicate reviews!

@fritzo fritzo requested a review from fehiepsi May 24, 2020 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants