-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add binomial
variables generator to random.py
#16134
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Large counts
- We should define and test the behavior for large
count
arguments, i.e., the user passing a floating-pointcount
that overflowsint32
or evenint64
. Numpy barfs with a dtype conversion error in this case, but I don't think this code will do that, because it just castscount
straight toprobs.dtype
without going through an integer dtype first. - It is possible to support large
count
arguments directly in floating-point, but then we should define and test what happens whencount
is+inf
. - If we do support
count = +inf
, then for positiveprob
, the result should semantically be+inf
. What's the right way to represent that if the user requests an integer output dtype? - Even without
count = +inf
, we should test that very large counts (near the top of the representable range of floating point) do not cause the rejection sampler to enter an infinite loop, as it might if something overflowed inside and created anan
in the upper bound computation. - Huge counts are a small corner case because the binomial distribution just becomes equivalent to the Gaussian distribution in that limit, but it seems like we still shouldn't gratuitously mess it up.
Fractional counts
What should we do about non-integer count
arguments? The numpy docs say "floating-point arguments are accepted but truncated to integers". This is unsatisfying to me because the binomial distribution has a perfectly good analytic extension to positive-real-valued "number of trials", but it precludes using the inversion sampler. Perhaps compatibility with numpy wins here, though I thought I'd mention it.
Floating-point output dtype
Looking at the code for poisson
, it seems JAX is willing to return samples from integer-valued distributions in a floating-point representation. (In fact, in TFP this is even the default). We should probably be consistent between poisson
and binomial
on this point.
Allowing a floating-point output dtype also gives us nicer sentinel values to use on the output: in all the cases where the sampler is undefined, we can return nan
, and in the case where the counts are +inf
and the probability is positive, we can return +inf
. However, nan
cast to an integer produces 0, which is even worse than -1 to use as a sentinel because it's also a valid answer. Numpy doesn't have this problem because it errors out on inputs whose outputs would be nan
, but I assume JAX doesn't have that luxury.
Hi, @axch thanks for your detailed and constructive review! For Floating-point output dtype, It seems import jax.random as jr
key = jr.PRNGKey(0)
jr.poisson(key, jnp.inf, shape=(2, 3)) and output is
So I think before we could decide which way to treat Finally for Fractional counts I totally agree that a truncation operation is bad, this is another topic! |
Replaced by #18228 |
As proposed and discussed in #480 (comment), #13327, this PR aims to add counterpart for
np.random.binomial
, the implementation is from tensorflow (https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_binomial_op.cc) and tensorflow-probability, and lift it up forjax
.Implementation
for
n * p < 10
, use_binomial_inverse
and_btrs
otherwise. Two functions are mainly from tensorflow or tfp, and refer to the implementation ofrandom.poisson
, which doesmax_iters
restricting and corner cases controlling (mentioned in jax.random.poisson: fix corner cases #9721, jax.random.poisson hangs if lam is NaN #9719) etc.Corner Cases Control
Since
binomial
receives bothn
andp
orcount
andprob
, we need to take care about several casesn
isnan
or negative, butp
is neithernan
nor negative, produce-1
as outputs for consistency withjax.random.poisson
n
is neithernan
nor negative, butp
isnan
or negative, recall that binomial is n-bernoulli,jax.random.bernoulli
will produceFalse
(or 0) fornan
or negative inputs, so for this case, produce 0 as outputs for consistencyn
isnan
or negative andp
isnan
or negative, produce -1 as outputsI'm not sure about above discussions for corner case control, especially for condition 3, need suggestions...
Unluckily, I don't have a gpu or tpu environment right now, so everything is tested on my Apple Silicon macbook (M1), and I make some trivial comparison with
tfp
, correct me if I did something wrong!and results are
And I know that this comparison is not convincing enough, since for
tfp
cpu and noncpu situation is different #13327 (comment)At last, feel free to close this PR, if you think it's not qualified enough!