-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow arbitrary sample_sample in Predictive #639
Comments
I am not opposed to it, but would like to understand the use case better. For case 1, we should be able to flatten the chain dim and use |
Yes, you are right that this is not an issue. I just found it more convenient to avoid K = 3
D = 4
def gmm(data=None, N=None):
if N is None:
assert data is not None
N = data.shape[0]
with numpyro.plate("num_components", K):
means = numpyro.sample("means", dist.Normal().expand([D]).to_event(1))
scale_trils = numpyro.sample("scale_trils", dist.LKJCholesky(D))
mixture_weights = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
with numpyro.plate("num_observations", N):
z = numpyro.sample("assignments", dist.Categorical(mixture_weights))
numpyro.sample("obs", dist.MultivariateNormal(means[z], scale_tril=scale_trils[z]), obs=data)
nuts_kernel = NUTS(gmm)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), data)
samples = mcmc.get_samples()
last_sample = {k: v[-1:] for k, v in samples}
predictions = Predictive(gmm, last_sample, return_sites=['assignments', 'obs'])(random.PRNGKey(1))
# there are two grosses with this: last_sample has leading dimension with size 1;
# the output also has size-1 leading dimension.
# typically, we use handler `trace` for this job but it is too verbose
# (requires using `seed`, `trace`, `get_trace`,...) |
@fehiepsi - Sorry for a late reply on this. Is the reason for this that we have additional enumeration dims on the left, or is this a more general issue? (Additionally, can you send me some runnable code even if I have to run it against your branch). I am trying to understand all the interface issues involved with Predictive and assess if it shares any commonality with Pyro. Fritz has also pointed out a couple of issues with Predictive in Pyro (the last one was about being able to use samples from |
I guess the purpose of that Pyro PR is to align sample dimensions so that those samples can be substituted into a vectorized model, which is similar to the output of |
Currently, we assume the sample_shape in Predictive is
(num_samples,)
, but it might be more convenient to allow arbitrary shapes here. Some motivations are:group_by_chain=True
. For models with discrete latent variables, we don't replay the model to collect the value of discrete sites so we need to use Predictive to collect those samples. Of course, we can do{k: v.reshape(...)}
but it is a bit verbose.Cons:
num_samples
tosample_shape
or adding a new keywordsample_shape
for this purpose.What do you think, @neerajprad?
The text was updated successfully, but these errors were encountered: