Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚧 RSA vectorized prior achieved; L0 doesn't enumerate #2908

Draft
wants to merge 4 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions examples/rsa/generics-vectorized.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
],
"outputs": [],
},
{
"cell_type": "code",
"execution_count": 2,
"source": [
"import argparse\n",
"import collections\n",
"import numbers\n",
"\n",
"import torch\n",
"from search_inference import HashingMarginal, Search, memoize\n",
"\n",
"import pyro\n",
"import pyro.distributions as dist\n",
"import pyro.poutine as poutine\n",
"from pyro.infer import config_enumerate, TraceEnum_ELBO\n",
"from pyro.ops.indexing import Vindex\n",
"\n",
"torch.set_default_dtype(torch.float64) # double precision for numerical stability\n",
"torch.manual_seed(42)\n",
"\n",
"utterances = [\n",
" \"generic is true\", \"generic is false\",\n",
" \"mu\", \"some\", \"most\", \"all\",\n",
"]\n",
"\n",
"from vectorized_search import (\n",
" VectoredSearch as VSearch,\n",
" VectoredHashingMarginal as VHMarginal,\n",
")\n",
"\n",
"\n",
"def Marginal(fn):\n",
" return memoize(lambda *args: VHMarginal(VSearch(config_enumerate(fn)).run(*args)))\n",
"\n",
"\n",
"Params = collections.namedtuple(\"Params\", [\"theta\", \"gamma\", \"delta\"])\n",
"beta_bins = torch.tensor([0., 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99])"
],
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fd519eede70>"
]
},
"metadata": {},
"execution_count": 2
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 5,
"source": [
"@Marginal\n",
"def structured_prior(params: Params) -> torch.Tensor:\n",
" # computing the Beta pdf for discretized bins above for enumerated Search\n",
" shape_alpha = params.gamma * params.delta - 1\n",
" shape_beta = (1. - params.gamma) * params.delta - 1\n",
" discrete_bins = (beta_bins ** shape_alpha) * ((1. - beta_bins) ** shape_beta) * params.theta\n",
" discrete_bins[0] = (1 - params.theta)\n",
" idx = pyro.sample(\"bin\", dist.Categorical(probs=discrete_bins / discrete_bins.sum()))\n",
"\n",
" return beta_bins[idx]"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 6,
"source": [
"wings_prior_params = Params(theta=0.5, gamma=0.99, delta=10.0)\n",
"wings_prior = structured_prior(wings_prior_params)\n",
"\n",
"for el in wings_prior.enumerate_support():\n",
" print(el.item(), wings_prior.log_prob(el).exp().item())"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0.0 0.015988904498871442\n",
"0.01 2.2204460492503185e-16\n",
"0.1 2.213097002755768e-11\n",
"0.2 1.175451410947732e-08\n",
"0.3 4.893387852027631e-07\n",
"0.4 7.274723747978211e-06\n",
"0.5 6.245665819871659e-05\n",
"0.6 0.00038682171379413966\n",
"0.7 0.00197597813839096\n",
"0.8 0.009340983321304765\n",
"0.9 0.0497252576984154\n",
"0.99 0.922511822131846\n"
]
}
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 7,
"source": [
"def utterance_prior() -> torch.Tensor:\n",
" utts = torch.arange(0, len(utterances), 1)\n",
" probs = torch.ones_like(utts) / len(utts)\n",
" idx = pyro.sample(\"utterance\", dist.Categorical(probs=probs))\n",
" return utts[idx]"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 8,
"source": [
"def threshold_prior() -> torch.Tensor:\n",
" bins = torch.arange(0.0, 1.0, 0.1)\n",
" idx = pyro.sample(\"threshold\", dist.Categorical(logits=torch.zeros_like(bins)))\n",
" return bins[idx]"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 9,
"source": [
"def meaning(utterance: torch.Tensor, state: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:\n",
" possible_evals = {\n",
" \"as_genT\": (state > threshold),\n",
" \"as_genF\": (state <= threshold),\n",
" \"is_mu\" : torch.full_like(state, True, dtype=bool),\n",
" \"is_some\": (state > 0),\n",
" \"is_most\": (state >= 0.5),\n",
" \"is_all\" : (state >= 0.99),\n",
" \"as_num\" : (state == utterance),\n",
" \"default\": torch.full_like(state, True, dtype=bool),\n",
" }\n",
"\n",
" meanings = torch.stack(list(possible_evals.values()))\n",
"\n",
" while utterance.ndim < meanings.ndim: # expand utterance to be used as an indexer\n",
" utterance = utterance[None]\n",
"\n",
" return torch.gather(meanings, dim=0, index=utterance.long()).float().squeeze()"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"# Listener 0"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 10,
"source": [
"@Marginal\n",
"def listener0(utterances: torch.Tensor, thresholds: torch.Tensor, prior: HashingMarginal) -> torch.Tensor:\n",
" state = pyro.sample(f\"state\", prior)\n",
" means = meaning(utterances, state, thresholds)\n",
" pyro.factor(f\"listener0-true\", torch.where(means == 1., 0., -99_999.))\n",
" return state"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": 11,
"source": [
"wings_posterior = listener0(torch.tensor([1]), torch.tensor([0.1]), wings_prior)\n",
"for el in wings_posterior.enumerate_support():\n",
" print(el, wings_posterior.log_prob(el).exp().item())\n"
],
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor(0.9000) 1.0\n"
]
}
],
"metadata": {}
}
],
"metadata": {
"orig_nbformat": 4,
"language_info": {
"name": "python",
"version": "3.7.10",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.7.10 64-bit ('pyro': conda)"
},
"interpreter": {
"hash": "9bc28b369b6d70ba89e337d74d0d547816f4fa8c93eafedc4eb46222a9da357b"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
120 changes: 120 additions & 0 deletions examples/rsa/generics-vectorized.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/usr/bin/env python3

# %%
from IPython import get_ipython

# %%
get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')


# %%
import argparse
import collections
import numbers

import torch
from search_inference import HashingMarginal, Search, memoize

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import config_enumerate, TraceEnum_ELBO
from pyro.ops.indexing import Vindex

torch.set_default_dtype(torch.float64) # double precision for numerical stability
torch.manual_seed(42)

utterances = [
"generic is true", "generic is false",
"mu", "some", "most", "all",
]

from vectorized_search import (
VectoredSearch as VSearch,
VectoredHashingMarginal as VHMarginal,
)


def Marginal(fn):
return memoize(lambda *args: VHMarginal(VSearch(config_enumerate(fn)).run(*args)))


Params = collections.namedtuple("Params", ["theta", "gamma", "delta"])
beta_bins = torch.tensor([0., 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99])


# %%
@Marginal
def structured_prior(params: Params) -> torch.Tensor:
# computing the Beta pdf for discretized bins above for enumerated Search
shape_alpha = params.gamma * params.delta - 1
shape_beta = (1. - params.gamma) * params.delta - 1
discrete_bins = (beta_bins ** shape_alpha) * ((1. - beta_bins) ** shape_beta) * params.theta
discrete_bins[0] = (1 - params.theta)
idx = pyro.sample("bin", dist.Categorical(probs=discrete_bins / discrete_bins.sum()))

return beta_bins[idx]


# %%
wings_prior_params = Params(theta=0.5, gamma=0.99, delta=10.0)
wings_prior = structured_prior(wings_prior_params)

for el in wings_prior.enumerate_support():
print(el.item(), wings_prior.log_prob(el).exp().item())


# %%
def utterance_prior() -> torch.Tensor:
utts = torch.arange(0, len(utterances), 1)
probs = torch.ones_like(utts) / len(utts)
idx = pyro.sample("utterance", dist.Categorical(probs=probs))
return utts[idx]


# %%
def threshold_prior() -> torch.Tensor:
bins = torch.arange(0.0, 1.0, 0.1)
idx = pyro.sample("threshold", dist.Categorical(logits=torch.zeros_like(bins)))
return bins[idx]


# %%
def meaning(utterance: torch.Tensor, state: torch.Tensor, threshold: torch.Tensor) -> torch.Tensor:
possible_evals = {
"as_genT": (state > threshold),
"as_genF": (state <= threshold),
"is_mu" : torch.full_like(state, True, dtype=bool),
"is_some": (state > 0),
"is_most": (state >= 0.5),
"is_all" : (state >= 0.99),
"as_num" : (state == utterance),
"default": torch.full_like(state, True, dtype=bool),
}

meanings = torch.stack(list(possible_evals.values()))

while utterance.ndim < meanings.ndim: # expand utterance to be used as an indexer
utterance = utterance[None]

return torch.gather(meanings, dim=0, index=utterance.long()).float().squeeze()

# %% [markdown]
# # Listener 0

# %%
@Marginal
def listener0(utterances: torch.Tensor, thresholds: torch.Tensor, prior: HashingMarginal) -> torch.Tensor:
state = pyro.sample(f"state", prior)
means = meaning(utterances, state, thresholds)
pyro.factor(f"listener0-true", torch.where(means == 1., 0., -99_999.))
return state


# %%
wings_posterior = listener0(torch.tensor([1]), torch.tensor([0.1]), wings_prior)
for el in wings_posterior.enumerate_support():
print(el, wings_posterior.log_prob(el).exp().item())


Loading