Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚧 RSA vectorized prior achieved; L0 doesn't enumerate #2908

Draft
wants to merge 4 commits into
base: dev
Choose a base branch
from

Conversation

jmuchovej
Copy link

@jmuchovej jmuchovej commented Aug 1, 2021

structured_prior will correctly enumerate (and specify individual sample sites - vs making a single sample site of a tensor over all enumerations).

listener0 doesn't enumerate over the support of structured_prior. It strictly visits the highest probability sample from the structured_prior.

`structured_prior` will correctly enumerates (and specifies individual
sample sites - vs making a single sample site of a tensor over all
enumerations).

`listener0` doesn't enumerate over the support of `structured_prior`. It
strictly visits the highest probability sample from the
`structured_prior`.
`structured_prior` will correctly enumerates (and specifies individual
sample sites - vs making a single sample site of a tensor over all
enumerations).

`listener0` doesn't enumerate over the support of `structured_prior`. It
strictly visits the highest probability sample from the
`structured_prior`.
Conflicts:
- Altered execution counters in `examples/rsa/generics-vectorized.ipynb`
- Upgrade `maxsize` of memoize from 10 -> 100
@jmuchovej jmuchovej changed the title 🚧 Vectorized prior achieved; L0 doesn't enumerate 🚧 RSA vectorized prior achieved; L0 doesn't enumerate Aug 1, 2021
@eb8680 eb8680 self-requested a review August 3, 2021 14:08
@jmuchovej
Copy link
Author

jmuchovej commented Oct 4, 2021

@eb8680 Just a bump on this. I can update with some more work I've done with this – but right now I've hit a snag where the following doesn't appear to work:

@config_enumerate
def listener0(utterance: Tensor, threshold: Tensor, States: RSAMarginal) -> Tensor:
    state = pyro.sample("state", States)
    # ...
    return state

RSAMarginal is much like HashingMarginal, but it extracts the support from a TracePosterior (as TraceEnum_ELBO appears to do in TraceEnum_ELBO._traces).

I can push the changes so you can inspect them if it's helpful – but at the moment, I'm not sure how to "signal" that RSAMarginal can be enumerated as though dist.Categorical-like classes appear to do.

I've searched the forum pretty extensively and haven't seen any questions about this.

@eb8680
Copy link
Member

eb8680 commented Oct 12, 2021

@jmuchovej sorry, I'll try to look at it this week.

@jmuchovej
Copy link
Author

Also, a side note, I've actually replicated everything (that is, the models we're looking to build) in WebPPL. The only way to get a reasonable run-time was to use of their cache(..., N) function. Just running the model (no data fitting) takes <5min as a result, compared to hours without cache(..., N). (I didn't let the model run without cache run to completion.)

There's a @memoize(...) decorator in the existing HashingMarginal, which I would assume should be analogous to WebPPL's cache function, but it doesn't seem to be... (If memory serves – I've tried with maxsize=100000 and no dice.)

I'm not exactly sure how WebPPL's cache works under the hood (in terms of relating it to Pyro nomenclature), but I would imagine it works similarly to mem. mem uses the arguments to compute the hash for the lookup table. (This seems very similar to what _dist_and_values does, yet there seems to be little to no performance uplift from memoizing _dist_and_values.)

WebPPL's docs on cache and on mem.

(Later this week, or early next week, I can try replicating the results I recall if that would be helpful.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants