Skip to content

Commit

Permalink
Added deterministic_callbacks tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayam753 committed Jul 17, 2020
1 parent c80d3a7 commit 6189ba1
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions tests/test_variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@pytest.fixture(scope="function")
def simple_model():
def conjugate_normal_model():
unknown_mean = -5
known_sigma = 3
data_points = 1000
Expand Down Expand Up @@ -49,15 +49,15 @@ def approximation(request):
return request.param


def test_fit(approximation, simple_model):
model = simple_model["model"]()
def test_fit(approximation, conjugate_normal_model):
model = conjugate_normal_model["model"]()
approx = _test_kwargs[approximation]
advi = pm.fit(method=approx["method"](model), **approx["fit_kwargs"])
assert advi is not None
assert advi.losses.numpy().shape == (approx["fit_kwargs"].get("num_steps") or 10000,)

q_samples = advi.approximation.sample(10000)
estimated_mean = simple_model["estimated_mean"]
estimated_mean = conjugate_normal_model["estimated_mean"]
np.testing.assert_allclose(
np.mean(np.squeeze(q_samples.posterior["model/mu"].values, axis=0)),
estimated_mean,
Expand All @@ -84,3 +84,25 @@ def test_bivariate_shapes(bivariate_gaussian):

samples = advi.approximation.sample(5000)
assert samples.posterior["bivariate_gaussian/density"].values.shape == (1, 5000, 2)


def test_advi_with_deterministics(simple_model_with_deterministic):
advi = pm.fit(simple_model_with_deterministic(), num_steps=1000)
samples = advi.approximation.sample(100)
norm = "simple_model_with_deterministic/simple_model/norm"
determ = "simple_model_with_deterministic/determ"
np.testing.assert_allclose(samples.posterior[determ], samples.posterior[norm] * 2)


def test_advi_with_deterministics_in_nested_models(deterministics_in_nested_models):
(
model,
*_,
deterministic_mapping,
) = deterministics_in_nested_models
advi = pm.fit(model(), num_steps=1000)
samples = advi.approximation.sample(100)
for deterministic, (inputs, op) in deterministic_mapping.items():
np.testing.assert_allclose(
samples.posterior[deterministic], op(*[samples.posterior[i] for i in inputs]), rtol=1e-6
)

0 comments on commit 6189ba1

Please sign in to comment.