From 33dc63c6c6a954f9041b0fdb264a85dc2e2e2e00 Mon Sep 17 00:00:00 2001 From: Mihaela Duta Date: Fri, 22 Nov 2024 15:17:11 +0000 Subject: [PATCH] Fix linting. --- l2gv2/network/graph.py | 103 +++++++++++++------ l2gv2/network/npgraph.py | 212 +++++++++++++++++++++++---------------- l2gv2/network/tgraph.py | 190 +++++++++++++++++++++++++---------- l2gv2/network/utils.py | 30 ++++-- 4 files changed, 357 insertions(+), 178 deletions(-) diff --git a/l2gv2/network/graph.py b/l2gv2/network/graph.py index 40684de..3de91a4 100644 --- a/l2gv2/network/graph.py +++ b/l2gv2/network/graph.py @@ -17,34 +17,42 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Sequence, Collection, Iterable +"""TODO: module docstring for network/graph.py""" + +from typing import Sequence, Iterable +from abc import abstractmethod import networkx as nx -from abc import ABC, abstractmethod import numpy as np - +# pylint: disable=too-many-instance-attributes +# pylint: disable=too-many-public-methods class Graph: """ numpy backed graph class with support for memmapped edge_index """ + weights: Sequence degree: Sequence - device = 'cpu' + device = "cpu" @staticmethod - def _convert_input(input): - return input + def _convert_input(inp): + return inp @classmethod def from_tg(cls, data): - return cls(edge_index=data.edge_index, - edge_attr=data.edge_attr, - x=data.x, - y=data.y, - num_nodes=data.num_nodes) + """ TODO: docstring for from_tg.""" + return cls( + edge_index=data.edge_index, + edge_attr=data.edge_attr, + x=data.x, + y=data.y, + num_nodes=data.num_nodes, + ) @classmethod def from_networkx(cls, nx_graph: nx.Graph, weight=None): + """ TODO: docstring for from_networkx.""" undir = not nx_graph.is_directed() if undir: nx_graph = nx_graph.to_directed(as_view=True) @@ -57,27 +65,45 @@ def from_networkx(cls, nx_graph: nx.Graph, weight=None): if w is not None: weights.append(w) if weights and len(weights) != num_edges: - raise RuntimeError('some edges have missing weight') + raise RuntimeError("some edges have missing weight") if weight is not None: weights = np.array(weights) else: weights = None - return cls(edge_index, weights, num_nodes=num_nodes, ensure_sorted=True, undir=undir) + return cls( + edge_index, weights, num_nodes=num_nodes, ensure_sorted=True, undir=undir + ) @abstractmethod - def __init__(self, edge_index, edge_attr=None, x=None, y=None, num_nodes=None, adj_index=None, - ensure_sorted=False, undir=None, nodes=None): + def __init__( + self, + edge_index, + edge_attr=None, + x=None, + y=None, + num_nodes=None, + adj_index=None, + ensure_sorted=False, + undir=None, + nodes=None, + ): """ Initialise graph Args: - edge_index: edge index such that ``edge_index[0]`` lists the source and ``edge_index[1]`` the target node for each edge + edge_index: edge index such that ``edge_index[0]`` lists the source + and ``edge_index[1]`` the target node for each edge + edge_attr: optionally provide edge weights + num_nodes: specify number of nodes (default: ``max(edge_index)+1``) + ensure_sorted: if ``False``, assume that the ``edge_index`` input is already sorted - undir: boolean indicating if graph is directed. If not provided, the ``edge_index`` is checked to determine this value. + + undir: boolean indicating if graph is directed. + If not provided, the ``edge_index`` is checked to determine this value. """ self.edge_index = self._convert_input(edge_index) self.edge_attr = self._convert_input(edge_attr) @@ -97,20 +123,23 @@ def weighted(self): @property def num_edges(self): + """ TODO: docstring for num_edges.""" return self.edge_index.shape[1] @property def num_features(self): + """ TODO: docstring for num_features.""" return 0 if self.x is None else self.x.shape[1] @property def nodes(self): + """ TODO: docstring for nodes.""" if self._nodes is None: return range(self.num_nodes) - else: - return self._nodes + return self._nodes def has_node_labels(self): + """ TODO: docstring for has_node_labels.""" return self._nodes is not None def adj(self, node: int): @@ -124,7 +153,7 @@ def adj(self, node: int): neighbours """ - return self.edge_index[1][self.adj_index[node]:self.adj_index[node + 1]] + return self.edge_index[1][self.adj_index[node] : self.adj_index[node + 1]] def adj_weighted(self, node: int): """ @@ -136,7 +165,9 @@ def adj_weighted(self, node: int): neighbours, weights """ - return self.adj(node), self.weights[self.adj_index[node]:self.adj_index[node + 1]] + return self.adj(node), self.weights[ + self.adj_index[node] : self.adj_index[node + 1] + ] @abstractmethod def edges(self): @@ -154,6 +185,7 @@ def edges_weighted(self): @abstractmethod def is_edge(self, source, target): + """ TODO: docstring for is_edge.""" raise NotImplementedError @abstractmethod @@ -203,6 +235,7 @@ def nodes_in_lcc(self): return (i for i, c in enumerate(self.connected_component_ids()) if c == 0) def lcc(self, relabel=False): + """ TODO: docstring for lcc.""" return self.subgraph(self.nodes_in_lcc(), relabel) def to_networkx(self): @@ -219,18 +252,21 @@ def to_networkx(self): return nxgraph def to(self, graph_cls): + """ TODO: docstring for to.""" if self.__class__ is graph_cls: return self - else: - return graph_cls(edge_index=self.edge_index, - edge_attr=self.edge_attr, - x=self.x, - y=self.y, - num_nodes=self.num_nodes, - adj_index=self.adj_index, - ensure_sorted=False, - undir=self.undir, - nodes=self._nodes) + + return graph_cls( + edge_index=self.edge_index, + edge_attr=self.edge_attr, + x=self.x, + y=self.y, + num_nodes=self.num_nodes, + adj_index=self.adj_index, + ensure_sorted=False, + undir=self.undir, + nodes=self._nodes, + ) @abstractmethod def bfs_order(self, start=0): @@ -248,11 +284,16 @@ def bfs_order(self, start=0): @abstractmethod def partition_graph(self, partition, self_loops=True): + """ TODO: docstring for partition_graph.""" raise NotImplementedError @abstractmethod def sample_negative_edges(self, num_samples): + """ TODO: docstring for sample_negative_edges.""" raise NotImplementedError def sample_positive_edges(self, num_samples): + """ TODO: docstring for sample_positive_edges.""" raise NotImplementedError +# pylint: enable=too-many-public-methods +# pylint: enable=too-many-instance-attributes diff --git a/l2gv2/network/npgraph.py b/l2gv2/network/npgraph.py index d51fc9f..ed0892b 100644 --- a/l2gv2/network/npgraph.py +++ b/l2gv2/network/npgraph.py @@ -17,10 +17,10 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""TODO: module docstring for network/npgraph.py""" import json from pathlib import Path -from tempfile import TemporaryFile from random import randrange import numpy as np @@ -28,83 +28,87 @@ import numba from numba.experimental import jitclass +from l2gv2 import progress from .graph import Graph -from local2global_embedding import progress rng = np.random.default_rng() spec = [ - ('edge_index', numba.int64[:, :]), - ('adj_index', numba.int64[:]), - ('degree', numba.int64[:]), - + ("edge_index", numba.int64[:, :]), + ("adj_index", numba.int64[:]), + ("degree", numba.int64[:]), ] +# pylint: disable=too-many-instance-attributes class NPGraph(Graph): """ numpy backed graph class with support for memmapped edge_index """ + @staticmethod - def _convert_input(input): - if input is None: - return input - elif isinstance(input, torch.Tensor): - return np.asanyarray(input.cpu()) - else: - return np.asanyarray(input) + def _convert_input(inp): + if inp is None: + return inp + + if isinstance(inp, torch.Tensor): + return np.asanyarray(inp.cpu()) + + return np.asanyarray(inp) @classmethod def load(cls, folder, mmap_edges=None, mmap_features=None): + """TODO: docstring for load.""" folder = Path(folder) kwargs = {} - kwargs['edge_index'] = np.load(folder / 'edge_index.npy', mmap_mode=mmap_edges) + kwargs["edge_index"] = np.load(folder / "edge_index.npy", mmap_mode=mmap_edges) - attr_file = folder / 'edge_attr.npy' + attr_file = folder / "edge_attr.npy" if attr_file.is_file(): - kwargs['edge_attr'] = np.load(attr_file, mmap_mode=mmap_edges) + kwargs["edge_attr"] = np.load(attr_file, mmap_mode=mmap_edges) - info_file = folder / 'info.json' + info_file = folder / "info.json" if info_file.is_file(): - with open(info_file) as f: + with open(info_file, encoding="utf-8") as f: info = json.load(f) kwargs.update(info) - feat_file = folder / 'node_feat.npy' + feat_file = folder / "node_feat.npy" if feat_file.is_file(): - kwargs['x'] = np.load(feat_file, mmap_mode=mmap_features) + kwargs["x"] = np.load(feat_file, mmap_mode=mmap_features) - label_file = folder / 'node_label.npy' + label_file = folder / "node_label.npy" if label_file.is_file(): - kwargs['y'] = np.load(label_file) + kwargs["y"] = np.load(label_file) - index_file = folder / 'adj_index.npy' + index_file = folder / "adj_index.npy" if index_file.is_file(): - kwargs['adj_index'] = np.load(index_file) + kwargs["adj_index"] = np.load(index_file) return cls(**kwargs) def save(self, folder): + """TODO: docstring for save.""" folder = Path(folder) - np.save(folder / 'edge_index.npy', self.edge_index) + np.save(folder / "edge_index.npy", self.edge_index) if self.weighted: - np.save(folder / 'edge_attr.npy', self.edge_attr) + np.save(folder / "edge_attr.npy", self.edge_attr) - np.save(folder / 'adj_index.npy', self.adj_index) + np.save(folder / "adj_index.npy", self.adj_index) - info = {'num_nodes': self.num_nodes, 'undir': self.undir} - with open(folder / 'info.json', 'w') as f: + info = {"num_nodes": self.num_nodes, "undir": self.undir} + with open(folder / "info.json", "w", encoding="utf-8") as f: json.dump(info, f) if self.y is not None: - np.save(self.y, folder / 'node_label.npy') + np.save(self.y, folder / "node_label.npy") if self.x is not None: - np.save(self.x, folder / 'node_feat.npy') + np.save(self.x, folder / "node_feat.npy") def __init__(self, *args, ensure_sorted=False, **kwargs): super().__init__(*args, **kwargs) @@ -114,12 +118,15 @@ def __init__(self, *args, ensure_sorted=False, **kwargs): if ensure_sorted: if isinstance(self.edge_index, np.memmap): - raise NotImplementedError("Sorting for memmapped arrays not yet implemented") - else: - index = np.argsort(self.edge_index[0]*self.num_nodes + self.edge_index[1]) - self.edge_index = self.edge_index[:, index] - if self.edge_attr is not None: - self.edge_attr = self.edge_attr[index] + raise NotImplementedError( + "Sorting for memmapped arrays not yet implemented" + ) + index = np.argsort( + self.edge_index[0] * self.num_nodes + self.edge_index[1] + ) + self.edge_index = self.edge_index[:, index] + if self.edge_attr is not None: + self.edge_attr = self.edge_attr[index] self._jitgraph = JitGraph(self.edge_index, self.num_nodes, self.adj_index, None) self.adj_index = self._jitgraph.adj_index self.degree = self._jitgraph.degree @@ -130,19 +137,27 @@ def __init__(self, *args, ensure_sorted=False, **kwargs): self.strength = np.zeros(self.num_nodes) #: tensor of node strength np.add.at(self.strength, self.edge_index[0], self.weights) else: - self.weights = np.broadcast_to(np.ones(1), (self.num_edges,)) # use expand to avoid actually allocating large array + self.weights = np.broadcast_to( + np.ones(1), (self.num_edges,) + ) # use expand to avoid actually allocating large array self.strength = self.degree - self.device = 'cpu' + self.device = "cpu" if self.undir is None: if isinstance(self.edge_index, np.memmap): - raise NotImplementedError("Checking directedness for memmapped arrays not yet implemented") - else: - index = np.argsort(self.edge_index[1]*self.num_nodes + self.edge_index[0]) - edge_reverse = self.edge_index[::-1, index] - self.undir = np.array_equal(self.edge_index, edge_reverse) - if self.weighted: - self.undir = self.undir and np.array_equal(self.weights, self.weights[index]) + raise NotImplementedError( + "Checking directedness for memmapped arrays not yet implemented" + ) + + index = np.argsort( + self.edge_index[1] * self.num_nodes + self.edge_index[0] + ) + edge_reverse = self.edge_index[::-1, index] + self.undir = np.array_equal(self.edge_index, edge_reverse) + if self.weighted: + self.undir = self.undir and np.array_equal( + self.weights, self.weights[index] + ) def edges(self): """ @@ -154,8 +169,10 @@ def edges_weighted(self): """ return list of edges where each edge is a tuple ``(source, target, weight)`` """ - return ((e[0], e[1], w[0] if w.size > 1 else w) - for e, w in zip(self.edge_index.T, self.weights)) + return ( + (e[0], e[1], w[0] if w.size > 1 else w) + for e, w in zip(self.edge_index.T, self.weights) + ) def is_edge(self, source, target): return self._jitgraph.is_edge(source, target) @@ -211,14 +228,16 @@ def subgraph(self, nodes: torch.Tensor, relabel=False, keep_x=True, keep_y=True) y = self.y[nodes] else: y = None - return self.__class__(edge_index=edge_index, - edge_attr=edge_attr[index] if edge_attr is not None else None, - num_nodes=len(nodes), - ensure_sorted=False, - undir=self.undir, - nodes=node_labels, - x=x, - y=y) + return self.__class__( + edge_index=edge_index, + edge_attr=edge_attr[index] if edge_attr is not None else None, + num_nodes=len(nodes), + ensure_sorted=False, + undir=self.undir, + nodes=node_labels, + x=x, + y=y, + ) def connected_component_ids(self): """ @@ -270,14 +289,18 @@ def bfs_order(self, start=0): new_nodes = new_nodes[not_visited[new_nodes]] number_new_nodes = len(new_nodes) not_visited[new_nodes] = False - bfs_list[append_pointer:append_pointer+number_new_nodes] = new_nodes + bfs_list[append_pointer : append_pointer + number_new_nodes] = new_nodes append_pointer += number_new_nodes return bfs_list def partition_graph(self, partition, self_loops=True): partition = np.asanyarray(partition) - partition_edges, weights = self._jitgraph.partition_graph_edges(partition, self_loops) - return self.__class__(edge_index=partition_edges, edge_attr=weights, undir=self.undir) + partition_edges, weights = self._jitgraph.partition_graph_edges( + partition, self_loops + ) + return self.__class__( + edge_index=partition_edges, edge_attr=weights, undir=self.undir + ) def sample_negative_edges(self, num_samples): return self._jitgraph.sample_negative_edges(num_samples) @@ -287,6 +310,9 @@ def sample_positive_edges(self, num_samples): return self.edge_index[:, index] +# pylint: enable=too-many-instance-attributes + + @numba.njit def _subgraph_edges(edge_index, adj_index, degs, num_nodes, sources): max_edges = degs[sources].sum() @@ -296,8 +322,8 @@ def _subgraph_edges(edge_index, adj_index, degs, num_nodes, sources): target_index[sources] = np.arange(len(sources)) count = 0 - for s in range(len(sources)): - for i in range(adj_index[sources[s]], adj_index[sources[s]+1]): + for s, source in enumerate(sources): + for i in range(adj_index[source], adj_index[source + 1]): t = target_index[edge_index[1, i]] if t >= 0: subgraph_edge_index[0, count] = s @@ -311,7 +337,7 @@ def _subgraph_edges(edge_index, adj_index, degs, num_nodes, sources): def _memmap_degree(edge_index, num_nodes): degree = np.zeros(num_nodes, dtype=np.int64) with numba.objmode: - print('computing degrees') + print("computing degrees") progress.reset(edge_index.shape[1]) for it, source in enumerate(edge_index[0]): degree[source] += 1 @@ -325,13 +351,15 @@ def _memmap_degree(edge_index, num_nodes): @jitclass( [ - ('edge_index', numba.int64[:, :]), - ('adj_index', numba.int64[:]), - ('degree', numba.int64[:]), - ('num_nodes', numba.int64) + ("edge_index", numba.int64[:, :]), + ("adj_index", numba.int64[:]), + ("degree", numba.int64[:]), + ("num_nodes", numba.int64), ] ) class JitGraph: + """TODO: docstring for JitGraph.""" + def __init__(self, edge_index, num_nodes=None, adj_index=None, degree=None): if num_nodes is None: num_nodes_int = edge_index.max() + 1 @@ -339,7 +367,7 @@ def __init__(self, edge_index, num_nodes=None, adj_index=None, degree=None): num_nodes_int = num_nodes if adj_index is None: - adj_index_ar = np.zeros((num_nodes_int+1,), dtype=np.int64) + adj_index_ar = np.zeros((num_nodes_int + 1,), dtype=np.int64) else: adj_index_ar = adj_index @@ -350,7 +378,7 @@ def __init__(self, edge_index, num_nodes=None, adj_index=None, degree=None): degree[s] += 1 adj_index_ar[1:] = degree.cumsum() else: - degree = adj_index_ar[1:]-adj_index_ar[:-1] + degree = adj_index_ar[1:] - adj_index_ar[:-1] self.edge_index = edge_index self.adj_index = adj_index_ar @@ -358,15 +386,23 @@ def __init__(self, edge_index, num_nodes=None, adj_index=None, degree=None): self.num_nodes = num_nodes_int def is_edge(self, source, target): + """TODO: docstring for is_edge.""" if source not in range(self.num_nodes) or target not in range(self.num_nodes): return False - index = np.searchsorted(self.edge_index[1, self.adj_index[source]:self.adj_index[source + 1]], target) - if index < self.degree[source] and self.edge_index[1, self.adj_index[source] + index] == target: + index = np.searchsorted( + self.edge_index[1, self.adj_index[source] : self.adj_index[source + 1]], + target, + ) + if ( + index < self.degree[source] + and self.edge_index[1, self.adj_index[source] + index] == target + ): return True - else: - return False + + return False def sample_negative_edges(self, num_samples): + """TODO: docstring for sample_negative_edges.""" i = 0 sampled_edges = np.empty((2, num_samples), dtype=np.int64) while i < num_samples: @@ -379,22 +415,26 @@ def sample_negative_edges(self, num_samples): return sampled_edges def adj(self, node): - return self.edge_index[1][self.adj_index[node]:self.adj_index[node+1]] + """TODO: docstring for adj.""" + return self.edge_index[1][self.adj_index[node] : self.adj_index[node + 1]] def neighbours(self, nodes): + """TODO: docstring for neighbours.""" size = self.degree[nodes].sum() out = np.empty((size,), dtype=np.int64) it = 0 for node in nodes: - out[it:it+self.degree[node]] = self.adj(node) + out[it : it + self.degree[node]] = self.adj(node) it += self.degree[node] return np.unique(out) def sample_positive_edges(self, num_samples): + """TODO: docstring for sample_positive_edges.""" index = np.random.randint(self.num_edges, (num_samples,)) return self.edge_index[:, index] def subgraph_edges(self, sources): + """TODO: docstring for subgraph_edges.""" max_edges = self.degree[sources].sum() subgraph_edge_index = np.empty((2, max_edges), dtype=np.int64) index = np.empty((max_edges,), dtype=np.int64) @@ -402,8 +442,8 @@ def subgraph_edges(self, sources): target_index[sources] = np.arange(len(sources)) count = 0 - for s in range(len(sources)): - for ei in range(self.adj_index[sources[s]], self.adj_index[sources[s]+1]): + for s, source in enumerate(sources): + for ei in range(self.adj_index[source], self.adj_index[source + 1]): t = target_index[self.edge_index[1][ei]] if t >= 0: subgraph_edge_index[0, count] = s @@ -413,13 +453,15 @@ def subgraph_edges(self, sources): return subgraph_edge_index[:, :count], index[:count] def subgraph(self, sources): + """TODO: docstring for subgraph.""" edge_index, _ = self.subgraph_edges(sources) return JitGraph(edge_index, len(sources), None, None) def partition_graph_edges(self, partition, self_loops): + """TODO: docstring for partition_graph_edges.""" num_edges = self.num_edges with numba.objmode: - print('finding partition edges') + print("finding partition edges") progress.reset(num_edges) num_clusters = partition.max() + 1 edge_counts = np.zeros((num_clusters, num_clusters), dtype=np.int64) @@ -441,20 +483,21 @@ def partition_graph_edges(self, partition, self_loops): return partition_edges, weights def partition_graph(self, partition, self_loops): + """TODO: docstring for partition_graph.""" edge_index, _ = self.partition_graph_edges(partition, self_loops) return JitGraph(edge_index, None, None, None) def connected_component_ids(self): """ - return nodes in breadth-first-search order + return nodes in breadth-first-search order - Args: - start: index of starting node (default: 0) + Args: + start: index of starting node (default: 0) - Returns: - tensor of node indeces + Returns: + tensor of node indeces - """ + """ components = np.full((self.num_nodes,), -1, dtype=np.int64) not_visited = np.ones(self.num_nodes, dtype=np.bool) component_id = 0 @@ -478,7 +521,7 @@ def connected_component_ids(self): not_visited[new_nodes] = False bfs_list.extend(new_nodes) - num_components = components.max()+1 + num_components = components.max() + 1 component_size = np.zeros((num_components,), dtype=np.int64) for i in components: component_size[i] += 1 @@ -493,4 +536,5 @@ def nodes_in_lcc(self): @property def num_edges(self): + """TODO: docstring for num_edges.""" return self.edge_index.shape[1] diff --git a/l2gv2/network/tgraph.py b/l2gv2/network/tgraph.py index aa611ae..7fbf4d3 100644 --- a/l2gv2/network/tgraph.py +++ b/l2gv2/network/tgraph.py @@ -17,7 +17,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import typing as _t +"""TODO: module docstring for network/tgraph.py""" import networkx as nx import torch @@ -26,40 +26,56 @@ from .graph import Graph - +# pylint: disable=too-many-instance-attributes class TGraph(Graph): """Wrapper class for pytorch-geometric edge_index providing fast adjacency look-up.""" + @staticmethod - def _convert_input(input): - if input is None: + def _convert_input(inp): + if inp is None: return None - else: - return torch.as_tensor(input) + + return torch.as_tensor(inp) def __init__(self, *args, ensure_sorted=False, **kwargs): super().__init__(*args, **kwargs) if self.num_nodes is None: - self.num_nodes = int(torch.max(self.edge_index)+1) #: number of nodes + self.num_nodes = int(torch.max(self.edge_index) + 1) #: number of nodes if ensure_sorted: - index = torch.argsort(self.edge_index[0]*self.num_nodes+self.edge_index[1]) + index = torch.argsort( + self.edge_index[0] * self.num_nodes + self.edge_index[1] + ) self.edge_index = self.edge_index[:, index] if self.edge_attr is not None: self.edge_attr = self.edge_attr[index] if self.adj_index is None: - self.degree = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device) #: tensor of node degrees - self.degree.index_add_(0, self.edge_index[0], - torch.ones(1, dtype=torch.long, device=self.device).expand(self.num_edges)) # use expand to avoid actually allocating large array - self.adj_index = torch.zeros(self.num_nodes + 1, dtype=torch.long) #: adjacency index such that edges starting at node ``i`` are given by ``edge_index[:, adj_index[i]:adj_index[i+1]]`` + self.degree = torch.zeros( + self.num_nodes, dtype=torch.long, device=self.device + ) #: tensor of node degrees + self.degree.index_add_( + 0, + self.edge_index[0], + torch.ones(1, dtype=torch.long, device=self.device).expand( + self.num_edges + ), + ) # use expand to avoid actually allocating large array + self.adj_index = torch.zeros( + self.num_nodes + 1, dtype=torch.long + ) + #: adjacency index such that edges starting at node ``i`` + # are given by ``edge_index[:, adj_index[i]:adj_index[i+1]]`` self.adj_index[1:] = torch.cumsum(self.degree, 0) else: self.degree = self.adj_index[1:] - self.adj_index[:-1] if self.weighted: self.weights = self.edge_attr - self.strength = torch.zeros(self.num_nodes, device=self.device, dtype=self.weights.dtype) #: tensor of node strength + self.strength = torch.zeros( + self.num_nodes, device=self.device, dtype=self.weights.dtype + ) #: tensor of node strength self.strength.index_add_(0, self.edge_index[0], self.weights) else: # use expand to avoid actually allocating large array @@ -67,10 +83,16 @@ def __init__(self, *args, ensure_sorted=False, **kwargs): self.strength = self.degree if self.undir is None: - index = torch.argsort(self.edge_index[1]*self.num_nodes+self.edge_index[0]) - self.undir = torch.equal(self.edge_index, self.edge_index[:, index].flip((0,))) + index = torch.argsort( + self.edge_index[1] * self.num_nodes + self.edge_index[0] + ) + self.undir = torch.equal( + self.edge_index, self.edge_index[:, index].flip((0,)) + ) if self.weighted: - self.undir = self.undir and torch.equal(self.weights, self.weights[index]) + self.undir = self.undir and torch.equal( + self.weights, self.weights[index] + ) @property def device(self): @@ -81,21 +103,38 @@ def edges(self): """ return list of edges where each edge is a tuple ``(source, target)`` """ - return ((self.edge_index[0, e].item(), self.edge_index[1, e].item()) for e in range(self.num_edges)) + return ( + (self.edge_index[0, e].item(), self.edge_index[1, e].item()) + for e in range(self.num_edges) + ) def edges_weighted(self): """ return list of edges where each edge is a tuple ``(source, target, weight)`` """ - return ((self.edge_index[0, e].item(), self.edge_index[1, e].item(), self.weights[e].cpu().numpy() - if self.weights.ndim > 1 else self.weights[e].item()) for e in range(self.num_edges)) + return ( + ( + self.edge_index[0, e].item(), + self.edge_index[1, e].item(), + self.weights[e].cpu().numpy() + if self.weights.ndim > 1 + else self.weights[e].item(), + ) + for e in range(self.num_edges) + ) def is_edge(self, source, target): - index = torch.bucketize(target, self.edge_index[1, self.adj_index[source]:self.adj_index[source+1]]) - if index < self.degree[source] and self.edge_index[1, self.adj_index[source]+index] == target: + index = torch.bucketize( + target, + self.edge_index[1, self.adj_index[source] : self.adj_index[source + 1]], + ) + if ( + index < self.degree[source] + and self.edge_index[1, self.adj_index[source] + index] == target + ): return True - else: - return False + + return False def neighbourhood(self, nodes: torch.Tensor, hops: int = 1): """ @@ -133,7 +172,14 @@ def subgraph(self, nodes: torch.Tensor, relabel=False, keep_x=True, keep_y=True) subgraph """ - index = torch.cat([torch.arange(self.adj_index[node], self.adj_index[node + 1], dtype=torch.long) for node in nodes]) + index = torch.cat( + [ + torch.arange( + self.adj_index[node], self.adj_index[node + 1], dtype=torch.long + ) + for node in nodes + ] + ) node_mask = torch.zeros(self.num_nodes, dtype=torch.bool, device=self.device) node_mask[nodes] = True node_ids = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device) @@ -155,29 +201,46 @@ def subgraph(self, nodes: torch.Tensor, relabel=False, keep_x=True, keep_y=True) else: y = None - return self.__class__(edge_index=node_ids[self.edge_index[:, index]], - edge_attr=edge_attr[index] if edge_attr is not None else None, - num_nodes=len(nodes), - ensure_sorted=True, - undir=self.undir, - x=x, - y=y, - nodes=node_labels - ) + return self.__class__( + edge_index=node_ids[self.edge_index[:, index]], + edge_attr=edge_attr[index] if edge_attr is not None else None, + num_nodes=len(nodes), + ensure_sorted=True, + undir=self.undir, + x=x, + y=y, + nodes=node_labels, + ) def connected_component_ids(self): - """Find the (weakly)-connected components. Component ids are sorted by size, such that id=0 corresponds - to the largest connected component""" + """ Find the (weakly)-connected components. + Component ids are sorted by size, such that id=0 corresponds + to the largest connected component + """ edge_index = self.edge_index is_undir = self.undir - last_components = torch.full((self.num_nodes,), self.num_nodes, dtype=torch.long, device=self.device) + last_components = torch.full( + (self.num_nodes,), self.num_nodes, dtype=torch.long, device=self.device + ) components = torch.arange(self.num_nodes, dtype=torch.long, device=self.device) while not torch.equal(last_components, components): last_components[:] = components - components = ts.scatter(last_components[edge_index[0]], edge_index[1], out=components, reduce='min') + components = ts.scatter( + last_components[edge_index[0]], + edge_index[1], + out=components, + reduce="min", + ) if not is_undir: - components = ts.scatter(last_components[edge_index[1]], edge_index[0], out=components, reduce='min') - component_id, inverse, component_size = torch.unique(components, return_counts=True, return_inverse=True) + components = ts.scatter( + last_components[edge_index[1]], + edge_index[0], + out=components, + reduce="min", + ) + _, inverse, component_size = torch.unique( + components, return_counts=True, return_inverse=True + ) new_id = torch.argsort(component_size, descending=True) return new_id[inverse] @@ -194,10 +257,10 @@ def to_networkx(self): nxgraph.add_nodes_from(range(self.num_nodes)) if self.x is not None: for i in range(self.num_nodes): - nxgraph.nodes[i]['x'] = self.x[i, :] + nxgraph.nodes[i]["x"] = self.x[i, :] if self.y is not None: for i in range(self.num_nodes): - nxgraph.nodes[i]['y'] = self.y[i] + nxgraph.nodes[i]["y"] = self.y[i] if self.weighted: nxgraph.add_weighted_edges_from(self.edges_weighted()) else: @@ -216,22 +279,27 @@ def to(self, *args, graph_cls=None, **kwargs): """ if args: - if not (graph_cls is None): - raise ValueError("Both positional and graph_cls keyword argument specified.") - elif len(args) == 1: + if graph_cls is not None: + raise ValueError( + "Both positional and graph_cls keyword argument specified." + ) + if len(args) == 1: arg = args[0] if isinstance(arg, type) and issubclass(arg, Graph): graph_cls = arg if kwargs: - raise ValueError("Cannot specify additional keyword arguments when converting between graph classes.") + raise ValueError( + "Cannot specify additional keyword arguments " + "when converting between graph classes." + ) if graph_cls is not None: return super().to(graph_cls) - else: - for key, value in self.__dict__.items(): - if isinstance(value, torch.Tensor): - self.__dict__[key] = value.to(*args, **kwargs) - return self + + for key, value in self.__dict__.items(): + if isinstance(value, torch.Tensor): + self.__dict__[key] = value.to(*args, **kwargs) + return self def bfs_order(self, start=0): """ @@ -244,7 +312,9 @@ def bfs_order(self, start=0): tensor of node indeces """ - bfs_list = torch.full((self.num_nodes,), -1, dtype=torch.long, device=self.device) + bfs_list = torch.full( + (self.num_nodes,), -1, dtype=torch.long, device=self.device + ) not_visited = torch.ones(self.num_nodes, dtype=torch.bool, device=self.device) bfs_list[0] = start not_visited[start] = False @@ -262,20 +332,29 @@ def bfs_order(self, start=0): new_nodes = new_nodes[not_visited[new_nodes]] number_new_nodes = len(new_nodes) not_visited[new_nodes] = False - bfs_list[append_pointer:append_pointer+number_new_nodes] = new_nodes + bfs_list[append_pointer : append_pointer + number_new_nodes] = new_nodes append_pointer += number_new_nodes return bfs_list def partition_graph(self, partition, self_loops=True): num_clusters = torch.max(partition) + 1 - pe_index = partition[self.edge_index[0]]*num_clusters + partition[self.edge_index[1]] + pe_index = ( + partition[self.edge_index[0]] * num_clusters + partition[self.edge_index[1]] + ) partition_edges, weights = torch.unique(pe_index, return_counts=True) - partition_edges = torch.stack((partition_edges // num_clusters, partition_edges % num_clusters), dim=0) + partition_edges = torch.stack( + (partition_edges // num_clusters, partition_edges % num_clusters), dim=0 + ) if not self_loops: valid = partition_edges[0] != partition_edges[1] partition_edges = partition_edges[:, valid] weights = weights[valid] - return self.__class__(edge_index=partition_edges, edge_attr=weights, num_nodes=num_clusters, undir=self.undir) + return self.__class__( + edge_index=partition_edges, + edge_attr=weights, + num_nodes=num_clusters, + undir=self.undir, + ) def sample_negative_edges(self, num_samples): return tg.utils.negative_sampling(self.edge_index, self.num_nodes, num_samples) @@ -283,3 +362,4 @@ def sample_negative_edges(self, num_samples): def sample_positive_edges(self, num_samples): index = torch.randint(self.num_edges, (num_samples,), dtype=torch.long) return self.edge_index[:, index] +# pylint: enable=too-many-instance-attributes diff --git a/l2gv2/network/utils.py b/l2gv2/network/utils.py index 630eedb..0fe94db 100644 --- a/l2gv2/network/utils.py +++ b/l2gv2/network/utils.py @@ -52,6 +52,7 @@ class UnionFind: http://www.ics.uci.edu/~eppstein/PADS/UnionFind.py """ + parents: numba.int64[:] weights: numba.int64[:] @@ -98,8 +99,11 @@ def conductance(graph: Graph, source, target=None): Args: graph: input graph + source: set of source nodes - target: set of target nodes (if ``target=None``, consider all nodes that are not in ``source`` as target) + + target: set of target nodes (if ``target=None``, + consider all nodes that are not in ``source`` as target) Returns: conductance @@ -114,7 +118,7 @@ def conductance(graph: Graph, source, target=None): out = torch.cat([graph.adj(node) for node in source]) cond = torch.sum(target_mask[out]).float() s_deg = graph.degree[source].sum() - t_deg = graph.num_edges-s_deg if target is None else graph.degree[target].sum() + t_deg = graph.num_edges - s_deg if target is None else graph.degree[target].sum() cond /= torch.minimum(s_deg, t_deg) return cond @@ -136,7 +140,12 @@ def spanning_tree(graph: TGraph, maximise=False): weights = graph.edge_attr[edge_mask] else: weights = None - return TGraph(edge_index=edge_index, edge_attr=weights, num_nodes=graph.num_nodes, ensure_sorted=False) + return TGraph( + edge_index=edge_index, + edge_attr=weights, + num_nodes=graph.num_nodes, + ensure_sorted=False, + ) def spanning_tree_mask(graph: Graph, maximise=False): @@ -152,7 +161,9 @@ def spanning_tree_mask(graph: Graph, maximise=False): # find positions of reverse edges if graph.undir: - reverse_edge_index = np.argsort(graph.edge_index[1]*graph.num_nodes+graph.edge_index[0]) + reverse_edge_index = np.argsort( + graph.edge_index[1] * graph.num_nodes + graph.edge_index[0] + ) forward_edge_index = np.flatnonzero(graph.edge_index[0] < graph.edge_index[1]) edges = graph.edge_index[:, forward_edge_index] weights = graph.weights[forward_edge_index] @@ -168,17 +179,20 @@ def spanning_tree_mask(graph: Graph, maximise=False): index = index[::-1] edge_mask = np.zeros(graph.num_edges, dtype=bool) - edge_mask = _spanning_tree_mask(edge_mask, edges, index, graph.num_nodes, forward_edge_index, reverse_edge_index) + edge_mask = _spanning_tree_mask( + edge_mask, edges, index, graph.num_nodes, forward_edge_index, reverse_edge_index + ) if convert_to_tensor: edge_mask = torch.as_tensor(edge_mask) return edge_mask @numba.njit -def _spanning_tree_mask(edge_mask, edges, index, num_nodes, forward_edge_index, reverse_edge_index): +def _spanning_tree_mask( + edge_mask, edges, index, num_nodes, forward_edge_index, reverse_edge_index +): subtrees = UnionFind(num_nodes) - for it in range(len(index)): - i = index[it] + for i in index: u = edges[0, i] v = edges[1, i] if subtrees.find(u) != subtrees.find(v):