Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow arbitrary sample_sample in Predictive #639

Closed
fehiepsi opened this issue Jun 17, 2020 · 4 comments
Closed

Allow arbitrary sample_sample in Predictive #639

fehiepsi opened this issue Jun 17, 2020 · 4 comments

Comments

@fehiepsi
Copy link
Member

Currently, we assume the sample_shape in Predictive is (num_samples,), but it might be more convenient to allow arbitrary shapes here. Some motivations are:

  • We can apply Predictive on samples with group_by_chain=True. For models with discrete latent variables, we don't replay the model to collect the value of discrete sites so we need to use Predictive to collect those samples. Of course, we can do {k: v.reshape(...)} but it is a bit verbose.
  • I found that using Predictive is more convenient than using handlers to collect trace and get sites' values. I can recall that this question is already raised in the forum. Actually, this is the main motivation that I raise this issue. In my gmm example, I just want to draw the prediction for "one" cluster means/scales sample.

Cons:

  • We might change num_samples to sample_shape or adding a new keyword sample_shape for this purpose.

What do you think, @neerajprad?

@neerajprad
Copy link
Member

I am not opposed to it, but would like to understand the use case better. For case 1, we should be able to flatten the chain dim and use Predictive, right? For the second case, could you point me to the code so that I understand what the issues are.

@fehiepsi
Copy link
Member Author

fehiepsi commented Jun 18, 2020

Yes, you are right that this is not an issue. I just found it more convenient to avoid reshape stuff. Here is my code for gmm (there is nothing special with it), predictions will have the leading dimension with size 1. If we allow something like sample_shape=(), the output won't have that leading dimension. At the time I write this, I found that batch_ndim might be a better argument (by default, it will be 1; with chains, it will be 2; with single sample, it will be 0).

K = 3
D = 4

def gmm(data=None, N=None):
    if N is None:
        assert data is not None
        N = data.shape[0]

    with numpyro.plate("num_components", K):
        means = numpyro.sample("means", dist.Normal().expand([D]).to_event(1))
        scale_trils = numpyro.sample("scale_trils", dist.LKJCholesky(D))

    mixture_weights = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
    with numpyro.plate("num_observations", N):
        z = numpyro.sample("assignments", dist.Categorical(mixture_weights))
        numpyro.sample("obs", dist.MultivariateNormal(means[z], scale_tril=scale_trils[z]), obs=data)

nuts_kernel = NUTS(gmm)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
mcmc.run(random.PRNGKey(0), data)
samples = mcmc.get_samples()

last_sample = {k: v[-1:] for k, v in samples}
predictions = Predictive(gmm, last_sample, return_sites=['assignments', 'obs'])(random.PRNGKey(1))
# there are two grosses with this: last_sample has leading dimension with size 1;
# the output also has size-1 leading dimension.
# typically, we use handler `trace` for this job but it is too verbose
# (requires using `seed`, `trace`, `get_trace`,...)

@neerajprad
Copy link
Member

neerajprad commented Jun 23, 2020

last_sample has leading dimension with size 1;

@fehiepsi - Sorry for a late reply on this. Is the reason for this that we have additional enumeration dims on the left, or is this a more general issue? (Additionally, can you send me some runnable code even if I have to run it against your branch). I am trying to understand all the interface issues involved with Predictive and assess if it shares any commonality with Pyro.

Fritz has also pointed out a couple of issues with Predictive in Pyro (the last one was about being able to use samples from Predictive to replay from a model with an outermost batch dim which requires aligning all the samples, see pyro-ppl/pyro#2466 (comment)) and I want to understand if these might be related.

@fehiepsi
Copy link
Member Author

I guess the purpose of that Pyro PR is to align sample dimensions so that those samples can be substituted into a vectorized model, which is similar to the output of Predictive.get_vectorized_trace method. I didn't intend to address that in this issue. My proposal is a bit simpler: to avoid having to reshape the sample_shape to use Predictive. Currently, we require sample_shape = (num_samples,) but I think it is better to also support sample_shape = (num_chains, num_samples per chain) or sample_shape = (). As you pointed out in your earlier comment, we can achieve that purpose by reshaping sample_shape to (-1,) and reshaping the output to sample_shape. But I think that it will be better to provide a convenient usage pattern for users. With batch_ndim arg, users can use Predictive for various pattern of sample_shape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants