-
-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose ImproperUniform distribution #2516
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I just have a comment for default args of this distribution.
from .util import broadcast_shape | ||
|
||
|
||
class ImproperUniform(TorchDistribution): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to make sure that you already consider it: should it be better to use the shorter name Improper
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm... I guess there are other improper distributions, e.g. Unit
.
""" | ||
arg_constraints = {} | ||
|
||
def __init__(self, support, batch_shape=torch.Size(), event_shape=torch.Size()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think event_shape
is a requirement. Otherwise, we can set it to None
, so that the default event_shape can be (0,) * biject_to(support).event_dim
. This way, log_prob will have a more appropriate shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think i'd rather make them both required than perform a computation on the constraint. Currently our constraints are pretty sloppy- we should be using IndependentConstraint
more thoroughly and many of the computations would be wrong. This isn't a big issue yet because we don't use biject_to(support).event_dim
for much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree that we should require event_shape. But for batch_shape, I guess it is better to let it have default shape as before (i.e. __init__(self, support, event_shape, batch_shape=torch.Size())
) to makes the code less verbose.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree a default of batch_shape=()
makes sense, but I find it disorienting for batch_shape to be specified on the right of event_shape: I always see things as batch_shape, event_shape
. 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha, sounds reasonable to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @fritzo , I think this is a great distribution to have!
Thanks for your careful review @fehiepsi! I'm glad you care about interface elegance 😄 |
Addresses #2426
Follows up on @fehiepsi's comment #2495 (review)
This exposes
ImproperUniform
as a public distribution. It is currently used only inSplitReparam
. I plan to use it also for sampling auxiliary variables inCompartmentalModel._sequential_model()
and._vectorized_model()
. Those classes currently usedist.Uniform(...).mask(False).expand(shape).to_event()
, but I would like to generalize the logic to arbitrary constraints, while adding the safety of a not-implemented.sample()
method.Tested
SplitReparam
tests