Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Improper Distribution #612

Merged
merged 14 commits into from
Jun 6, 2020
Merged

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Jun 2, 2020

This mimics pyro-ppl/pyro#2516 but adds an invalid sample method to generate random samples in the support. I also add batch_shape=() by default to make the usage of this distribution less verbose (there are some distributions in NumPyro also have the argument batch_shape=() by default).

@fritzo I should ask this in your PR: should this distribution have validate_args argument?

cc @vanAmsterdam: hopes that this will unblock your work using master branch

@fehiepsi fehiepsi requested a review from fritzo June 2, 2020 02:39
arg_constraints = {}
unconstrained_radius = 2

def __init__(self, support, event_shape, batch_shape=(), validate_args=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this agree with Pyro? Either batch_shape, event_shape or batch_shape=(), event_shape=()? If you really feel strongly that event_shape, batch_shape=() is optimal, we should change the Pyro version.

👍 I agree a validate_args is a good idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using shape, event_dim=None here where event_dim=len(shape) by default? I like it better than the current resolution. Either having event_shape=() or event_dim=0 will lead to wrong results, while having batch_shape=() or event_dim=len(shape) by default will work for most usage cases (except for models with strong batch shape restriction, e.g. models with enumerate). If we go for this, I think we need to resolve the difference of event_dim in Pyro and event_ndim in NumPyro for Delta distribution too.

Using batch_shape, event_shape is good to me. But I worry that we need to explain those arguments properly. In addition, the previous usage of improper distributions do not require this distinction of shape:

param('x', np.ones(3), constraints.positive)

or

sample('x', dist.LogNormal(zeros(3), 1).mask(False), constraints.positive)

I also think @neerajprad will have more inputs about the interface choice. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I like shape, event_dim=None. My main use case is to replicate an existing distribution, so your suggestion would allow

dist.ImproperUniform(fn.support, fn.shape(), fn.event_dim)

Does that also work in NumPyro?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have .shape() method and event_dim property. The main reason is at the early stage, we ported torch Distribution to NumPyro, rather than Pyro Distribution, so there are some missing pieces of stuff here in NumPyro. We'll add them. Thanks for pointing it out!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either looks fine to me, but I don't fully understand why the default of treating all dims as event dims is a better choice than batch_shape=(), event_shape=(). Could you elaborate? The latter has the advantage that we won't have to specify either batch or event shape for a sample site under plate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is an example where ImproperUniform(...).expand() would creates an inconsistent intermediate distribution:

population = torch.tensor([1000., 2000., 3000.])
with pyro.plate("region", 3):
    incidence = pyro.sample("incidence",
                            ImproperUniform(support=interval(0, population),
                                            batch_shape=(3,),
                                            event_shape=())) 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting... I thought we need to specify event_shape=population.shape here. Let me look at it more closely, probably your notion of event_shape is different from mine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After going through compartmental models and a long thought for this, I still do not understand why it's better to use args batch_shape, event_shape and to introduce that interface to users. Obviously that it does not make the internal code or user code shorter. Actually, we started with a discussion about the interface and we agreed that the solutions are in principle equivalent there. Then we raised some issues with using default values: event_shape=() or event_shape=None. The former requires a warning, while the latter might introduce some bugs or performance issues. Because I don't understand those latter issues, I'll go with Fritz's solution and try to deliver some messages to users through docs (current docs lack those pieces of information).

It isn't a compromise between solutions. It is just because I don't understand the reasoning behind the issues of the better (to me) solution and because I think we have put a lot of thoughts about this issue (which might not worth our time).

Thanks for discussions, @fritzo @neerajprad! I'll go with the current Pyro solution. I tried to simplify pieces of stuff for users, but I might be wrong that it will be actually helpful. :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These discussions were very helpful at least for me to understand the issues involved, and I still don't think I fully do.

@fehiepsi - My suggestion would be to have tests for the cases discussed here (specially cases with plate and batched constraints, and constraints with event shape), and verify that the sample and log_prob shapes agree with what we would expect (using whatever parametrization makes the most sense to you). In @fritzo's example above, for instance, I think even with event_shape=(3,), we'll get incorrectly shaped samples with the existing implementation, but you probably have thought about it much more and have a solution. That will also make it easier to clearly see and discuss different ways to parametrize this, and how the resulting model code looks like.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad I don't have a solution. I just don't see the issue with it. I'll add more tests to understand better.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable.

It would be nice to make the Pyro and NumPyro versions agree, both in interface and semantics. The interface is easy to fix. Maybe to make the semantics agree we could add an extra "sampler" kwarg that defaults to sampler=None but you could also pass sampler=Uniform(-2,2) or something? Or maybe we shouldn't bother to achieve consistency?

@@ -350,6 +350,45 @@ def variance(self):
return np.broadcast_to(self.base_dist.variance, self.batch_shape + self.event_shape)


class ImproperUniform(Distribution):
"""
A helper distribution with zero :meth:`log_prob` and improper
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: "improper sample" -> "arbitrarily distributed sample" or something. The term 'improper' specifically means "non normalized and possibly non-nonrmalizable" as in "improper uniform prior over the real line".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! I'll make the change.

Re sample method: I will raise not ImplementationError here as you did. It is better.

@vanAmsterdam
Copy link
Contributor

This mimics pyro-ppl/pyro#2516 but adds an invalid sample method to generate random samples in the support. I also add batch_shape=() by default to make the usage of this distribution less verbose (there are some distributions in NumPyro also have the argument batch_shape=() by default).

@fritzo I should ask this in your PR: should this distribution have validate_args argument?

cc @vanAmsterdam: hopes that this will unblock your work using master branch

@fehiepsi yes, this (combined) with a sample(..., obs=...) statement gives back the same flexibility as before; I made a tiny educational example for myself: https://gist.github.com/vanAmsterdam/c76164de2f39cc515dbf81b27ffa4b75

def __call__(self, x):
raise NotImplementedError

def check(self, value):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expose this doc so I want to match PyTorch behavior here.

@fehiepsi
Copy link
Member Author

fehiepsi commented Jun 5, 2020

@fritzo @neerajprad I have removed sample method and matched the constructor with the interface of Pyro. Because there is no sample method, I changed the logic of init_to_uniform, init_to_median a bit to bypass the NotImplementedError. Also, please let me know if there is any confusing point in the docs.

I also added the test test_improper_expand but still couldn't catch the expand issue. I still think that things should work as expected. >"< Could you elaborate it a bit more?

@fehiepsi
Copy link
Member Author

fehiepsi commented Jun 5, 2020

Hi @vanAmsterdam, the ordinal-regression gist looks great! Do you want to make a PR for it? I believe ordinal_regression is a very good example for the next NumPyro release.

@vanAmsterdam
Copy link
Contributor

Hi @vanAmsterdam, the ordinal-regression gist looks great! Do you want to make a PR for it? I believe ordinal_regression is a very good example for the next NumPyro release.

sure! here it is: #619 (comment)

fritzo
fritzo previously approved these changes Jun 5, 2020
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, adapting the InitMessenger!

numpyro/distributions/distribution.py Outdated Show resolved Hide resolved
event_shape=(3,)))

model_info = initialize_model(random.PRNGKey(0), model)
assert model_info.param_info.z['incidence'].shape == (3, 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I wanted to verify for this example is that with event_shape=(), this should be (3,). Is that correct?

Copy link
Member Author

@fehiepsi fehiepsi Jun 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me test it. I don't know what happens when users provide an invalid event_shape.

edit: Oh, now I see what you and Fritz meant before. Here, support is batched... interesting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the result is still (3, 3). Something is wrong...

Copy link
Member Author

@fehiepsi fehiepsi Jun 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @neerajprad! I just added that test case and fixed the issue at init_to_uniform. Does the fix sound correct to you?

  • before
prototype_value = np.full(site['fn'].event_shape, np.nan)
unconstrained_event_shape = np.shape(transform.inv(prototype_value))
unconstrained_shape = site['fn'].batch_shape + unconstrained_event_shape
  • after
prototype_value = np.full(site['fn'].shape(), np.nan)
unconstrained_shape = np.shape(transform.inv(prototype_value))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, that makes sense.

samples = site['fn'].sample(rng_key, sample_shape=(num_samples,) + sample_shape)
return np.median(samples, axis=0)
except NotImplementedError:
return init_to_uniform(site)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just return prototype_value = np.full(site['fn'].shape(), np.nan) instead of calling init_to_uniform so that we don't need to additional work like calling random.split?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For substitute_fn, I think it only works if we substitue(seed(model), ...) because seed will give rng_key for the site we want to apply substitute_fn. Inside substitute_fn, we dont use sample primitive, so there is no random.split here IIUC.

population = np.array([1000., 2000., 3000.])
with numpyro.plate("region", 3):
numpyro.sample("incidence",
dist.ImproperUniform(support=constraints.interval(0, population),
Copy link
Member

@neerajprad neerajprad Jun 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a log_prob shape check here?

d = dist.ImproperUniform(support=constraints.interval(0, population)
incidence = numpyro.sample("incidence",
                           d,
                           batch_shape=(3,),
                           event_shape=event_shape))
assert d.log_prob(incidence).shape == (3,)

Copy link
Member

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this distribution, @fehiepsi. Thanks @fritzo for reviewing.

@neerajprad neerajprad merged commit 4c49031 into pyro-ppl:master Jun 6, 2020
@fehiepsi
Copy link
Member Author

fehiepsi commented Jun 6, 2020

Thank @neerajprad and @fritzo for reviewing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants