Skip to content

Commit

Permalink
Summarize model with rich table
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 14, 2024
1 parent ff668d5 commit 941a88f
Show file tree
Hide file tree
Showing 3 changed files with 281 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,12 @@ Model Transforms

autoreparam.vip_reparametrize
autoreparam.VIP


Printing
========
.. currentmodule:: pymc_experimental.printing
.. autosummary::
:toctree: generated/

model_table
164 changes: 164 additions & 0 deletions pymc_experimental/printing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import itertools

import numpy as np

from pymc import Model
from pymc.printing import str_for_dist, str_for_potential_or_deterministic
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.type import Constant
from rich.box import SIMPLE_HEAD
from rich.table import Table


def _extract_value(var: SharedVariable | Constant) -> np.ndarray:
if isinstance(var, SharedVariable):
return var.get_value(borrow=True)
else:
return var.data


def model_table(
model: Model,
split_groups: bool = True,
truncate_deterministic: int | None = None,
parameter_count: bool = True,
) -> Table:
"""Create a rich table with a summary of the model's variables and their expressions.
Parameters
----------
model : Model
The PyMC model to summarize.
split_groups : bool
If True, each group of variables (data, free_RVs, deterministics, potentials, observed_RVs)
will be separated by a section.
truncate_deterministic : int | None
If not None, truncate the expression of deterministic variables that go beyond this length.
parameter_count : bool
If True, add a row with the total number of parameters in the model.
Returns
-------
Table
A rich table with the model's variables, their expressions and dims.
Examples
--------
.. code-block:: python
import numpy as np
import pymc as pm
from pymc_experimental.printing import model_table
coords = {"subject": range(20), "param": ["a", "b"]}
with pm.Model(coords=coords) as m:
x = pm.Data("x", np.random.normal(size=(20, 2)), dims=("subject", "param"))
y = pm.Data("y", np.random.normal(size=(20,)), dims="subject")
beta = pm.Normal("beta", mu=0, sigma=1, dims="param")
mu = pm.Deterministic("mu", pm.math.dot(x, beta), dims="subject")
sigma = pm.HalfNormal("sigma", sigma=1)
y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, dims="subject")
table = model_table(m)
table # Displays the following table in an interactive environment
'''
Variable Expression Dimensions
─────────────────────────────────────────────────────
x = Data subject[20] × param[2]
y = Data subject[20]
beta ~ Normal(0, 1) param[2]
sigma ~ HalfNormal(0, 1)
Parameter count = 3
mu = f(beta) subject[20]
y_obs ~ Normal(mu, sigma) subject[20]
'''
Output can be explicitly rendered in a rich console or exported to text, html or svg.
.. code-block:: python
from rich.console import Console
console = Console(record=True)
console.print(table)
text_export = console.export_text()
html_export = console.export_html()
svg_export = console.export_svg()
"""
table = Table(
show_header=True,
show_edge=False,
box=SIMPLE_HEAD,
highlight=False,
collapse_padding=True,
)
table.add_column("Variable", justify="right")
table.add_column("Expression", justify="left")
table.add_column("Dimensions")

dim_sizes = {k: _extract_value(v) for k, v in model.dim_lengths.items()}

groups = (
model.data_vars,
model.free_RVs,
model.deterministics,
model.potentials,
model.observed_RVs,
)
if not split_groups:
groups = (itertools.chain.from_iterable(groups),)

for group in groups:
if not group:
continue

for var in group:
var_name = var.name
dims = model.named_vars_to_dims.get(var_name, ())

is_data = var in model.data_vars
is_deterministic = var in model.deterministics
is_potential = var in model.potentials

if is_data:
var_expr = "Data"
elif is_deterministic:
str_repr = str_for_potential_or_deterministic(var, dist_name="")
_, var_expr = str_repr.split(" ~ ")
var_expr = var_expr[1:-1] # Remove outer parentheses (f(...))
if truncate_deterministic is not None and len(var_expr) > truncate_deterministic:
contents = var_expr[2:-1].split(", ")
str_len = 0
for show_n, content in enumerate(contents):
str_len += len(content) + 2
if str_len > truncate_deterministic:
break
var_expr = f"f({', '.join(contents[:show_n])}, ...)"
elif is_potential:
var_expr = str_for_potential_or_deterministic(var, dist_name="Potential").split(
" ~ "
)[1]
else:
var_expr = str_for_dist(var).split(" ~ ")[1]

dims_and_sizes = " × ".join(f"{dim}[{dim_sizes[dim]}]" for dim in dims)
sep = f'[b]{" =" if (is_data or is_deterministic or is_potential) else " ~"}[/b]'
table.add_row(var_name + sep, var_expr, dims_and_sizes)

if parameter_count and (not split_groups or group == model.free_RVs):
rv_shapes = model.eval_rv_shapes()
n_parameters = np.sum(
[np.prod(rv_shapes[free_rv.name]).astype(int) for free_rv in model.free_RVs]
)
table.add_row("", "", f"[i]Parameter count = {n_parameters}[/i]")

table.add_section()

return table
108 changes: 108 additions & 0 deletions tests/test_printing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import io

import numpy as np
import pymc as pm

from rich.console import Console

from pymc_experimental.printing import model_table


def get_text(table) -> str:
console = Console(
record=True,
file=io.StringIO(),
force_terminal=False,
force_interactive=False,
force_jupyter=False,
)
console.print(table)
return console.export_text()


def test_model_table():
with pm.Model(coords={"trial": range(6), "subject": range(20)}) as model:
x_data = pm.Data("x_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))
y_data = pm.Data("y_data", np.random.normal(size=(6, 20)), dims=("trial", "subject"))

mu = pm.Normal("mu", mu=0, sigma=1)
sigma = pm.HalfNormal("sigma", sigma=1)
global_intercept = pm.Normal("global_intercept", mu=0, sigma=1)
intercept_subject = pm.Normal("intercept_subject", mu=0, sigma=1, dims="subject")
beta_subject = pm.Normal("beta_subject", mu=mu, sigma=sigma, dims="subject")

mu_trial = pm.Deterministic(
"mu_trial",
global_intercept + intercept_subject + beta_subject * x_data,
dims=["trial", "subject"],
)
noise = pm.Exponential("noise", lam=1)
y = pm.Normal("y", mu=mu_trial, sigma=noise, observed=y_data, dims=("trial", "subject"))

pm.Potential("beta_subject_penalty", -pm.math.abs(beta_subject), dims="subject")

table_txt = get_text(model_table(model))
assert table_txt == (
""" Variable Expression Dimensions
────────────────────────────────────────────────────────────────────────────────
x_data = Data trial[6] × subject[20]
y_data = Data trial[6] × subject[20]
mu ~ Normal(0, 1)
sigma ~ HalfNormal(0, 1)
global_intercept ~ Normal(0, 1)
intercept_subject ~ Normal(0, 1) subject[20]
beta_subject ~ Normal(mu, sigma) subject[20]
noise ~ Exponential(f())
Parameter count = 44
mu_trial = f(beta_subject, trial[6] × subject[20]
intercept_subject,
global_intercept)
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
"""
)

table_txt = get_text(model_table(model, split_groups=False))
assert table_txt == (
""" Variable Expression Dimensions
────────────────────────────────────────────────────────────────────────────────
x_data = Data trial[6] × subject[20]
y_data = Data trial[6] × subject[20]
mu ~ Normal(0, 1)
sigma ~ HalfNormal(0, 1)
global_intercept ~ Normal(0, 1)
intercept_subject ~ Normal(0, 1) subject[20]
beta_subject ~ Normal(mu, sigma) subject[20]
noise ~ Exponential(f())
mu_trial = f(beta_subject, trial[6] × subject[20]
intercept_subject,
global_intercept)
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
Parameter count = 44
"""
)

table_txt = get_text(
model_table(model, split_groups=False, truncate_deterministic=30, parameter_count=False)
)
assert table_txt == (
""" Variable Expression Dimensions
────────────────────────────────────────────────────────────────────────────
x_data = Data trial[6] × subject[20]
y_data = Data trial[6] × subject[20]
mu ~ Normal(0, 1)
sigma ~ HalfNormal(0, 1)
global_intercept ~ Normal(0, 1)
intercept_subject ~ Normal(0, 1) subject[20]
beta_subject ~ Normal(mu, sigma) subject[20]
noise ~ Exponential(f())
mu_trial = f(beta_subject, ...) trial[6] × subject[20]
beta_subject_penalty = Potential(f(beta_subject)) subject[20]
y ~ Normal(mu_trial, noise) trial[6] × subject[20]
"""
)

0 comments on commit 941a88f

Please sign in to comment.