-
-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ff668d5
commit 941a88f
Showing
3 changed files
with
281 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import itertools | ||
|
||
import numpy as np | ||
|
||
from pymc import Model | ||
from pymc.printing import str_for_dist, str_for_potential_or_deterministic | ||
from pytensor.compile.sharedvalue import SharedVariable | ||
from pytensor.graph.type import Constant | ||
from rich.box import SIMPLE_HEAD | ||
from rich.table import Table | ||
|
||
|
||
def _extract_value(var: SharedVariable | Constant) -> np.ndarray: | ||
if isinstance(var, SharedVariable): | ||
return var.get_value(borrow=True) | ||
else: | ||
return var.data | ||
|
||
|
||
def model_table( | ||
model: Model, | ||
split_groups: bool = True, | ||
truncate_deterministic: int | None = None, | ||
parameter_count: bool = True, | ||
) -> Table: | ||
"""Create a rich table with a summary of the model's variables and their expressions. | ||
Parameters | ||
---------- | ||
model : Model | ||
The PyMC model to summarize. | ||
split_groups : bool | ||
If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs) | ||
will be separated by a section. | ||
truncate_deterministic : int | None | ||
If not None, truncate the expression of deterministic variables that go beyond this length. | ||
parameter_count : bool | ||
If True, add a row with the total number of parameters in the model. | ||
Returns | ||
------- | ||
Table | ||
A rich table with the model's variables, their expressions and dims. | ||
Examples | ||
-------- | ||
.. code-block:: python | ||
import numpy as np | ||
import pymc as pm | ||
from pymc_experimental.printing import model_table | ||
coords = {"subject": range(20), "param": ["a", "b"]} | ||
with pm.Model(coords=coords) as m: | ||
x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param")) | ||
y = pm.Data("y", np.random.normal(size=(20,)), dims="subject") | ||
beta = pm.Normal("beta", mu=0, sigma=1, dims="param") | ||
mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject") | ||
sigma = pm.HalfNormal("sigma", sigma=1) | ||
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject") | ||
table = model_table(m) | ||
table # Displays the following table in an interactive environment | ||
''' | ||
Variable Expression Dimensions | ||
───────────────────────────────────────────────────── | ||
x = Data subject[20] × param[2] | ||
y = Data subject[20] | ||
beta ~ Normal(0, 1) param[2] | ||
sigma ~ HalfNormal(0, 1) | ||
Parameter count = 3 | ||
mu = f(beta) subject[20] | ||
y_obs ~ Normal(mu, sigma) subject[20] | ||
''' | ||
Output can be explicitly rendered in a rich console or exported to text, html or svg. | ||
.. code-block:: python | ||
from rich.console import Console | ||
console = Console(record=True) | ||
console.print(table) | ||
text_export = console.export_text() | ||
html_export = console.export_html() | ||
svg_export = console.export_svg() | ||
""" | ||
table = Table( | ||
show_header=True, | ||
show_edge=False, | ||
box=SIMPLE_HEAD, | ||
highlight=False, | ||
collapse_padding=True, | ||
) | ||
table.add_column("Variable", justify="right") | ||
table.add_column("Expression", justify="left") | ||
table.add_column("Dimensions") | ||
|
||
dim_sizes = {k: _extract_value(v) for k, v in model.dim_lengths.items()} | ||
|
||
groups = ( | ||
model.data_vars, | ||
model.free_RVs, | ||
model.deterministics, | ||
model.potentials, | ||
model.observed_RVs, | ||
) | ||
if not split_groups: | ||
groups = (itertools.chain.from_iterable(groups),) | ||
|
||
for group in groups: | ||
if not group: | ||
continue | ||
|
||
for var in group: | ||
var_name = var.name | ||
dims = model.named_vars_to_dims.get(var_name, ()) | ||
|
||
is_data = var in model.data_vars | ||
is_deterministic = var in model.deterministics | ||
is_potential = var in model.potentials | ||
|
||
if is_data: | ||
var_expr = "Data" | ||
elif is_deterministic: | ||
str_repr = str_for_potential_or_deterministic(var, dist_name="") | ||
_, var_expr = str_repr.split(" ~ ") | ||
var_expr = var_expr[1:-1] # Remove outer parentheses (f(...)) | ||
if truncate_deterministic is not None and len(var_expr) > truncate_deterministic: | ||
contents = var_expr[2:-1].split(", ") | ||
str_len = 0 | ||
for show_n, content in enumerate(contents): | ||
str_len += len(content) + 2 | ||
if str_len > truncate_deterministic: | ||
break | ||
var_expr = f"f({', '.join(contents[:show_n])}, ...)" | ||
elif is_potential: | ||
var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split( | ||
" ~ " | ||
)[1] | ||
else: | ||
var_expr = str_for_dist(var).split(" ~ ")[1] | ||
|
||
dims_and_sizes = " × ".join(f"{dim}[{dim_sizes[dim]}]" for dim in dims) | ||
sep = f'[b]{" =" if (is_data or is_deterministic or is_potential) else " ~"}[/b]' | ||
table.add_row(var_name + sep, var_expr, dims_and_sizes) | ||
|
||
if parameter_count and (not split_groups or group == model.free_RVs): | ||
rv_shapes = model.eval_rv_shapes() | ||
n_parameters = np.sum( | ||
[np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs] | ||
) | ||
table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]") | ||
|
||
table.add_section() | ||
|
||
return table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import io | ||
|
||
import numpy as np | ||
import pymc as pm | ||
|
||
from rich.console import Console | ||
|
||
from pymc_experimental.printing import model_table | ||
|
||
|
||
def get_text(table) -> str: | ||
console = Console( | ||
record=True, | ||
file=io.StringIO(), | ||
force_terminal=False, | ||
force_interactive=False, | ||
force_jupyter=False, | ||
) | ||
console.print(table) | ||
return console.export_text() | ||
|
||
|
||
def test_model_table(): | ||
with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model: | ||
x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject")) | ||
y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject")) | ||
|
||
mu = pm.Normal("mu", mu=0, sigma=1) | ||
sigma = pm.HalfNormal("sigma", sigma=1) | ||
global_intercept = pm.Normal("global_intercept", mu=0, sigma=1) | ||
intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, dims="subject") | ||
beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject") | ||
|
||
mu_trial = pm.Deterministic( | ||
"mu_trial", | ||
global_intercept + intercept_subject + beta_subject * x_data, | ||
dims=["trial", "subject"], | ||
) | ||
noise = pm.Exponential("noise", lam=1) | ||
y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject")) | ||
|
||
pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject") | ||
|
||
table_txt = get_text(model_table(model)) | ||
assert table_txt == ( | ||
""" Variable Expression Dimensions | ||
──────────────────────────────────────────────────────────────────────────────── | ||
x_data = Data trial[6] × subject[20] | ||
y_data = Data trial[6] × subject[20] | ||
mu ~ Normal(0, 1) | ||
sigma ~ HalfNormal(0, 1) | ||
global_intercept ~ Normal(0, 1) | ||
intercept_subject ~ Normal(0, 1) subject[20] | ||
beta_subject ~ Normal(mu, sigma) subject[20] | ||
noise ~ Exponential(f()) | ||
Parameter count = 44 | ||
mu_trial = f(beta_subject, trial[6] × subject[20] | ||
intercept_subject, | ||
global_intercept) | ||
beta_subject_penalty = Potential(f(beta_subject)) subject[20] | ||
y ~ Normal(mu_trial, noise) trial[6] × subject[20] | ||
""" | ||
) | ||
|
||
table_txt = get_text(model_table(model, split_groups=False)) | ||
assert table_txt == ( | ||
""" Variable Expression Dimensions | ||
──────────────────────────────────────────────────────────────────────────────── | ||
x_data = Data trial[6] × subject[20] | ||
y_data = Data trial[6] × subject[20] | ||
mu ~ Normal(0, 1) | ||
sigma ~ HalfNormal(0, 1) | ||
global_intercept ~ Normal(0, 1) | ||
intercept_subject ~ Normal(0, 1) subject[20] | ||
beta_subject ~ Normal(mu, sigma) subject[20] | ||
noise ~ Exponential(f()) | ||
mu_trial = f(beta_subject, trial[6] × subject[20] | ||
intercept_subject, | ||
global_intercept) | ||
beta_subject_penalty = Potential(f(beta_subject)) subject[20] | ||
y ~ Normal(mu_trial, noise) trial[6] × subject[20] | ||
Parameter count = 44 | ||
""" | ||
) | ||
|
||
table_txt = get_text( | ||
model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False) | ||
) | ||
assert table_txt == ( | ||
""" Variable Expression Dimensions | ||
──────────────────────────────────────────────────────────────────────────── | ||
x_data = Data trial[6] × subject[20] | ||
y_data = Data trial[6] × subject[20] | ||
mu ~ Normal(0, 1) | ||
sigma ~ HalfNormal(0, 1) | ||
global_intercept ~ Normal(0, 1) | ||
intercept_subject ~ Normal(0, 1) subject[20] | ||
beta_subject ~ Normal(mu, sigma) subject[20] | ||
noise ~ Exponential(f()) | ||
mu_trial = f(beta_subject, ...) trial[6] × subject[20] | ||
beta_subject_penalty = Potential(f(beta_subject)) subject[20] | ||
y ~ Normal(mu_trial, noise) trial[6] × subject[20] | ||
""" | ||
) |