Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualise models as graphs #754

Open
rhayes777 opened this issue Jul 6, 2023 · 5 comments
Open

Visualise models as graphs #754

rhayes777 opened this issue Jul 6, 2023 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@rhayes777
Copy link
Owner

Visualise model and prior relationships using a graph. Also visualise factor and variable relationships using a graph.

@rhayes777 rhayes777 added the enhancement New feature or request label Jul 6, 2023
@matthewghgriffiths
Copy link
Collaborator

I had set up some code for visualising FactorGraphs

  • def draw_graph(
    self,
    pos=None,
    ax=None,
    size=20,
    color="k",
    fill="w",
    factor_shape="s",
    variable_shape="o",
    factor_labels=None,
    variable_labels=None,
    factor_kws=None,
    variable_kws=None,
    edge_kws=None,
    factors=None,
    draw_labels=False,
    label_kws=None,
    **kwargs,
    ):
    try:
    import matplotlib.pyplot as plt
    import networkx as nx
    except ImportError as e:
    raise ImportError(
    "Matplotlib and networkx required for draw_graph()"
    ) from e
    except RuntimeError as e:
    print("Matplotlib unable to open display")
    raise e
    if ax is None:
    ax = plt.gca()
    G = self.graph
    if pos is None:
    pos = bipartite_layout(factors or self.factors)
    kwargs.setdefault("ms", size)
    kwargs.setdefault("c", color)
    kwargs.setdefault("mec", color)
    kwargs.setdefault("mfc", fill)
    kwargs.setdefault("ls", "")
    factor_kws = factor_kws or {}
    factor_kws.setdefault("marker", factor_shape)
    variable_kws = variable_kws or {}
    variable_kws.setdefault("marker", variable_shape)
    # draw factors
    xy = np.array([pos[f] for f in self.factors]).T
    fs = ax.plot(*xy, **{**kwargs, **factor_kws})
    # draw variables
    xy = np.array([pos[f] for f in self.all_variables]).T
    vs = ax.plot(*xy, **{**kwargs, **variable_kws})
    # draw edges
    edges = nx.draw_networkx_edges(G, pos, **(edge_kws or {}))
    # remove ticks from axes
    ax.tick_params(
    axis="both",
    which="both",
    bottom=False,
    left=False,
    labelbottom=False,
    labelleft=False,
    )
    if draw_labels:
    self.draw_graph_labels(
    pos,
    ax=ax,
    factor_labels=factor_labels,
    variable_labels=variable_labels,
    **(label_kws or {})
    )
    return pos, fs, vs, edges
    def draw_graph_labels(
    self,
    pos,
    factor_labels=None,
    variable_labels=None,
    shift=0.1,
    f_shift=None,
    v_shift=None,
    f_horizontalalignment="right",
    v_horizontalalignment="left",
    f_kws=None,
    v_kws=None,
    graph=None,
    ax=None,
    rescale=True,
    **kwargs,
    ):
    try:
    import matplotlib.pyplot as plt
    import networkx as nx
    except ImportError as e:
    raise ImportError(
    "Matplotlib and networkx required for draw_graph()"
    ) from e
    ax = ax or plt.gca()
    graph = graph or self.graph
    factor_labels = factor_labels or {f: f.name for f in self.factors}
    variable_labels = variable_labels or {v: v.name for v in self.all_variables}
    f_kws = f_kws or {"horizontalalignment": f_horizontalalignment}
    v_kws = v_kws or {"horizontalalignment": v_horizontalalignment}
    f_shift = f_shift or shift
    f_pos = {f: (x - f_shift, y) for f, (x, y) in pos.items()}
    v_shift = v_shift or shift
    v_pos = {f: (x + v_shift, y) for f, (x, y) in pos.items()}
    text = {
    **nx.draw_networkx_labels(
    graph, f_pos, labels=factor_labels, ax=ax, **f_kws, **kwargs
    ),
    **nx.draw_networkx_labels(
    graph, v_pos, labels=variable_labels, ax=ax, **v_kws, **kwargs
    ),
    }
    if rescale:
    rescale_to_artists(text.values(), ax=ax)
    return text
    def bipartite_layout(factors):
    n_factors = len(factors)
    n_variables = len(set().union(*(f.variables for f in factors)))
    n = max(n_factors, n_variables)
    factor_count = count()
    variable_count = count()
    pos = {}
    for factor in factors:
    pos[factor] = 0, next(factor_count) * n / n_factors
    for v in factor.variables:
    if v not in pos:
    pos[v] = 1, next(variable_count) * n / n_variables
    return pos

@rhayes777
Copy link
Owner Author

rhayes777 commented Jul 6, 2023 via email

@matthewghgriffiths
Copy link
Collaborator

Yeah, perhaps something like pyvis could be nice - https://towardsdatascience.com/pyvis-visualize-interactive-network-graphs-in-python-77e059791f01

@rhayes777
Copy link
Owner Author

rhayes777 commented Jul 6, 2023 via email

@Jammy2211
Copy link
Collaborator

The first steps towards this are here: #923

I have put up a number of issues which aim to make model text output more concnise:

#925
#926
#930
#931

It goes without saying that the code makes model text output more concise, and which offers the right constructs for a visual graph, are likely one and the same, and all issues should be considered together.

@Jammy2211 Jammy2211 self-assigned this May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants