Skip to content

Commit

Permalink
Avoid recompiling initial_point and logp functions in sample
Browse files Browse the repository at this point in the history
Also removes costly `model.check_start_vals()`
Also makes all Step arguments expect vars keyword-only
Also set `trust_input=True` for automatically generated logp_dlogp_function
  • Loading branch information
ricardoV94 committed Nov 25, 2024
1 parent 4e06974 commit 342669f
Show file tree
Hide file tree
Showing 10 changed files with 217 additions and 142 deletions.
7 changes: 5 additions & 2 deletions pymc/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
from pymc.backends.base import BaseTrace, IBaseTrace
from pymc.backends.ndarray import NDArray
from pymc.blocking import PointType
from pymc.model import Model
from pymc.step_methods.compound import BlockedStep, CompoundStep

Expand Down Expand Up @@ -100,11 +101,12 @@ def _init_trace(
trace: BaseTrace | None,
model: Model,
trace_vars: list[TensorVariable] | None = None,
initial_point: PointType | None = None,
) -> BaseTrace:
"""Initialize a trace backend for a chain."""
strace: BaseTrace
if trace is None:
strace = NDArray(model=model, vars=trace_vars)
strace = NDArray(model=model, vars=trace_vars, test_point=initial_point)
elif isinstance(trace, BaseTrace):
if len(trace) > 0:
raise ValueError("Continuation of traces is no longer supported.")
Expand All @@ -122,7 +124,7 @@ def init_traces(
chains: int,
expected_length: int,
step: BlockedStep | CompoundStep,
initial_point: Mapping[str, np.ndarray],
initial_point: PointType,
model: Model,
trace_vars: list[TensorVariable] | None = None,
) -> tuple[RunType | None, Sequence[IBaseTrace]]:
Expand All @@ -145,6 +147,7 @@ def init_traces(
trace=backend,
model=model,
trace_vars=trace_vars,
initial_point=initial_point,
)
for chain_number in range(chains)
]
Expand Down
6 changes: 3 additions & 3 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
ShapeError,
ShapeWarning,
)
from pymc.initial_point import make_initial_point_fn
from pymc.initial_point import PointType, make_initial_point_fn
from pymc.logprob.basic import transformed_conditional_logp
from pymc.logprob.transforms import Transform
from pymc.logprob.utils import ParameterValueError, replace_rvs_by_values
Expand Down Expand Up @@ -174,7 +174,7 @@ def __init__(
casting="no",
compute_grads=True,
model=None,
initial_point=None,
initial_point: PointType | None = None,
ravel_inputs: bool | None = None,
**kwargs,
):
Expand Down Expand Up @@ -533,7 +533,7 @@ def logp_dlogp_function(
self,
grad_vars=None,
tempered=False,
initial_point=None,
initial_point: PointType | None = None,
ravel_inputs: bool | None = None,
**kwargs,
):
Expand Down
174 changes: 109 additions & 65 deletions pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def instantiate_steppers(
model: Model,
steps: list[Step],
selected_steps: Mapping[type[BlockedStep], list[Any]],
*,
step_kwargs: dict[str, dict] | None = None,
initial_point: PointType | None = None,
) -> Step | list[Step]:
"""Instantiate steppers assigned to the model variables.
Expand Down Expand Up @@ -131,13 +133,22 @@ def instantiate_steppers(
step_kwargs = {}

used_keys = set()
for step_class, vars in selected_steps.items():
if vars:
name = getattr(step_class, "name")
args = step_kwargs.get(name, {})
used_keys.add(name)
step = step_class(vars=vars, model=model, **args)
steps.append(step)
if selected_steps:
if initial_point is None:
initial_point = model.initial_point()

for step_class, vars in selected_steps.items():
if vars:
name = getattr(step_class, "name")
kwargs = step_kwargs.get(name, {})
used_keys.add(name)
step = step_class(
vars=vars,
model=model,
initial_point=initial_point,
**kwargs,
)
steps.append(step)

unused_args = set(step_kwargs).difference(used_keys)
if unused_args:
Expand All @@ -161,18 +172,22 @@ def assign_step_methods(
model: Model,
step: Step | Sequence[Step] | None = None,
methods: Sequence[type[BlockedStep]] | None = None,
step_kwargs: dict[str, Any] | None = None,
) -> Step | list[Step]:
) -> tuple[list[Step], dict[type[BlockedStep], list[Variable]]]:
"""Assign model variables to appropriate step methods.
Passing a specified model will auto-assign its constituent stochastic
variables to step methods based on the characteristics of the variables.
Passing a specified model will auto-assign its constituent value
variables to step methods based on the characteristics of the respective
random variables, and whether the logp can be differentiated with respect to it.
This function is intended to be called automatically from ``sample()``, but
may be called manually. Each step method passed should have a
``competence()`` method that returns an ordinal competence value
corresponding to the variable passed to it. This value quantifies the
appropriateness of the step method for sampling the variable.
The outputs of this function can then be passed to `instantiate_steppers()`
to initialize the assigned step samplers.
Parameters
----------
model : Model object
Expand All @@ -183,24 +198,32 @@ def assign_step_methods(
methods : iterable of step method classes, optional
The set of step methods from which the function may choose. Defaults
to the main step methods provided by PyMC.
step_kwargs : dict, optional
Parameters for the samplers. Keys are the lower case names of
the step method, values a dict of arguments.
Returns
-------
methods : list
List of step methods associated with the model's variables.
provided_steps: list of Step instances
List of user provided instantiated step(s)
assigned_steps: dict of Step class to Variable
Dictionary with automatically selected step classes as keys and associated value variables as values
"""
steps: list[Step] = []
provided_steps: list[Step] = []
assigned_vars: set[Variable] = set()

if step is not None:
if isinstance(step, BlockedStep | CompoundStep):
steps.append(step)
if isinstance(step, BlockedStep):
provided_steps = [step]
elif isinstance(step, Sequence):
provided_steps = list(step)
else:
steps.extend(step)
for step in steps:
raise ValueError(f"Step should be a Step or a sequence of Steps, got {step}")

for step in provided_steps:
if not isinstance(step, BlockedStep | CompoundStep):
if issubclass(step, BlockedStep | CompoundStep):
raise ValueError(f"Provided {step} was not initialized")
else:
raise ValueError(f"{step} is not a Step instance")

for var in step.vars:
if var not in model.value_vars:
raise ValueError(
Expand Down Expand Up @@ -235,7 +258,7 @@ def assign_step_methods(
)
selected_steps.setdefault(selected, []).append(var)

return instantiate_steppers(model, steps, selected_steps, step_kwargs)
return provided_steps, selected_steps


def _print_step_hierarchy(s: Step, level: int = 0) -> None:
Expand Down Expand Up @@ -719,22 +742,23 @@ def joined_blas_limiter():
msg = f"Only {draws} samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
_log.warning(msg)

auto_nuts_init = True
if step is not None:
if isinstance(step, CompoundStep):
for method in step.methods:
if isinstance(method, NUTS):
auto_nuts_init = False
elif isinstance(step, NUTS):
auto_nuts_init = False

initial_points = None
step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)
exclusive_nuts = (
# User provided an instantiated NUTS step, and nothing else is needed
(not selected_steps and len(provided_steps) == 1 and isinstance(provided_steps[0], NUTS))
or
# Only automatically selected NUTS step is needed
(
not provided_steps
and len(selected_steps) == 1
and issubclass(next(iter(selected_steps)), NUTS)
)
)

if nuts_sampler != "pymc":
if not isinstance(step, NUTS):
if not exclusive_nuts:
raise ValueError(
"Model can not be sampled with NUTS alone. Your model is probably not continuous."
"Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability."
)

with joined_blas_limiter():
Expand All @@ -755,13 +779,11 @@ def joined_blas_limiter():
**kwargs,
)

if isinstance(step, list):
step = CompoundStep(step)
elif isinstance(step, NUTS) and auto_nuts_init:
if exclusive_nuts and not provided_steps:
# Special path for NUTS initialization
if "nuts" in kwargs:
nuts_kwargs = kwargs.pop("nuts")
[kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
_log.info("Auto-assigning NUTS sampler...")
with joined_blas_limiter():
initial_points, step = init_nuts(
init=init,
Expand All @@ -775,9 +797,8 @@ def joined_blas_limiter():
initvals=initvals,
**kwargs,
)

if initial_points is None:
# Time to draw/evaluate numeric start points for each chain.
else:
# Get initial points
ipfns = make_initial_point_fns_per_chain(
model=model,
overrides=initvals,
Expand All @@ -786,11 +807,16 @@ def joined_blas_limiter():
)
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)]

# One final check that shapes and logps at the starting points are okay.
ip: dict[str, np.ndarray]
for ip in initial_points:
model.check_start_vals(ip)
_check_start_shape(model, ip)
# Instantiate automatically selected steps
step = instantiate_steppers(
model,
steps=provided_steps,
selected_steps=selected_steps,
step_kwargs=kwargs,
initial_point=initial_points[0],
)
if isinstance(step, list):
step = CompoundStep(step)

if var_names is not None:
trace_vars = [v for v in model.unobserved_RVs if v.name in var_names]
Expand All @@ -806,7 +832,7 @@ def joined_blas_limiter():
expected_length=draws + tune,
step=step,
trace_vars=trace_vars,
initial_point=ip,
initial_point=initial_points[0],
model=model,
)

Expand Down Expand Up @@ -954,7 +980,6 @@ def _sample_return(
f"took {t_sampling:.0f} seconds."
)

idata = None
if compute_convergence_checks or return_inferencedata:
ikwargs: dict[str, Any] = {"model": model, "save_warmup": not discard_tuned_samples}
ikwargs.update(idata_kwargs)
Expand Down Expand Up @@ -1159,7 +1184,6 @@ def _iter_sample(
diverging : bool
Indicates if the draw is divergent. Only available with some samplers.
"""
model = modelcontext(model)
draws = int(draws)

if draws < 1:
Expand All @@ -1174,8 +1198,6 @@ def _iter_sample(
if hasattr(step, "reset_tuning"):
step.reset_tuning()
for i in range(draws):
diverging = False

if i == 0 and hasattr(step, "iter_count"):
step.iter_count = 0
if i == tune:
Expand Down Expand Up @@ -1298,6 +1320,7 @@ def _init_jitter(
seeds: Sequence[int] | np.ndarray,
jitter: bool,
jitter_max_retries: int,
logp_dlogp_func=None,
) -> list[PointType]:
"""Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
Expand Down Expand Up @@ -1328,19 +1351,30 @@ def _init_jitter(
if not jitter:
return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)]

model_logp_fn: Callable
if logp_dlogp_func is None:
model_logp_fn = model.compile_logp()
else:

def model_logp_fn(ip):
q, _ = DictToArrayBijection.map(ip)
return logp_dlogp_func([q], extra_vars={})[0]

initial_points = []
for ipfn, seed in zip(ipfns, seeds):
rng = np.random.RandomState(seed)
rng = np.random.default_rng(seed)
for i in range(jitter_max_retries + 1):
point = ipfn(seed)
if i < jitter_max_retries:
try:
point_logp = model_logp_fn(point)
if not np.isfinite(point_logp):
if i == jitter_max_retries:
# Print informative message on last attempted point
model.check_start_vals(point)
except SamplingError:
# Retry with a new seed
seed = rng.randint(2**30, dtype=np.int64)
else:
break
# Retry with a new seed
seed = rng.integers(2**30, dtype=np.int64)
else:
break

initial_points.append(point)
return initial_points

Expand Down Expand Up @@ -1436,10 +1470,12 @@ def init_nuts(

_log.info(f"Initializing NUTS using {init}...")

cb = [
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
]
cb = []
if "advi" in init:
cb = [
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
]

logp_dlogp_func = model.logp_dlogp_function(ravel_inputs=True)
logp_dlogp_func.trust_input = True
Expand All @@ -1449,6 +1485,7 @@ def init_nuts(
seeds=random_seed_list,
jitter="jitter" in init,
jitter_max_retries=jitter_max_retries,
logp_dlogp_func=logp_dlogp_func,
)

apoints = [DictToArrayBijection.map(point) for point in initial_points]
Expand Down Expand Up @@ -1562,7 +1599,14 @@ def init_nuts(
else:
raise ValueError(f"Unknown initializer: {init}.")

step = pm.NUTS(potential=potential, model=model, rng=random_seed_list[0], **kwargs)
step = pm.NUTS(
potential=potential,
model=model,
rng=random_seed_list[0],
initial_point=initial_points[0],
logp_dlogp_func=logp_dlogp_func,
**kwargs,
)

# Filter deterministics from initial_points
value_var_names = [var.name for var in model.value_vars]
Expand Down
Loading

0 comments on commit 342669f

Please sign in to comment.