-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a shifted Stirling approximation to log Beta function #2500
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation LGTM! I just have one question regarding shift
term (the tests passed so I might miss some point).
log_factor = functools.reduce(operator.mul, factors).log() | ||
|
||
return (log_factor + (x - 0.5) * x.log() + (y - 0.5) * y.log() | ||
- (xy - 0.5) * xy.log() + (math.log(2 * math.pi) / 2 - shift)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where does this shift
come from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stirling's approximation of each of the lgamma
terms is
lgamma(x) ~ (x-1/2) log(x) - x + log(2 pi)/2
Now when we use this to approximate a log_beta
function the x
terms usually cancel
log_beta(x,y) = lgamma(x) + lgamma(y) - lgamma(x+y)
= (x-1/2) log(x) + (y-1/2) log(y) - (x+y-1/2) log(x+y)
+ -x + -y + (x+y) # <--- these cancel if shift = 0
+ (1+1-1) log(2 pi)/2
Now the trick to make this more accurate is to use the lgamma
recursion and increment each of x
, y
, and x+y
, which is written in the code as
x = x + 1
y = y + 1
xy = xy + 1
When we do that, the first line works unchanged, the second line gets an extra -1
, and the last line remains unchanged. And if we increment each of x
, y
, and x+y
shift
-many times and each time gets an extra -1
then we can simply add the term shift
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah, it comes from the second line (I thought it cancelled out). Thanks for a clear explanation!
I'll quick add this approximation also to |
n = self.total_count | ||
k = value | ||
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p) | ||
# (case logit < 0) = k * logit - n * log1p(e^logit) | ||
# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p) | ||
# = k * logit - n * logit - n * log1p(e^-logit) | ||
# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|) | ||
normalize_term = n * (_clamp_by_zero(self.logits) + self.logits.abs().neg().exp().log1p()) | ||
return (k * self.logits - normalize_term | ||
+ log_binomial(n, k, tol=self.approx_log_prob_tol)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was adapted from the upstream implementation:
log_factorial_n = torch.lgamma(self.total_count + 1)
log_factorial_k = torch.lgamma(value + 1)
log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
# k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p)
# (case logit < 0) = k * logit - n * log1p(e^logit)
# (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p)
# = k * logit - n * logit - n * log1p(e^-logit)
# (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|)
normalize_term = (self.total_count * _clamp_by_zero(self.logits)
+ self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits)))
- log_factorial_n)
return value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term
I've updated to use |
Addresses #2426
Blocking #2498
See derivation notebook.
This implements a cheap approximate
log_beta()
function for use inBinomial.log_prob()
andBetaBinomial.log_prob()
. This is motivated by profiling results of #2498 which showed thattorch.lgamma()
was taking more than half of inference time. This PR also refactors to use the new approximation inCompartimentalModel.heuristic()
, and I plan to use the approximation in more places in #2498.Accuracy
Users specify a
tol
parameter which is absolute error (in log space) and is worst for small values ofx,y
inlog_beta(x,y)
. When does this happen inpyro.contrib.epidemiology
? The only circumstance when one ofx,y
is small is in superspreading models with smallk
and tiny infectious population, sok * I <= 1
; in all other cases this approximation is more accurate than the statedtol
.Speed
This approximation speeds up
log_beta()
by about 2.5x at loosest tolerance 0.1.log_beta(x, y)
log_beta_stirling(x, y, tol=0.1)
log_beta_stirling(x, y, tol=0.05)
log_beta_stirling(x, y, tol=0.02)
Tested
log_beta()
andlog_binomial()
at a variety of tolerancesBinomial.log_prob()
at a variety of tolerancesBetaBinomial
is covered by existing testsCompartmentalModel.heuristic
is covered by superspreader smoke testsExtendedBinomial
andExtendedBetaBinomial