diff --git a/netgraph/_main.py b/netgraph/_main.py index 56584fa..2323e29 100755 --- a/netgraph/_main.py +++ b/netgraph/_main.py @@ -1852,6 +1852,68 @@ def _on_motion(self, event): self.fig.canvas.draw_idle() +class EmphasizeOnClick(object): + """Emphasize matplotlib artists when clicking on them by desaturating all other artists.""" + + def __init__(self, artist_to_mapping): + self.artist_to_mapping = artist_to_mapping + self.emphasizeable_artists = self.artist_to_mapping.keys() + self.artists = list(self.node_artists.values()) + list( + self.edge_artists.values() + ) + keys = list(self.node_artists.keys()) + list(self.edge_artists.keys()) + self.artist_to_key = dict(zip(self.artists, keys)) + self.mapping = None + self._base_alpha = {artist: artist.get_alpha() for artist in self.artists} + self.deemphasized_artists = [] + + try: + (self.fig,) = set(list(artist.figure for artist in self.artists)) + except ValueError: + raise Exception("All artists have to be on the same figure!") + + try: + (self.ax,) = set(list(artist.axes for artist in self.artists)) + except ValueError: + raise Exception("All artists have to be on the same axis!") + + self.fig.canvas.mpl_connect("button_release_event", self._on_release) + + def _add_mapping(self, selected_artist): + self.mapping = self.artist_to_mapping[selected_artist] + emphasized_artists = [] + for value in self.mapping: + if value in self.node_artists: + emphasized_artists.append(self.node_artists[value]) + elif value in self.edge_artists: + emphasized_artists.append(self.edge_artists[value]) + for artist in self.artists: + if artist not in emphasized_artists: + artist.set_alpha(self._base_alpha[artist] / 5) + self.deemphasized_artists.append(artist) + + def _remove_mapping(self): + for artist in self.deemphasized_artists: + artist.set_alpha(self._base_alpha[artist]) + self.deemphasized_artists = [] + self.mapping = None + + def _on_release(self, event): + if event.inaxes == self.ax: + for artist in self.emphasizeable_artists: + if artist.contains(event)[0]: + if self.mapping: + self._remove_mapping() + else: + self._add_mapping(artist) + self.fig.canvas.draw() + break + else: + if self.mapping: + self._remove_mapping() + self.fig.canvas.draw() + + class DraggableGraphWithGridMode(DraggableGraph): """ Implements a grid-mode, in which node positions are fixed to a grid. @@ -2080,6 +2142,128 @@ def _on_motion(self, event): self.fig.canvas.draw_idle() +class EmphasizeOnClickGraph(Graph, EmphasizeOnClick): + """Combines `EmphasizeOnClick` with the `Graph` class such that nodes are emphasized when clicking on them with the mouse. + + Parameters + ---------- + graph : various formats + Graph object to plot. Various input formats are supported. + In order of precedence: + + - Edge list: + Iterable of (source, target) or (source, target, weight) tuples, + or equivalent (E, 2) or (E, 3) ndarray, where E is the number of edges. + - Adjacency matrix: + Full-rank (V, V) ndarray, where V is the number of nodes/vertices. + The absence of a connection is indicated by a zero. + + .. note:: If V <= 3, any (2, 2) or (3, 3) matrices will be interpreted as edge lists.** + + - networkx.Graph, igraph.Graph, or graph_tool.Graph object + + mouseover_highlight_mapping : dict or None, default None + Determines which nodes and/or edges are highlighted when clicking on any given node or edge. + The keys of the dictionary are node and/or edge IDs, while the values are iterables of node and/or edge IDs. + If the parameter is None, a default dictionary is constructed, which maps + + - edges to themselves as well as their source and target nodes, and + - nodes to themselves as well as their immediate neighbours and any edges between them. + + *args, **kwargs + Parameters passed through to `Graph`. See its documentation for a full list of available arguments. + + Attributes + ---------- + node_artists : dict + Mapping of node IDs to matplotlib PathPatch artists. + edge_artists : dict + Mapping of edge IDs to matplotlib PathPatch artists. + node_label_artists : dict + Mapping of node IDs to matplotlib text objects (if applicable). + edge_label_artists : dict + Mapping of edge IDs to matplotlib text objects (if applicable). + node_positions : dict node : (x, y) tuple + Mapping of node IDs to node positions. + + See also + -------- + Graph + + """ + + def __init__(self, graph, mouseover_highlight_mapping=None, *args, **kwargs): + Graph.__init__(self, graph, *args, **kwargs) + + artists = list(self.node_artists.values()) + list(self.edge_artists.values()) + keys = list(self.node_artists.keys()) + list(self.edge_artists.keys()) + self.artist_to_key = dict(zip(artists, keys)) + EmphasizeOnClick.__init__(self, mouseover_highlight_mapping) + + if mouseover_highlight_mapping is None: # construct default mapping + self.mouseover_highlight_mapping = ( + self._get_default_mouseover_highlight_mapping() + ) + else: # this includes empty mappings! + self._check_mouseover_highlight_mapping(mouseover_highlight_mapping) + self.mouseover_highlight_mapping = mouseover_highlight_mapping + + def _get_default_mouseover_highlight_mapping(self): + mapping = dict() + + # mapping for edges: source node, target node and the edge itself + for source, target in self.edges: + mapping[(source, target)] = [(source, target), source, target] + + # mapping for nodes: the node itself, its neighbours, and any edges between them + adjacency_list = _edge_list_to_adjacency_list(self.edges, directed=False) + for node, neighbours in adjacency_list.items(): + mapping[node] = [node] + for neighbour in neighbours: + mapping[node].append(neighbour) + if (node, neighbour) in self.edge_artists: + mapping[node].append((node, neighbour)) + if (neighbour, node) in self.edge_artists: + mapping[node].append((neighbour, node)) + + return mapping + + def _check_mouseover_highlight_mapping(self, mapping): + if not isinstance(mapping, dict): + raise TypeError( + f"Parameter `mouseover_highlight_mapping` is a dictionary, not {type(mapping)}." + ) + + invalid_keys = [] + for key in mapping: + if key in self.node_artists: + pass + elif key in self.edge_artists: + pass + else: + invalid_keys.append(key) + if invalid_keys: + msg = "Parameter `mouseover_highlight_mapping` contains invalid keys:" + for key in invalid_keys: + msg += f"\n\t- {key}" + raise ValueError(msg) + + invalid_values = [] + for values in mapping.values(): + for value in values: + if value in self.node_artists: + pass + elif value in self.edge_artists: + pass + else: + invalid_values.append(value) + if invalid_values: + msg = "Parameter `mouseover_highlight_mapping` contains invalid values:" + for value in set(invalid_values): + msg += f"\n\t- {value}" + raise ValueError(msg) + + class AnnotateOnClick(object): """Show or hide annotations when clicking on matplotlib artists.""" @@ -2299,6 +2483,71 @@ def _remove_table(self): self.table = None +class TableOnHover(object): + """Show or hide tabular information when hovering over matplotlib artists.""" + + def __init__(self, artist_to_table, table_kwargs=None): + self.artist_to_table = artist_to_table + self.artists = list(self.node_artists.values()) + list( + self.edge_artists.values() + ) + self.table = None + self.table_fontsize = None + self.table_kwargs = dict( + # bbox = [1.1, 0.1, 0.5, 0.8], + # edges = 'horizontal', + ) + + if table_kwargs: + if "fontsize" in table_kwargs: + self.table_fontsize = table_kwargs["fontsize"] + self.table_kwargs.update(table_kwargs) + + try: + (self.fig,) = set(list(artist.figure for artist in artist_to_table)) + except ValueError: + raise Exception("All artists have to be on the same figure!") + + try: + (self.ax,) = set(list(artist.axes for artist in artist_to_table)) + except ValueError: + raise Exception("All artists have to be on the same axis!") + + self.fig.canvas.mpl_connect("motion_notify_event", self._on_motion) + if self.table: + self.table.remove() + self.table = None + + def _on_motion(self, event): + if event.inaxes == self.ax: + # on artist + selected_artist = None + for artist in self.artists: + if artist.contains(event)[0]: # returns two arguments for some reason + selected_artist = artist + break + + if selected_artist: + try: + df = self.artist_to_table[selected_artist] + except KeyError: + return + self.table = self.ax.table( + cellText=df.values.tolist(), + rowLabels=df.index.values, + colLabels=df.columns.values, + **self.table_kwargs, + ) + self.fig.canvas.draw() + + # not on any artist + if selected_artist is None: + if self.table: + self.table.remove() + self.table = None + self.fig.canvas.draw() + + class TableOnClickGraph(Graph, TableOnClick): """Combines `TableOnClick` with the `Graph` class such that nodes or edges can have toggleable tabular annotations.""" @@ -2321,6 +2570,30 @@ def __init__(self, *args, **kwargs): TableOnClick.__init__(self, artist_to_table) +class TableOnHoverGraph(Graph, TableOnHover): + """Combines `TableOnHover` with the `Graph` class such that nodes or edges can have toggleable tabular annotations.""" + + def __init__(self, *args, **kwargs): + Graph.__init__(self, *args, **kwargs) + + self.artist_to_table = dict() + if "tables" in kwargs: + for key, table in kwargs["tables"].items(): + if key in self.nodes: + self.artist_to_table[self.node_artists[key]] = table + elif key in self.edges: + self.artist_to_table[self.edge_artists[key]] = table + else: + raise ValueError( + f"There is no node or edge with the ID {key} for the table '{table}'." + ) + + if "table_kwargs" in kwargs: + TableOnHover.__init__(self, self.artist_to_table, kwargs["table_kwargs"]) + else: + TableOnHover.__init__(self, self.artist_to_table) + + class InteractiveGraph(DraggableGraphWithGridMode, EmphasizeOnHoverGraph, AnnotateOnClickGraph, TableOnClickGraph): """Extends the `Graph` class to support node placement with the mouse, emphasis of graph elements when hovering over them, and toggleable annotations.