Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random.binomial and random.multinomial #13327

Open
carlosgmartin opened this issue Nov 19, 2022 · 11 comments · May be fixed by #25688
Open

Add random.binomial and random.multinomial #13327

carlosgmartin opened this issue Nov 19, 2022 · 11 comments · May be fixed by #25688
Assignees
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@carlosgmartin
Copy link
Contributor

carlosgmartin commented Nov 19, 2022

Add JAX counterparts of numpy.random.binomial and numpy.random.multinomial to jax.random package. See #480 (comment) for context. A current workaround is using the JAX substrate of TensorFlow Probability:

from jax import random, numpy as jnp
from tensorflow_probability.substrates import jax as tfp

def binomial(key, n, p, shape=()):
    return tfp.distributions.Binomial(n, probs=p).sample(
        seed=key,
        sample_shape=shape,
    )

def multinomial(key, n, p, shape=()):
    return tfp.distributions.Multinomial(n, probs=p).sample(
        seed=key,
        sample_shape=shape,
    )

key = random.PRNGKey(0)

key, subkey = random.split(key)
print(binomial(subkey, 9, .8, [2, 5]))

key, subkey = random.split(key)
print(multinomial(subkey, 9, jnp.array([.7, .2, .1]), [4]))

Output:

[[7. 8. 8. 7. 5.]
 [5. 8. 9. 8. 5.]]
[[7. 1. 1.]
 [4. 3. 2.]
 [5. 3. 1.]
 [4. 4. 1.]]
@carlosgmartin carlosgmartin added the enhancement New feature or request label Nov 19, 2022
@zhangqiaorjc
Copy link
Collaborator

@sharadmv, is it possible to port tfp's implementation to jax.random?

@sharadmv
Copy link
Collaborator

Possibly, yes. The implementation is fairly complex though, and makes some accelerator-specific tradeoffs IIRC. cc: @srvasude @brianwa84.

@jakevdp jakevdp added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Nov 28, 2022
@hylkedonker
Copy link

hylkedonker commented Dec 6, 2023

Any plans to add random.multinomial now that random.binomial has been merged?
The multinomial sampler could be implemented either as sequence of negative binomials or by repeated categorical draws.
See also the Wikipedia entry on multinomial variate generation).

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 6, 2023

Thanks for reaching out – I don't know of anyone working on this currently.

@andportnoy
Copy link
Contributor

andportnoy commented Jan 28, 2024

A workaround if you want pure JAX is to take the log of your probabilities vector (nonnegative, sums to 1):

jax.random.categorical(key, jnp.log(p))

(Based on the "contract" of jax.random.categorical:

logits: Unnormalized log probabilities of the categorical distribution(s) to sample from,
so that softmax(logits, axis) gives the corresponding probabilities.

)

That can't be the most efficient way to sample though...

@brianwa84
Copy link
Contributor

brianwa84 commented Jan 28, 2024 via email

@ShawnGeller
Copy link

I would use a multinomial feature if added! The categorical workaround is a nonstarter if I want a lot of samples.

@twallema
Copy link

I would also benefit from a multinomial feature to experiment with speeding up the numerical simulation of jump processes in my simulation code pySODM.

@yoavram
Copy link

yoavram commented Dec 15, 2024

👍

@cthorrez
Copy link

I would benefit from multinomial as well. I'm currently using the numpy one for efficient bootstrap sampling and not having it here makes it harder to make the switch to jax.

@carlosgmartin carlosgmartin linked a pull request Dec 27, 2024 that will close this issue
@carlosgmartin
Copy link
Contributor Author

I've created a PR for this: #25688.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

Successfully merging a pull request may close this issue.