Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Trace_ELBO that generalizes Trace_ELBO, TraceEnum_ELBO, and TraceGraph_ELBO #2893

Draft
wants to merge 62 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
b00948c
trace_elbo
Jul 3, 2021
da2f887
lint
Jul 3, 2021
f6c95e4
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Jul 5, 2021
3ec076d
test_gradient
Jul 5, 2021
a22ff4e
copy traceenum_elbo and add test model with poisson dist
Jul 16, 2021
d551fa2
lint
Jul 16, 2021
b68bb3f
use constant funsor
Jul 21, 2021
bfb13bf
working version
Jul 28, 2021
ca1a1fe
pass second test
Jul 28, 2021
6d6a9ed
clean up trace_elbo
Jul 29, 2021
0f23b42
add another test
Aug 8, 2021
91384ed
lazy eval
Aug 20, 2021
c18a8bd
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Sep 18, 2021
34d9a3c
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Sep 30, 2021
b0182c0
vectorize particles; update tests
Sep 30, 2021
dc31767
minor fixes; pin to funsor@normalize-logaddexp
Sep 30, 2021
5c0fe75
update docs/requirements
Sep 30, 2021
2b15fe1
combine Trace_ELBO and TraceEnum_ELBO
Sep 30, 2021
351090b
eager evaluation
Oct 1, 2021
7d029c7
rm file
Oct 1, 2021
1bb7380
lazy
Oct 1, 2021
42ad4fa
remove memoize
Oct 1, 2021
5b6afdb
merge TraceEnum_ELBO
Oct 10, 2021
33628aa
skip test
Oct 11, 2021
18a973b
fixes
Oct 12, 2021
2c3ead3
convert Tensor to Categorical
Oct 12, 2021
5fb1522
restore docs/requirements.txt
Oct 12, 2021
f907f93
pin funsor in docs/requirements
Oct 12, 2021
902e445
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Oct 12, 2021
0042f85
use funsor.optimizer.apply_optimizer; higher precision in the test
Oct 12, 2021
ee5a5ad
pin funsor to the latest commit
Oct 12, 2021
e4c6760
optimize logzq
Oct 12, 2021
aba300a
optimize logzq
Oct 13, 2021
d823153
restore TraceEnum_ELBO
Oct 13, 2021
c06e9e4
revert hmm changes
Oct 13, 2021
eee297d
_tensor_to_categorical helper function
Oct 13, 2021
d748efa
lazy to_funsor
Oct 13, 2021
a1970d6
reduce over particle_var
Oct 13, 2021
4c1ee9e
address comment in tests
Oct 13, 2021
5df30c8
import pyroapi
Oct 13, 2021
46ff6f4
compute expected grads using dice factors
Oct 14, 2021
d7ee7ee
add test with guide enumeration
Oct 15, 2021
49553c3
add two more tests
Oct 15, 2021
835f815
pin funsor
Oct 15, 2021
760eeb0
lint
Oct 15, 2021
ab3831c
remove breakpoint
Oct 15, 2021
0b46f3a
Merge branch 'dev' of https://github.com/pyro-ppl/pyro into fix-funso…
Oct 29, 2021
b6ff8e0
Approximate(ops.sample, ...) based approach
Nov 3, 2021
b5bece7
Importance funsor based approach
Nov 4, 2021
d6e246e
fixes
Nov 4, 2021
6582d7d
Merge branch 'dev' into fix-funsor-traceelbo
Apr 6, 2022
714fd62
fix funsor model enumeration
Apr 9, 2022
2d2210e
Merge branch 'fix-model-enumeration-funsor' into fix-funsor-traceelbo
Apr 9, 2022
29bad7a
use Sampled funsor
Apr 11, 2022
9144be1
fixes
Apr 11, 2022
e4c8a47
git fixes
Apr 11, 2022
c147ad9
Merge branch 'dev' into fix-funsor-traceelbo
Apr 11, 2022
703a2fa
use Provenance funsor
Apr 11, 2022
3137b1b
clean up
Apr 12, 2022
88713f6
fixes
May 5, 2022
99a0647
Merge branch 'dev' into fix-funsor-traceelbo
May 5, 2022
14131ad
use provenance
Jun 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 59 additions & 19 deletions pyro/contrib/funsor/infer/trace_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import contextlib

import funsor
from funsor.adjoint import AdjointTape

from pyro.contrib.funsor import to_data, to_funsor
from pyro.contrib.funsor.handlers import enum, plate, replay, trace
Expand All @@ -12,7 +13,7 @@
from pyro.infer import Trace_ELBO as _OrigTrace_ELBO

from .elbo import ELBO, Jit_ELBO
from .traceenum_elbo import terms_from_trace
from .traceenum_elbo import apply_optimizer, terms_from_trace


@copy_docs_from(_OrigTrace_ELBO)
Expand All @@ -21,29 +22,68 @@ def differentiable_loss(self, model, guide, *args, **kwargs):
with enum(), plate(
size=self.num_particles
) if self.num_particles > 1 else contextlib.ExitStack():
guide_tr = trace(config_enumerate(default="flat")(guide)).get_trace(
*args, **kwargs
)
guide_tr = trace(
config_enumerate(default="flat", num_samples=self.num_particles)(guide)
ordabayevy marked this conversation as resolved.
Show resolved Hide resolved
).get_trace(*args, **kwargs)
model_tr = trace(replay(model, trace=guide_tr)).get_trace(*args, **kwargs)

model_terms = terms_from_trace(model_tr)
guide_terms = terms_from_trace(guide_tr)

log_measures = guide_terms["log_measures"] + model_terms["log_measures"]
log_factors = model_terms["log_factors"] + [
-f for f in guide_terms["log_factors"]
]
plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"]
measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"]

elbo = funsor.Integrate(
sum(log_measures, to_funsor(0.0)),
sum(log_factors, to_funsor(0.0)),
measure_vars,
)
elbo = elbo.reduce(funsor.ops.add, plate_vars)

return -to_data(elbo)
with funsor.terms.eager:
costs = model_terms["log_factors"] + [
-f for f in guide_terms["log_factors"]
]

# compute expected cost
# Cf. pyro.infer.util.Dice.compute_expectation()
# https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212
# TODO Replace this with funsor.Expectation
plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
# compute the marginal logq in the guide corresponding to each cost term
targets = dict()
for cost in costs:
input_vars = frozenset(cost.inputs)
if input_vars not in targets:
targets[input_vars] = funsor.Tensor(
funsor.ops.new_zeros(
funsor.tensor.get_default_prototype(),
tuple(v.size for v in cost.inputs.values()),
),
cost.inputs,
cost.dtype,
)
with AdjointTape() as tape:
logzq = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
guide_terms["log_measures"] + list(targets.values()),
plates=plate_vars,
eliminate=(plate_vars | guide_terms["measure_vars"]),
)
marginals = tape.adjoint(
funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values())
)
# finally, integrate out guide variables in the elbo and all plates
elbo = to_funsor(0, output=funsor.Real)
for cost in costs:
target = targets[frozenset(cost.inputs)]
logzq_local = marginals[target].reduce(
funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars
)
log_prob = marginals[target] - logzq_local
elbo_term = funsor.Integrate(
log_prob,
cost,
guide_terms["measure_vars"] & frozenset(log_prob.inputs),
)
elbo += elbo_term.reduce(
funsor.ops.add, plate_vars & frozenset(cost.inputs)
)

# evaluate the elbo, using memoize to share tensor computation where possible
with funsor.interpretations.memoize():
return -to_data(apply_optimizer(elbo))


class JitTrace_ELBO(Jit_ELBO, Trace_ELBO):
Expand Down
93 changes: 93 additions & 0 deletions tests/contrib/funsor/test_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import logging

import pytest
import torch

from tests.common import assert_equal

# put all funsor-related imports here, so test collection works without funsor
try:
import funsor

import pyro.contrib.funsor

funsor.set_backend("torch")
from pyroapi import distributions as dist
from pyroapi import handlers, infer, pyro, pyro_backend
except ImportError:
pytestmark = pytest.mark.skip(reason="funsor is not installed")

logger = logging.getLogger(__name__)

# _PYRO_BACKEND = os.environ.get("TEST_ENUM_PYRO_BACKEND", "contrib.funsor")


@pytest.mark.parametrize(
"Elbo,backend",
[
("TraceEnum_ELBO", "pyro"),
("Trace_ELBO", "contrib.funsor"),
],
)
def test_particle_gradient(Elbo, backend):
with pyro_backend(backend):
pyro.clear_param_store()
data = torch.tensor([-0.5, 2.0])
# Normal = dist.Normal if reparameterized else fakes.NonreparameterizedNormal

def model():
with pyro.plate("data", len(data)) as ind:
x = data[ind]
z = pyro.sample("z", dist.Poisson(3))
pyro.sample("x", dist.Normal(z, 1), obs=x)

def guide():
# scale = pyro.param("scale", lambda: torch.tensor([1.0]))
with pyro.plate("data", len(data)):
rate = pyro.param("rate", lambda: torch.tensor([3.5, 1.5]), event_dim=0)
z_dist = dist.Poisson(rate)
# if has_rsample is not None:
# z_dist.has_rsample_(has_rsample)
pyro.sample("z", z_dist)

elbo = getattr(infer, Elbo)(
max_plate_nesting=1, # set this to ensure rng agrees across runs
num_particles=1,
strict_enumeration_warning=False,
)

# Elbo gradient estimator
pyro.set_rng_seed(0)
elbo.loss_and_grads(model, guide)
params = dict(pyro.get_param_store().named_parameters())
actual_grads = {
name: param.grad.detach().cpu() for name, param in params.items()
}

# capture sample values and log_probs
pyro.set_rng_seed(0)
guide_tr = handlers.trace(guide).get_trace()
model_tr = handlers.trace(handlers.replay(model, guide_tr)).get_trace()
guide_tr.compute_log_prob()
model_tr.compute_log_prob()
z = guide_tr.nodes["z"]["value"].data
rate = pyro.param("rate").data

loss_i = (
model_tr.nodes["x"]["log_prob"].data
+ model_tr.nodes["z"]["log_prob"].data
- guide_tr.nodes["z"]["log_prob"].data
)
dlogq_drate = z / rate - 1
expected_grads = {
"rate": -(dlogq_drate * loss_i - dlogq_drate),
}

for name in sorted(params):
logger.info("expected {} = {}".format(name, expected_grads[name]))
logger.info("actual {} = {}".format(name, actual_grads[name]))

assert_equal(actual_grads, expected_grads, prec=1e-4)