Skip to content

Commit

Permalink
Add more arguments to psense_summary (#34)
Browse files Browse the repository at this point in the history
* add more arguments to psense_summary

* rST and api formatting

* use alphas instead of delta

* fix test due to small numerical difference in lower_alpha computation

* fix bug delta

---------

Co-authored-by: Oriol Abril-Pla <[email protected]>
  • Loading branch information
aloctavodia and OriolAbril authored Nov 5, 2024
1 parent c5a4f5f commit 1a499b6
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 40 deletions.
10 changes: 10 additions & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# API reference

## Functions

```{eval-rst}
.. autosummary::
:toctree: generated/
arviz_stats.psense
arviz_stats.psense_summary
```

## Accessors
Currently, using accessors is the recommended way to call functions from `arviz_stats`.

Expand Down
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
"sphinx_design",
"jupyter_sphinx",
"sphinx_autosummary_accessors",
"IPython.sphinxext.ipython_directive",
"IPython.sphinxext.ipython_console_highlighting",
]

templates_path = ["_templates", sphinx_autosummary_accessors.templates_path]
Expand Down
11 changes: 9 additions & 2 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,17 @@ def power_scale_lw(self, alpha=1, dims=None):
"""Compute log weights for power-scaling of the DataTree."""
return get_function("power_scale_lw")(self._obj, alpha=alpha, dims=dims)

def power_scale_sense(self, lower_w=None, upper_w=None, delta=None, dims=None):
def power_scale_sense(
self, lower_w=None, upper_w=None, lower_alpha=None, upper_alpha=None, dims=None
):
"""Compute power-scaling sensitivity."""
return get_function("power_scale_sense")(
self._obj, lower_w=lower_w, upper_w=upper_w, delta=delta, dims=dims
self._obj,
lower_w=lower_w,
upper_w=upper_w,
lower_alpha=lower_alpha,
upper_alpha=upper_alpha,
dims=dims,
)


Expand Down
6 changes: 4 additions & 2 deletions src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def power_scale_lw(self, ary, alpha=0, axes=-1):
)
return psl_ufunc(ary, out_shape=(ary.shape[i] for i in axes), alpha=alpha)

def power_scale_sense(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_axis=-1):
def power_scale_sense(
self, ary, lower_w, upper_w, lower_alpha, upper_alpha, chain_axis=-2, draw_axis=-1
):
"""Compute power-scaling sensitivity."""
if chain_axis is None:
ary = np.expand_dims(ary, axis=0)
Expand All @@ -181,7 +183,7 @@ def power_scale_sense(self, ary, lower_w, upper_w, delta, chain_axis=-2, draw_ax
pss_array = make_ufunc(
self._power_scale_sense, n_output=1, n_input=3, n_dims=2, ravel=False
)
return pss_array(ary, lower_w, upper_w, delta=delta)
return pss_array(ary, lower_w, upper_w, lower_alpha=lower_alpha, upper_alpha=upper_alpha)

def compute_ranks(self, ary, axes=-1, relative=False):
"""Compute ranks of MCMC samples."""
Expand Down
7 changes: 4 additions & 3 deletions src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,15 @@ def power_scale_lw(self, da, alpha=0, dims=None):
kwargs={"axes": np.arange(-len(dims), 0, 1)},
)

def power_scale_sense(self, da, lower_w, upper_w, delta, dims=None):
def power_scale_sense(self, da, lower_w, upper_w, lower_alpha, upper_alpha, dims=None):
"""Compute power-scaling sensitivity."""
dims, chain_axis, draw_axis = validate_dims_chain_draw_axis(dims)
return apply_ufunc(
self.array_class.power_scale_sense,
*broadcast(da, lower_w, upper_w),
delta,
input_core_dims=[dims, dims, dims, []],
lower_alpha,
upper_alpha,
input_core_dims=[dims, dims, dims, [], []],
output_core_dims=[[]],
kwargs={"chain_axis": chain_axis, "draw_axis": draw_axis},
)
Expand Down
7 changes: 4 additions & 3 deletions src/arviz_stats/base/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,15 +551,16 @@ def _gpinv(probs, kappa, sigma, mu):

return q

def _power_scale_sense(self, ary, lower_w, upper_w, delta=0.01):
def _power_scale_sense(self, ary, lower_w, upper_w, lower_alpha, upper_alpha):
"""Compute power-scaling sensitivity by finite difference second derivative of CJS."""
ary = np.ravel(ary)
lower_w = np.ravel(lower_w)
upper_w = np.ravel(upper_w)
lower_cjs = max(self._cjs_dist(ary, lower_w), self._cjs_dist(-1 * ary, lower_w))
upper_cjs = max(self._cjs_dist(ary, upper_w), self._cjs_dist(-1 * ary, upper_w))
grad = (lower_cjs + upper_cjs) / (2 * np.log2(1 + delta))
return grad
lower_grad = -1 * lower_cjs / np.log2(lower_alpha)
upper_grad = upper_cjs / np.log2(upper_alpha)
return (lower_grad + upper_grad) / 2

def _power_scale_lw(self, ary, alpha):
"""Compute log weights for power-scaling component by alpha."""
Expand Down
119 changes: 92 additions & 27 deletions src/arviz_stats/psense.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

def psense(
dt,
var_names=None,
filter_vars=None,
group="prior",
coords=None,
sample_dims=None,
alphas=(0.99, 1.01),
group_var_names=None,
group_coords=None,
var_names=None,
coords=None,
filter_vars=None,
delta=0.01,
):
"""
Compute power-scaling sensitivity values.
Expand All @@ -38,28 +38,30 @@ def psense(
Refer to documentation of :func:`arviz.convert_to_dataset` for details.
For ndarray: shape = (chain, draw).
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
var_names : list of str, optional
Names of posterior variables to include in the power scaling sensitivity diagnostic
filter_vars: {None, "like", "regex"}, default None
Used for `var_names` only.
If ``None`` (default), interpret var_names as the real variables names.
If "like", interpret var_names as substrings of the real variables names.
If "regex", interpret var_names as regular expressions on the real variables names.
group : {"prior", "likelihood"}, default "prior"
If "likelihood", the pointsize log likelihood values are retrieved
from the ``log_likelihood`` group and added together.
If "prior", the log prior values are retrieved from the ``log_prior`` group.
coords : dict, optional
Coordinates defining a subset over the posterior. Only these variables will
be used when computing the prior sensitivity.
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
alphas : tuple
Lower and upper alpha values for gradient calculation. Defaults to (0.99, 1.01).
group_var_names : str, optional
Name of the prior or log likelihood variables to use
group_coords : dict, optional
Coordinates defining a subset over the group element for which to
compute the prior sensitivity diagnostic.
var_names : list of str, optional
Names of posterior variables to include in the power scaling sensitivity diagnostic
coords : dict, optional
Coordinates defining a subset over the posterior. Only these variables will
be used when computing the prior sensitivity.
filter_vars: {None, "like", "regex"}, default None
Used for `var_names` only.
If ``None`` (default), interpret var_names as the real variables names.
If "like", interpret var_names as substrings of the real variables names.
If "regex", interpret var_names as regular expressions on the real variables names.
delta : float
Value for finite difference derivative calculation.
Returns
-------
Expand All @@ -78,20 +80,22 @@ def psense(
References
----------
.. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
power-scaling*, 2022, https://arxiv.org/abs/2107.14054
power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5
"""
dataset = extract(
dt, var_names=var_names, filter_vars=filter_vars, group="posterior", combined=False
dt,
var_names=var_names,
filter_vars=filter_vars,
group="posterior",
combined=False,
keep_dataset=True,
)
if coords is not None:
dataset = dataset.sel(coords)

lower_alpha = 1 / (1 + delta)
upper_alpha = 1 + delta

lower_w, upper_w = _get_power_scale_weights(
dt,
alphas=(lower_alpha, upper_alpha),
alphas=alphas,
group=group,
sample_dims=sample_dims,
group_var_names=group_var_names,
Expand All @@ -101,20 +105,52 @@ def psense(
return dataset.azstats.power_scale_sense(
lower_w=lower_w,
upper_w=upper_w,
delta=delta,
lower_alpha=alphas[0],
upper_alpha=alphas[1],
dims=sample_dims,
)


def psense_summary(data, threshold=0.05, round_to=3):
def psense_summary(
data,
var_names=None,
filter_vars=None,
coords=None,
sample_dims=None,
threshold=0.05,
alphas=(0.99, 1.01),
group_var_names=None,
group_coords=None,
round_to=3,
):
"""
Compute the prior/likelihood sensitivity based on power-scaling perturbations.
Parameters
----------
data : DataTree
var_names : list of str, optional
Names of posterior variables to include in the power scaling sensitivity diagnostic
filter_vars: {None, "like", "regex"}, default None
Used for `var_names` only.
If ``None`` (default), interpret var_names as the real variables names.
If "like", interpret var_names as substrings of the real variables names.
If "regex", interpret var_names as regular expressions on the real variables names.
coords : dict, optional
Coordinates defining a subset over the posterior. Only these variables will
be used when computing the prior sensitivity.
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
threshold : float, optional
Threshold value to determine the sensitivity diagnosis. Default is 0.05.
alphas : tuple
Lower and upper alpha values for gradient calculation. Defaults to (0.99, 1.01).
group_var_names : str, optional
Name of the prior or log likelihood variables to use
group_coords : dict, optional
Coordinates defining a subset over the group element for which to
compute the prior sensitivity diagnostic
round_to : int, optional
Number of decimal places to round the sensitivity values. Default is 3.
Expand All @@ -127,9 +163,38 @@ def psense_summary(data, threshold=0.05, round_to=3):
- "strong prior / weak likelihood" if the prior sensitivity is above threshold
and the likelihood sensitivity is below the threshold
- "-" otherwise
Examples
--------
.. ipython::
In [1]: from arviz_base import load_arviz_data
...: from arviz_stats import psense_summary
...: rugby = load_arviz_data("rugby")
...: psense_summary(rugby, var_names="atts")
"""
pssdp = psense(data, group="prior")
pssdl = psense(data, group="likelihood")
pssdp = psense(
data,
var_names=var_names,
filter_vars=filter_vars,
group="prior",
sample_dims=sample_dims,
coords=coords,
alphas=alphas,
group_var_names=group_var_names,
group_coords=group_coords,
)
pssdl = psense(
data,
var_names=var_names,
filter_vars=filter_vars,
group="likelihood",
coords=coords,
sample_dims=sample_dims,
alphas=alphas,
group_var_names=group_var_names,
group_coords=group_coords,
)

joined = xr.concat([pssdp, pssdl], dim="component").assign_coords(
component=["prior", "likelihood"]
Expand Down
6 changes: 3 additions & 3 deletions tests/test_psense.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def test_psense_var_names():
result_0 = psense(uni_dt, group="prior", group_var_names=["mu"], var_names=["mu"])
result_1 = psense(uni_dt, group="prior", var_names=["mu"])
for result in (result_0, result_1):
assert "sigma" != result.name
assert "mu" == result.name
assert not isclose(result_0, result_1)
assert "sigma" not in result.data_vars
assert "mu" in result.data_vars
assert not isclose(result_0["mu"], result_1["mu"])


def test_psense_summary():
Expand Down

0 comments on commit 1a499b6

Please sign in to comment.