Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow opting out of model nesting by setting model=None #7352

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 98 additions & 86 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Literal,
Optional,
TypeVar,
Union,
cast,
overload,
)
Expand Down Expand Up @@ -107,18 +108,10 @@ def __new__(cls, name, bases, dct, **kwargs):

def __enter__(self):
self.__class__.context_class.get_contexts().append(self)
# self._pytensor_config is set in Model.__new__
self._config_context = None
if hasattr(self, "_pytensor_config"):
self._config_context = pytensor.config.change_flags(**self._pytensor_config)
self._config_context.__enter__()
return self

def __exit__(self, typ, value, traceback):
self.__class__.context_class.get_contexts().pop()
# self._pytensor_config is set in Model.__new__
if self._config_context:
self._config_context.__exit__(typ, value, traceback)

dct[__enter__.__name__] = __enter__
dct[__exit__.__name__] = __exit__
Expand Down Expand Up @@ -400,76 +393,108 @@ class Model(WithMemoization, metaclass=ContextMeta):
name : str
name that will be used as prefix for names of all random
variables defined within model
coords : dict
Xarray-like coordinate keys and values. These coordinates can be used
to specify the shape of random variables and to label (but not specify)
the shape of Determinsitic, Potential and Data objects.
Other than specifying the shape of random variables, coordinates have no
effect on the model. They can't be used for label-based broadcasting or indexing.
You must use numpy-like operations for those behaviors.
check_bounds : bool
Ensure that input parameters to distributions are in a valid
range. If your model is built in a way where you know your
parameters can only take on valid values you can set this to
False for increased speed. This should not be used if your model
contains discrete variables.
model : PyMC model, optional
A parent model that this model belongs to. If not specified and the current model
is created inside another model's context, the parent model will be set to that model.
If `None` the model will not have a parent.

Examples
--------
How to define a custom model
Use context manager to define model and respective variables

.. code-block:: python

class CustomModel(Model):
# 1) override init
def __init__(self, mean=0, sigma=1, name=''):
# 2) call super's init first, passing model and name
# to it name will be prefix for all variables here if
# no name specified for model there will be no prefix
super().__init__(name, model)
# now you are in the context of instance,
# `modelcontext` will return self you can define
# variables in several ways note, that all variables
# will get model's name prefix

# 3) you can create variables with the register_rv method
self.register_rv(Normal.dist(mu=mean, sigma=sigma), 'v1', initval=1)
# this will create variable named like '{name::}v1'
# and assign attribute 'v1' to instance created
# variable can be accessed with self.v1 or self['v1']

# 4) this syntax will also work as we are in the
# context of instance itself, names are given as usual
Normal('v2', mu=mean, sigma=sigma)

# something more complex is allowed, too
half_cauchy = HalfCauchy('sigma', beta=10, initval=1.)
Normal('v3', mu=mean, sigma=half_cauchy)

# Deterministic variables can be used in usual way
Deterministic('v3_sq', self.v3 ** 2)

# Potentials too
Potential('p1', pt.constant(1))

# After defining a class CustomModel you can use it in several
# ways

# I:
# state the model within a context
with Model() as model:
CustomModel()
# arbitrary actions

# II:
# use new class as entering point in context
with CustomModel() as model:
Normal('new_normal_var', mu=1, sigma=0)

# III:
# just get model instance with all that was defined in it
model = CustomModel()

# IV:
# use many custom models within one context
with Model() as model:
CustomModel(mean=1, name='first')
CustomModel(mean=2, name='second')

# variables inside both scopes will be named like `first::*`, `second::*`
import pymc as pm

with pm.Model() as model:
x = pm.Normal("x")


Use object API to define model and respective variables

.. code-block:: python

import pymc as pm

model = pm.Model()
x = pm.Normal("x", model=model)


Use coords for defining the shape of random variables and labeling other model variables

.. code-block:: python

import pymc as pm
import numpy as np

coords = {
"feature", ["A", "B", "C"],
"trial", [1, 2, 3, 4, 5],
}

with pm.Model(coords=coords) as model:
intercept = pm.Normal("intercept", shape=(3,)) # Variable will have default dim label `intercept__dim_0`
beta = pm.Normal("beta", dims=("feature",)) # Variable will have shape (3,) and dim label `feature`

# Dims below are only used for labeling, they have no effect on shape
idx = pm.Data("idx", np.array([0, 1, 1, 2, 2])) # Variable will have default dim label `idx__dim_0`
x = pm.Data("x", np.random.normal(size=(5, 3)), dims=("trial", "feature"))
mu = pm.Deterministic("mu", intercept[idx] + beta @ x, dims="trial") # single dim can be passed as string

# Dims controls the shape of the variable
# If not specified, it would be inferred from the shape of the observations
y = pm.Normal("y", mu=mu, observed=[-1, 0, 0, 1, 1], dims=("trial",))


Define nested models, and provide name for variable name prefixing

.. code-block:: python

import pymc as pm

with pm.Model(name="root") as root:
x = pm.Normal("x") # Variable wil be named "root::x"

with pm.Model(name='first') as first:
# Variable will belong to root and first
y = pm.Normal("y", mu=x) # Variable wil be named "root::first::y"

# Can pass parent model explicitly
with pm.Model(name='second', model=root) as second:
# Variable will belong to root and second
z = pm.Normal("z", mu=y) # Variable wil be named "root::second::z"

# Set None for standalone model
with pm.Model(name="third", model=None) as third:
# Variable will belong to third only
w = pm.Normal("w") # Variable wil be named "third::w"


Set `check_bounds` to False for models with only continuous variables and default transformers
PyMC will remove the bounds check from the model logp which can speed up sampling

.. code-block:: python

import pymc as pm

with pm.Model(check_bounds=False) as model:
sigma = pm.HalfNormal("sigma")
x = pm.Normal("x", sigma=sigma) # No bounds check will be performed on `sigma`


"""

if TYPE_CHECKING:
Expand All @@ -478,20 +503,13 @@ def __enter__(self: Self) -> Self: ...

def __exit__(self, exc_type: None, exc_val: None, exc_tb: None) -> None: ...

def __new__(cls, *args, **kwargs):
def __new__(cls, *args, model: Union[Literal[UNSET], None, "Model"] = UNSET, **kwargs):
# resolves the parent instance
instance = super().__new__(cls)
if kwargs.get("model") is not None:
instance._parent = kwargs.get("model")
else:
if model is UNSET:
instance._parent = cls.get_context(error_if_none=False)
pytensor_config = kwargs.get("pytensor_config", {})
if pytensor_config:
warnings.warn(
"pytensor_config is deprecated. Use pytensor.config or pytensor.config.change_flags context manager instead.",
FutureWarning,
)
instance._pytensor_config = pytensor_config
else:
instance._parent = model
return instance

@staticmethod
Expand All @@ -507,10 +525,9 @@ def __init__(
check_bounds=True,
*,
coords_mutable=None,
pytensor_config=None,
model=None,
model: Union[Literal[UNSET], None, "Model"] = UNSET,
):
del pytensor_config, model # used in __new__
del model # used in __new__ to define the parent of this model
self.name = self._validate_name(name)
self.check_bounds = check_bounds

Expand Down Expand Up @@ -560,11 +577,6 @@ def __init__(
functools.partial(str_for_model, formatting="latex"), self
)

@property
def model(self):
warnings.warn("Model.model property is deprecated. Just use Model.", FutureWarning)
return self

@property
def parent(self):
return self._parent
Expand Down Expand Up @@ -671,7 +683,7 @@ def compile_d2logp(
jacobian : bool
Whether to include jacobian terms in logprob graph. Defaults to True.
"""
return self.model.compile_fn(
return self.compile_fn(
self.d2logp(vars=vars, jacobian=jacobian, negate_output=negate_output),
**compile_kwargs,
)
Expand Down
4 changes: 1 addition & 3 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,7 @@ def first_non_model_var(var):
else:
return var

model = Model()
if model.parent is not None:
raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context")
model = Model(model=None) # Do not inherit from any model in the context manager

_coords = getattr(fgraph, "_coords", {})
_dim_lengths = getattr(fgraph, "_dim_lengths", {})
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/transform/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pytensor import Variable
from pytensor.graph import ancestors

from pymc import Model
from pymc.model.core import Model
from pymc.model.fgraph import (
ModelObservedRV,
ModelVar,
Expand Down
2 changes: 1 addition & 1 deletion pymc/model/transform/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
from pytensor.graph import ancestors
from pytensor.tensor import TensorVariable

from pymc import Model
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import rvs_in_graph
from pymc.model.core import Model
from pymc.model.fgraph import (
ModelDeterministic,
ModelFreeRV,
Expand Down
2 changes: 1 addition & 1 deletion pymc/sampling/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def compute_deterministics(
model = modelcontext(model)

if var_names is None:
deterministics = model.deterministics
deterministics = list(model.deterministics)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
var_names = [det.name for det in deterministics]
else:
deterministics = [model[var_name] for var_name in var_names]
Expand Down
39 changes: 15 additions & 24 deletions pymc/stats/log_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

__all__ = ("compute_log_likelihood", "compute_log_prior")

from pymc.model.transform.conditioning import remove_value_transforms


def compute_log_likelihood(
idata: InferenceData,
Expand Down Expand Up @@ -126,46 +128,35 @@ def compute_log_density(
if kind not in ("likelihood", "prior"):
raise ValueError("kind must be either 'likelihood' or 'prior'")

# We need to disable transforms, because the InferenceData only keeps the untransformed values
umodel = remove_value_transforms(model)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

if kind == "likelihood":
target_rvs = model.observed_RVs
target_rvs = list(umodel.observed_RVs)
target_str = "observed_RVs"
else:
target_rvs = model.free_RVs
target_rvs = list(umodel.free_RVs)
target_str = "free_RVs"

if var_names is None:
vars = target_rvs
var_names = tuple(rv.name for rv in vars)
else:
vars = [model.named_vars[name] for name in var_names]
vars = [umodel.named_vars[name] for name in var_names]
if not set(vars).issubset(target_rvs):
raise ValueError(f"var_names must refer to {target_str} in the model. Got: {var_names}")

# We need to temporarily disable transforms, because the InferenceData only keeps the untransformed values
try:
original_rvs_to_values = model.rvs_to_values
original_rvs_to_transforms = model.rvs_to_transforms

model.rvs_to_values = {
rv: rv.clone() if rv not in model.observed_RVs else value
for rv, value in model.rvs_to_values.items()
}
model.rvs_to_transforms = {rv: None for rv in model.basic_RVs}

elemwise_logdens_fn = model.compile_fn(
inputs=model.value_vars,
outs=model.logp(vars=vars, sum=False),
on_unused_input="ignore",
)
finally:
model.rvs_to_values = original_rvs_to_values
model.rvs_to_transforms = original_rvs_to_transforms
elemwise_logdens_fn = umodel.compile_fn(
inputs=umodel.value_vars,
outs=umodel.logp(vars=vars, sum=False),
on_unused_input="ignore",
)

coords, dims = coords_and_dims_for_inferencedata(model)
coords, dims = coords_and_dims_for_inferencedata(umodel)

logdens_dataset = apply_function_over_dataset(
elemwise_logdens_fn,
posterior[[rv.name for rv in model.free_RVs]],
posterior[[rv.name for rv in umodel.free_RVs]],
output_var_names=var_names,
sample_dims=sample_dims,
dims=dims,
Expand Down
Loading
Loading