You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
test_get_fitted_player_model_numpyro marked as an xfail.
Error is:
./airsenal/tests/test_score_predictions.py::test_get_fitted_player_model_numpyro Failed: [undefined]RuntimeError: Cannot find valid initial parameters. Please check your model again.
def test_get_fitted_player_model_numpyro():
pm = NumpyroPlayerModel()
assert isinstance(pm, NumpyroPlayerModel)
with test_past_data_session_scope() as ts:
> fpm = fit_player_data("FWD", "1819", 12, model=pm, dbsession=ts)
airsenal/tests/test_score_predictions.py:269:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
airsenal/framework/prediction_utils.py:525: in fit_player_data
fitted_model = model.fit(data)
airsenal/framework/player_model.py:177: in fit
mcmc.run(
.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:628: in run
states_flat, last_state = partial_map_fn(map_args)
.venv/lib/python3.11/site-packages/numpyro/infer/mcmc.py:410: in _single_chain_mcmc
new_init_state = self.sampler.init(
.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:713: in init
init_params = self._init_state(
.venv/lib/python3.11/site-packages/numpyro/infer/hmc.py:657: in _init_state
) = initialize_model(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
rng_key = Array([3923418436, 1366451097], dtype=uint32)
model = <numpyro.handlers.substitute object at 0x2880dec50>
def initialize_model(
rng_key,
model,
*,
init_strategy=init_to_uniform,
dynamic_args=False,
model_args=(),
model_kwargs=None,
forward_mode_differentiation=False,
validate_grad=True,
):
"""
(EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn`
and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood
to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`).
:param jax.random.PRNGKey rng_key: random number generator seed to
sample from the prior. The returned `init_params` will have the
batch shape ``rng_key.shape[:-1]``.
:param model: Python callable containing Pyro primitives.
:param callable init_strategy: a per-site initialization function.
See :ref:`init_strategy` section for available functions.
:param bool dynamic_args: if `True`, the `potential_fn` and
`constraints_fn` are themselves dependent on model arguments.
When provided a `*model_args, **model_kwargs`, they return
`potential_fn` and `constraints_fn` callables, respectively.
:param tuple model_args: args provided to the model.
:param dict model_kwargs: kwargs provided to the model.
:param bool forward_mode_differentiation: whether to use forward-mode differentiation
or reverse-mode differentiation. By default, we use reverse mode but the forward
mode can be useful in some cases to improve the performance. In addition, some
control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop`
only supports forward-mode differentiation. See
`JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_
for more information.
:param bool validate_grad: whether to validate gradient of the initial params.
Defaults to True.
:return: a namedtupe `ModelInfo` which contains the fields
(`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where
`param_info` is a namedtuple `ParamInfo` containing values from the prior
used to initiate MCMC, their corresponding potential energy, and their gradients;
`postprocess_fn` is a callable that uses inverse transforms
to convert unconstrained HMC samples to constrained values that
lie within the site's support, in addition to returning values
at `deterministic` sites in the model.
"""
model_kwargs = {} if model_kwargs is None else model_kwargs
substituted_model = substitute(
seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]),
substitute_fn=init_strategy,
)
(
inv_transforms,
replay_model,
has_enumerate_support,
model_trace,
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
# substitute param sites from model_trace to model so
# we don't need to generate again parameters of `numpyro.module`
model = substitute(
model,
data={
k: site["value"]
for k, site in model_trace.items()
if site["type"] in ["param"]
},
)
constrained_values = {
k: v["value"]
for k, v in model_trace.items()
if v["type"] == "sample"
and not v["is_observed"]
and not v["fn"].support.is_discrete
}
if has_enumerate_support:
from numpyro.contrib.funsor import config_enumerate, enum
if not isinstance(model, enum):
max_plate_nesting = _guess_max_plate_nesting(model_trace)
_validate_model(model_trace, plate_warning="error")
model = enum(config_enumerate(model), -max_plate_nesting - 1)
else:
_validate_model(model_trace, plate_warning="loose")
potential_fn, postprocess_fn = get_potential_fn(
model,
inv_transforms,
replay_model=replay_model,
enum=has_enumerate_support,
dynamic_args=dynamic_args,
model_args=model_args,
model_kwargs=model_kwargs,
)
init_strategy = (
init_strategy if isinstance(init_strategy, partial) else init_strategy()
)
if (init_strategy.func is init_to_value) and not replay_model:
init_values = init_strategy.keywords.get("values")
unconstrained_values = transform_fn(inv_transforms, init_values, invert=True)
init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
(init_params, pe, grad), is_valid = find_valid_initial_params(
rng_key,
substitute(
model,
data={
k: site["value"]
for k, site in model_trace.items()
if site["type"] in ["plate"]
},
),
init_strategy=init_strategy,
enum=has_enumerate_support,
model_args=model_args,
model_kwargs=model_kwargs,
prototype_params=prototype_params,
forward_mode_differentiation=forward_mode_differentiation,
validate_grad=validate_grad,
)
if not_jax_tracer(is_valid):
if device_get(~jnp.all(is_valid)):
with numpyro.validation_enabled(), trace() as tr:
# validate parameters
substituted_model(*model_args, **model_kwargs)
# validate values
for site in tr.values():
if site["type"] == "sample":
with warnings.catch_warnings(record=True) as ws:
site["fn"]._validate_sample(site["value"])
if len(ws) > 0:
for w in ws:
# at site information to the warning message
w.message.args = (
"Site {}: {}".format(
site["name"], w.message.args[0]
),
) + w.message.args[1:]
warnings.showwarning(
w.message,
w.category,
w.filename,
w.lineno,
file=w.file,
line=w.line,
)
> raise RuntimeError(
"Cannot find valid initial parameters. Please check your model again."
)
E RuntimeError: Cannot find valid initial parameters. Please check your model again.
.venv/lib/python3.11/site-packages/numpyro/infer/util.py:745: RuntimeError
The text was updated successfully, but these errors were encountered:
airsenal/tests/test_score_predictions.py::test_get_fitted_player_model_numpyro
/Users/jroberts/GitHub/AIrsenal/airsenal/framework/player_model.py:177: UserWarning: Site obs: Out-of-support values provided to log prob method. The value argument should be within the support.
mcmc.run(
test_get_fitted_player_model_numpyro
marked as an xfail.Error is:
The text was updated successfully, but these errors were encountered: