-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Improper Distribution #612
Conversation
arg_constraints = {} | ||
unconstrained_radius = 2 | ||
|
||
def __init__(self, support, event_shape, batch_shape=(), validate_args=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this agree with Pyro? Either batch_shape, event_shape
or batch_shape=(), event_shape=()
? If you really feel strongly that event_shape, batch_shape=()
is optimal, we should change the Pyro version.
👍 I agree a validate_args
is a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using shape, event_dim=None
here where event_dim=len(shape)
by default? I like it better than the current resolution. Either having event_shape=()
or event_dim=0
will lead to wrong results, while having batch_shape=()
or event_dim=len(shape)
by default will work for most usage cases (except for models with strong batch shape restriction, e.g. models with enumerate). If we go for this, I think we need to resolve the difference of event_dim
in Pyro and event_ndim
in NumPyro for Delta distribution too.
Using batch_shape, event_shape
is good to me. But I worry that we need to explain those arguments properly. In addition, the previous usage of improper distributions do not require this distinction of shape:
param('x', np.ones(3), constraints.positive)
or
sample('x', dist.LogNormal(zeros(3), 1).mask(False), constraints.positive)
I also think @neerajprad will have more inputs about the interface choice. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I like shape, event_dim=None
. My main use case is to replicate an existing distribution, so your suggestion would allow
dist.ImproperUniform(fn.support, fn.shape(), fn.event_dim)
Does that also work in NumPyro?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have .shape()
method and event_dim
property. The main reason is at the early stage, we ported torch Distribution to NumPyro, rather than Pyro Distribution, so there are some missing pieces of stuff here in NumPyro. We'll add them. Thanks for pointing it out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either looks fine to me, but I don't fully understand why the default of treating all dims as event dims is a better choice than batch_shape=(), event_shape=()
. Could you elaborate? The latter has the advantage that we won't have to specify either batch or event shape for a sample site under plate
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an example where ImproperUniform(...).expand()
would creates an inconsistent intermediate distribution:
population = torch.tensor([1000., 2000., 3000.])
with pyro.plate("region", 3):
incidence = pyro.sample("incidence",
ImproperUniform(support=interval(0, population),
batch_shape=(3,),
event_shape=()))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting... I thought we need to specify event_shape=population.shape here. Let me look at it more closely, probably your notion of event_shape is different from mine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After going through compartmental models and a long thought for this, I still do not understand why it's better to use args batch_shape, event_shape
and to introduce that interface to users. Obviously that it does not make the internal code or user code shorter. Actually, we started with a discussion about the interface and we agreed that the solutions are in principle equivalent there. Then we raised some issues with using default values: event_shape=()
or event_shape=None
. The former requires a warning, while the latter might introduce some bugs or performance issues. Because I don't understand those latter issues, I'll go with Fritz's solution and try to deliver some messages to users through docs (current docs lack those pieces of information).
It isn't a compromise between solutions. It is just because I don't understand the reasoning behind the issues of the better (to me) solution and because I think we have put a lot of thoughts about this issue (which might not worth our time).
Thanks for discussions, @fritzo @neerajprad! I'll go with the current Pyro solution. I tried to simplify pieces of stuff for users, but I might be wrong that it will be actually helpful. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These discussions were very helpful at least for me to understand the issues involved, and I still don't think I fully do.
@fehiepsi - My suggestion would be to have tests for the cases discussed here (specially cases with plate
and batched constraints, and constraints with event shape), and verify that the sample and log_prob shapes agree with what we would expect (using whatever parametrization makes the most sense to you). In @fritzo's example above, for instance, I think even with event_shape=(3,)
, we'll get incorrectly shaped samples with the existing implementation, but you probably have thought about it much more and have a solution. That will also make it easier to clearly see and discuss different ways to parametrize this, and how the resulting model code looks like.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@neerajprad I don't have a solution. I just don't see the issue with it. I'll add more tests to understand better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable.
It would be nice to make the Pyro and NumPyro versions agree, both in interface and semantics. The interface is easy to fix. Maybe to make the semantics agree we could add an extra "sampler" kwarg that defaults to sampler=None
but you could also pass sampler=Uniform(-2,2)
or something? Or maybe we shouldn't bother to achieve consistency?
@@ -350,6 +350,45 @@ def variance(self): | |||
return np.broadcast_to(self.base_dist.variance, self.batch_shape + self.event_shape) | |||
|
|||
|
|||
class ImproperUniform(Distribution): | |||
""" | |||
A helper distribution with zero :meth:`log_prob` and improper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: "improper sample" -> "arbitrarily distributed sample" or something. The term 'improper' specifically means "non normalized and possibly non-nonrmalizable" as in "improper uniform prior over the real line".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining! I'll make the change.
Re sample
method: I will raise not ImplementationError here as you did. It is better.
@fehiepsi yes, this (combined) with a |
def __call__(self, x): | ||
raise NotImplementedError | ||
|
||
def check(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We expose this doc so I want to match PyTorch behavior here.
@fritzo @neerajprad I have removed I also added the test |
Hi @vanAmsterdam, the ordinal-regression gist looks great! Do you want to make a PR for it? I believe |
sure! here it is: #619 (comment) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, adapting the InitMessenger
!
test/test_infer_util.py
Outdated
event_shape=(3,))) | ||
|
||
model_info = initialize_model(random.PRNGKey(0), model) | ||
assert model_info.param_info.z['incidence'].shape == (3, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I wanted to verify for this example is that with event_shape=()
, this should be (3,). Is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me test it. I don't know what happens when users provide an invalid event_shape.
edit: Oh, now I see what you and Fritz meant before. Here, support is batched... interesting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, the result is still (3, 3)
. Something is wrong...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @neerajprad! I just added that test case and fixed the issue at init_to_uniform
. Does the fix sound correct to you?
- before
prototype_value = np.full(site['fn'].event_shape, np.nan)
unconstrained_event_shape = np.shape(transform.inv(prototype_value))
unconstrained_shape = site['fn'].batch_shape + unconstrained_event_shape
- after
prototype_value = np.full(site['fn'].shape(), np.nan)
unconstrained_shape = np.shape(transform.inv(prototype_value))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, that makes sense.
samples = site['fn'].sample(rng_key, sample_shape=(num_samples,) + sample_shape) | ||
return np.median(samples, axis=0) | ||
except NotImplementedError: | ||
return init_to_uniform(site) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just return prototype_value = np.full(site['fn'].shape(), np.nan)
instead of calling init_to_uniform
so that we don't need to additional work like calling random.split
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For substitute_fn, I think it only works if we substitue(seed(model), ...) because seed will give rng_key for the site we want to apply substitute_fn. Inside substitute_fn, we dont use sample primitive, so there is no random.split here IIUC.
test/test_infer_util.py
Outdated
population = np.array([1000., 2000., 3000.]) | ||
with numpyro.plate("region", 3): | ||
numpyro.sample("incidence", | ||
dist.ImproperUniform(support=constraints.interval(0, population), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add a log_prob
shape check here?
d = dist.ImproperUniform(support=constraints.interval(0, population)
incidence = numpyro.sample("incidence",
d,
batch_shape=(3,),
event_shape=event_shape))
assert d.log_prob(incidence).shape == (3,)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank @neerajprad and @fritzo for reviewing! |
This mimics pyro-ppl/pyro#2516 but adds an invalid
sample
method to generate random samples in thesupport
. I also addbatch_shape=()
by default to make the usage of this distribution less verbose (there are some distributions in NumPyro also have the argumentbatch_shape=()
by default).@fritzo I should ask this in your PR: should this distribution have
validate_args
argument?cc @vanAmsterdam: hopes that this will unblock your work using
master
branch