Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo committed May 15, 2020
1 parent 51317f9 commit 937e344
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
16 changes: 12 additions & 4 deletions pyro/infer/autoguide/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,18 @@ def __call__(self, site):

def init_to_generated(site=None, generate=lambda: init_to_uniform):
"""
Initialize to the value specified in the ``values`` dict returned by
``values_fn``. This is similar to ``init_to_value(values=values_fn())`` but
calls ``values_fn`` once per model execution, thereby permitting multiple
randomized initializations.
Initialize to another initialization strategy returned by the callback
``generate`` which is called once per model execution.
This is like :func:`init_to_value` but can produce different (e.g. random)
values once per model execution. For example to generate values and return
``init_to_value`` you could define::
def generate():
values = {"x": torch.randn(100), "y": torch.rand(5)}
return init_to_value(values=values)
my_init_fn = init_to_generated(generate=generate)
:param callable generate: A callable returning another initialization
function, e.g. returning an ``init_to_value(values={...})`` populated
Expand Down
2 changes: 1 addition & 1 deletion tests/infer/mcmc/test_mcmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def model():
init_to_uniform(radius=0.1),
init_to_value(values={"x": torch.tensor(3.)}),
init_to_generated(
generate=lambda: init_to_value(values={"x": torch.randn(())})),
generate=lambda: init_to_value(values={"x": torch.rand(())})),
], ids=str)
def test_init_strategy_smoke(init_strategy):
def model():
Expand Down

0 comments on commit 937e344

Please sign in to comment.