Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add binomial variables generator to random.py #16134

Closed
wants to merge 1 commit into from

Conversation

JiaYaobo
Copy link
Contributor

As proposed and discussed in #480 (comment), #13327, this PR aims to add counterpart for np.random.binomial, the implementation is from tensorflow (https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_binomial_op.cc) and tensorflow-probability, and lift it up for jax.

  • Implementation
    for n * p < 10, use _binomial_inverse and _btrs otherwise. Two functions are mainly from tensorflow or tfp, and refer to the implementation of random.poisson , which does max_iters restricting and corner cases controlling (mentioned in jax.random.poisson: fix corner cases #9721, jax.random.poisson hangs if lam is NaN #9719) etc.

  • Corner Cases Control
    Since binomial receives both n and p or count and prob, we need to take care about several cases

  1. n is nan or negative, but p is neither nan nor negative, produce -1 as outputs for consistency with jax.random.poisson
  2. n is neither nan nor negative, but p is nan or negative, recall that binomial is n-bernoulli, jax.random.bernoulli will produce False (or 0) for nan or negative inputs, so for this case, produce 0 as outputs for consistency
  3. For both n is nan or negative and p is nan or negative, produce -1 as outputs
    I'm not sure about above discussions for corner case control, especially for condition 3, need suggestions...
  • Performance
    Unluckily, I don't have a gpu or tpu environment right now, so everything is tested on my Apple Silicon macbook (M1), and I make some trivial comparison with tfp, correct me if I did something wrong!
import timeit

import jax.random as jr
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp

key = jr.PRNGKey(0)


print(
    timeit.timeit(
        lambda: jr.binomial(
            key, 10, 0.2, shape=(1000000,), dtype=jnp.int32
        ).block_until_ready(),
        number=10,
    ) / 10
)

binom1 = tfp.distributions.Binomial(total_count=10, probs=0.2)

print(
    timeit.timeit(
        lambda: binom1.sample((1000000,), key).block_until_ready(),
        number=10,
    ) / 10
)

print(
    timeit.timeit(
        lambda: jr.binomial(
            key, 100, 0.2, shape=(100000,), dtype=jnp.int32
        ).block_until_ready(),
        number=10,
    ) / 10
)

binom2 = tfp.distributions.Binomial(total_count=100, probs=0.2)

print(
    timeit.timeit(
        lambda: binom1.sample((100000,), key).block_until_ready(),
        number=10,
    ) / 10
)

and results are

0.15127999170217662 # my
0.6121754749910906 # tfp
0.0645201083039865 #my
0.46493839169852436 #tfp

And I know that this comparison is not convincing enough, since for tfp cpu and noncpu situation is different #13327 (comment)

At last, feel free to close this PR, if you think it's not qualified enough!

@jakevdp jakevdp self-assigned this May 25, 2023
@jakevdp jakevdp self-requested a review May 25, 2023 17:09
Copy link
Contributor

@axch axch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large counts

  • We should define and test the behavior for large count arguments, i.e., the user passing a floating-point count that overflows int32 or even int64. Numpy barfs with a dtype conversion error in this case, but I don't think this code will do that, because it just casts count straight to probs.dtype without going through an integer dtype first.
  • It is possible to support large count arguments directly in floating-point, but then we should define and test what happens when count is +inf.
  • If we do support count = +inf, then for positive prob, the result should semantically be +inf. What's the right way to represent that if the user requests an integer output dtype?
  • Even without count = +inf, we should test that very large counts (near the top of the representable range of floating point) do not cause the rejection sampler to enter an infinite loop, as it might if something overflowed inside and created a nan in the upper bound computation.
  • Huge counts are a small corner case because the binomial distribution just becomes equivalent to the Gaussian distribution in that limit, but it seems like we still shouldn't gratuitously mess it up.

Fractional counts

What should we do about non-integer count arguments? The numpy docs say "floating-point arguments are accepted but truncated to integers". This is unsatisfying to me because the binomial distribution has a perfectly good analytic extension to positive-real-valued "number of trials", but it precludes using the inversion sampler. Perhaps compatibility with numpy wins here, though I thought I'd mention it.

Floating-point output dtype

Looking at the code for poisson, it seems JAX is willing to return samples from integer-valued distributions in a floating-point representation. (In fact, in TFP this is even the default). We should probably be consistent between poisson and binomial on this point.

Allowing a floating-point output dtype also gives us nicer sentinel values to use on the output: in all the cases where the sampler is undefined, we can return nan, and in the case where the counts are +inf and the probability is positive, we can return +inf. However, nan cast to an integer produces 0, which is even worse than -1 to use as a sentinel because it's also a valid answer. Numpy doesn't have this problem because it errors out on inputs whose outputs would be nan, but I assume JAX doesn't have that luxury.

@jakevdp jakevdp assigned axch and unassigned jakevdp May 26, 2023
@JiaYaobo
Copy link
Contributor Author

Hi, @axch thanks for your detailed and constructive review!

For Floating-point output dtype,

It seems poisson return an integer type in JAX( https://github.com/google/jax/blob/main/jax/_src/random.py#L1372 ), and -1 for invalid lam inputs, hmm, but as you mentioned, huge counts x probs becomes to a Guassian Distribution, large lam make same effect for poisson, so the original choice of the return type of poisson can be discussed more... and as for
Large Counts
So, I test inf for poisson, seems there's bug

import jax.random as jr
key = jr.PRNGKey(0)

jr.poisson(key, jnp.inf, shape=(2, 3))

and output is

Array([[2147483647, 2147483647, 2147483647],
       [         0, 2147483647, 2147483647]], dtype=int32)

So I think before we could decide which way to treat binomial on large counts, we should make poisson behave well, tfp.Poisson returns inf as you said, I guess the original poisson random generator aims to be compatible with numpy, so they choose integer return type, but for large lam, numpy can throw a err.

Finally for Fractional counts

I totally agree that a truncation operation is bad, this is another topic!

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 20, 2023

Replaced by #18228

@jakevdp jakevdp closed this Nov 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants