Skip to content

Commit

Permalink
Do not use initval in test model
Browse files Browse the repository at this point in the history
PRs pymc-devs#7508 and pymc-devs#7492 introduced incompatible changes but were not tested simultaneously.

Deepcopying the steps in the tests leads to deepcopying the model which uses `clone_model`, which in turn does not support initvals.
  • Loading branch information
ricardoV94 committed Oct 8, 2024
1 parent 465d8ac commit 79200d2
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import pytensor
import pytensor.tensor as pt

from pytensor import config
from pytensor.compile.ops import as_op

import pymc as pm
Expand All @@ -30,9 +29,9 @@ def simple_model():
mu = -2.1
tau = 1.3
with Model() as model:
Normal("x", mu, tau=tau, size=2, initval=np.array([0.1, 0.1]).astype(config.floatX))
x = Normal("x", mu, tau=tau, size=2)

return model.initial_point(), model, (mu, tau**-0.5)
return {"x": np.array([0.1, 0.1], dtype=x.type.dtype)}, model, (mu, tau**-0.5)


def another_simple_model():
Expand All @@ -46,11 +45,11 @@ def simple_categorical():
p = np.array([0.1, 0.2, 0.3, 0.4])
v = np.array([0.0, 1.0, 2.0, 3.0])
with Model() as model:
Categorical("x", p, size=3, initval=[1, 2, 3])
x = Categorical("x", p, size=3)

mu = np.dot(p, v)
var = np.dot(p, (v - mu) ** 2)
return model.initial_point(), model, (mu, var)
return {"x": np.array([1, 2, 3], dtype=x.type.dtype)}, model, (mu, var)


def multidimensional_model():
Expand Down Expand Up @@ -98,15 +97,14 @@ def mv_simple():
p = np.array([[2.0, 0, 0], [0.05, 0.1, 0], [1.0, -0.05, 5.5]])
tau = np.dot(p, p.T)
with pm.Model() as model:
pm.MvNormal(
x = pm.MvNormal(
"x",
pt.constant(mu),
tau=pt.constant(tau),
initval=np.array([0.1, 1.0, 0.8]),
)
H = tau
C = np.linalg.inv(H)
return model.initial_point(), model, (mu, C)
return {"x": np.array([0.1, 1.0, 0.8], dtype=x.type.dtype)}, model, (mu, C)


def mv_simple_coarse():
Expand Down

0 comments on commit 79200d2

Please sign in to comment.