diff --git a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py index 41b8ac45..fc2f89f8 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py @@ -3,21 +3,24 @@ import torch from topobenchmarkx.transforms.liftings.graph2simplicial import ( - SimplicialCliqueLifting, + SimplicialCliqueLifting ) - +from topobenchmarkx.transforms.converters import Data2NxGraph, Complex2Dict +from topobenchmarkx.transforms.liftings.base import LiftingTransform class TestSimplicialCliqueLifting: """Test the SimplicialCliqueLifting class.""" def setup_method(self): # Initialise the SimplicialCliqueLifting class - self.lifting_signed = SimplicialCliqueLifting( - complex_dim=3, signed=True - ) - self.lifting_unsigned = SimplicialCliqueLifting( - complex_dim=3, signed=False - ) + data2graph = Data2NxGraph() + simplicial2dict_signed = Complex2Dict(signed=True) + simplicial2dict_unsigned = Complex2Dict(signed=False) + + lifting_map = SimplicialCliqueLifting(complex_dim=3) + + self.lifting_signed = LiftingTransform(data2graph, simplicial2dict_signed, lifting_map) + self.lifting_unsigned = LiftingTransform(data2graph, simplicial2dict_unsigned, lifting_map) def test_lift_topology(self, simple_graph_1): """Test the lift_topology method.""" diff --git a/topobenchmarkx/complex.py b/topobenchmarkx/complex.py new file mode 100644 index 00000000..8a2949f2 --- /dev/null +++ b/topobenchmarkx/complex.py @@ -0,0 +1,85 @@ +import torch + + +class PlainComplex: + def __init__( + self, + incidence, + down_laplacian, + up_laplacian, + adjacency, + coadjacency, + hodge_laplacian, + features=None, + ): + # TODO: allow None with nice error message if callable? + + # TODO: make this private? do not allow for changes in these values? + self.incidence = incidence + self.down_laplacian = down_laplacian + self.up_laplacian = up_laplacian + self.adjacency = adjacency + self.coadjacency = coadjacency + self.hodge_laplacian = hodge_laplacian + + if features is None: + features = [None for _ in range(len(self.incidence))] + else: + for rank, dim in enumerate(self.shape): + # TODO: make error message more informative + if ( + features[rank] is not None + and features[rank].shape[0] != dim + ): + raise ValueError("Features have wrong shape.") + + self.features = features + + @property + def shape(self): + """Shape of the complex. + + Returns + ------- + list[int] + """ + return [incidence.shape[-1] for incidence in self.incidence] + + @property + def max_rank(self): + """Maximum rank of the complex. + + Returns + ------- + int + """ + return len(self.incidence) + + def update_features(self, rank, values): + """Update features. + + Parameters + ---------- + rank : int + Rank of simplices the features belong to. + values : array-like + New features for the rank-simplices. + """ + self.features[rank] = values + + def reset_features(self): + """Reset features.""" + self.features = [None for _ in self.features] + + def propagate_values(self, rank, values): + """Propagate features from a rank to an upper one. + + Parameters + ---------- + rank : int + Rank of the simplices the values belong to. + values : array-like + Features for the rank-simplices. + """ + # TODO: can be made much better + return torch.matmul(torch.abs(self.incidence[rank + 1].t()), values) diff --git a/topobenchmarkx/transforms/converters.py b/topobenchmarkx/transforms/converters.py new file mode 100644 index 00000000..96920b74 --- /dev/null +++ b/topobenchmarkx/transforms/converters.py @@ -0,0 +1,313 @@ +import abc + +import networkx as nx +import numpy as np +import torch +import torch_geometric +from topomodelx.utils.sparse import from_sparse +from torch_geometric.utils.undirected import is_undirected, to_undirected + +from topobenchmarkx.complex import PlainComplex +from topobenchmarkx.data.utils.utils import ( + generate_zero_sparse_connectivity, + select_neighborhoods_of_interest, +) + + +class Converter(abc.ABC): + """Convert between data structures representing the same domain.""" + + def __call__(self, domain): + """Convert domain's data structure.""" + return self.convert(domain) + + @abc.abstractmethod + def convert(self, domain): + """Convert domain's data structure.""" + + +class IdentityConverter(Converter): + """Identity conversion. + + Retrieves same data structure for domain. + """ + + def convert(self, domain): + """Convert domain.""" + return domain + + +class Data2NxGraph(Converter): + """Data to nx.Graph conversion. + + Parameters + ---------- + preserve_edge_attr : bool + Whether to preserve edge attributes. + """ + + def __init__(self, preserve_edge_attr=False): + self.preserve_edge_attr = preserve_edge_attr + + def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: + r"""Check if the input data object has edge attributes. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + bool + Whether the data object has edge attributes. + """ + return hasattr(data, "edge_attr") and data.edge_attr is not None + + def convert(self, domain: torch_geometric.data.Data) -> nx.Graph: + r"""Generate a NetworkX graph from the input data object. + + Parameters + ---------- + domain : torch_geometric.data.Data + The input data. + + Returns + ------- + nx.Graph + The generated NetworkX graph. + """ + # Check if data object have edge_attr, return list of tuples as [(node_id, {'features':data}, 'dim':1)] or ?? + nodes = [ + (n, dict(features=domain.x[n], dim=0)) + for n in range(domain.x.shape[0]) + ] + + if self.preserve_edge_attr and self._data_has_edge_attr(domain): + # In case edge features are given, assign features to every edge + edge_index, edge_attr = ( + domain.edge_index, + ( + domain.edge_attr + if is_undirected(domain.edge_index, domain.edge_attr) + else to_undirected(domain.edge_index, domain.edge_attr) + ), + ) + edges = [ + (i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1)) + for edge_idx, (i, j) in enumerate( + zip(edge_index[0], edge_index[1], strict=False) + ) + ] + + else: + # If edge_attr is not present, return list list of edges + edges = [ + (i.item(), j.item(), {}) + for i, j in zip( + domain.edge_index[0], domain.edge_index[1], strict=False + ) + ] + graph = nx.Graph() + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +class Complex2PlainComplex(Converter): + """toponetx.Complex to PlainComplex conversion. + + NB: order of features plays a crucial role, as ``PlainComplex`` + simply stores them as lists (i.e. the reference to the indices + of the simplex are lost). + + Parameters + ---------- + max_rank : int + Maximum rank of the complex. + neighborhoods : list, optional + List of neighborhoods of interest. + signed : bool, optional + If True, returns signed connectivity matrices. + transfer_features : bool, optional + Whether to transfer features. + """ + + def __init__( + self, + max_rank=None, + neighborhoods=None, + signed=False, + transfer_features=True, + ): + super().__init__() + self.max_rank = max_rank + self.neighborhoods = neighborhoods + self.signed = signed + self.transfer_features = transfer_features + + def convert(self, domain): + """Convert toponetx.Complex to PlainComplex. + + Parameters + ---------- + domain : toponetx.Complex + + Returns + ------- + PlainComplex + """ + # NB: just a slightly rewriting of get_complex_connectivity + + max_rank = self.max_rank or domain.dim + signed = self.signed + neighborhoods = self.neighborhoods + + connectivity_infos = [ + "incidence", + "down_laplacian", + "up_laplacian", + "adjacency", + "coadjacency", + "hodge_laplacian", + ] + + practical_shape = list( + np.pad(list(domain.shape), (0, max_rank + 1 - len(domain.shape))) + ) + data = { + connectivity_info: [] for connectivity_info in connectivity_infos + } + for rank_idx in range(max_rank + 1): + for connectivity_info in connectivity_infos: + try: + data[connectivity_info].append( + from_sparse( + getattr(domain, f"{connectivity_info}_matrix")( + rank=rank_idx, signed=signed + ) + ) + ) + except ValueError: + if connectivity_info == "incidence": + data[connectivity_info].append( + generate_zero_sparse_connectivity( + m=practical_shape[rank_idx - 1], + n=practical_shape[rank_idx], + ) + ) + else: + data[connectivity_info].append( + generate_zero_sparse_connectivity( + m=practical_shape[rank_idx], + n=practical_shape[rank_idx], + ) + ) + + # TODO: handle this + if neighborhoods is not None: + data = select_neighborhoods_of_interest(data, neighborhoods) + + # TODO: simplex specific? + # TODO: how to do this for other? + if self.transfer_features and hasattr( + domain, "get_simplex_attributes" + ): + # TODO: confirm features are in the right order; update this + data["features"] = [] + for rank in range(max_rank + 1): + rank_features_dict = domain.get_simplex_attributes( + "features", rank + ) + if rank_features_dict: + rank_features = torch.stack( + list(rank_features_dict.values()) + ) + else: + rank_features = None + data["features"].append(rank_features) + + return PlainComplex(**data) + + +class PlainComplex2Dict(Converter): + """PlainComplex to dict conversion.""" + + def convert(self, domain): + """Convert PlainComplex to dict. + + Parameters + ---------- + domain : toponetx.Complex + + Returns + ------- + dict + """ + data = {} + connectivity_infos = [ + "incidence", + "down_laplacian", + "up_laplacian", + "adjacency", + "coadjacency", + "hodge_laplacian", + ] + for connectivity_info in connectivity_infos: + info = getattr(domain, connectivity_info) + for rank, rank_info in enumerate(info): + data[f"{connectivity_info}_{rank}"] = rank_info + + # TODO: handle neighborhoods + data["shape"] = domain.shape + + for index, values in enumerate(domain.features): + if values is not None: + data[f"x_{index}"] = values + + return data + + +class ConverterComposition(Converter): + def __init__(self, converters): + super().__init__() + self.converters = converters + + def convert(self, domain): + """Convert domain""" + for converter in self.converters: + domain = converter(domain) + + return domain + + +class Complex2Dict(ConverterComposition): + """Complex to dict conversion. + + Parameters + ---------- + max_rank : int + Maximum rank of the complex. + neighborhoods : list, optional + List of neighborhoods of interest. + signed : bool, optional + If True, returns signed connectivity matrices. + transfer_features : bool, optional + Whether to transfer features. + """ + + def __init__( + self, + max_rank=None, + neighborhoods=None, + signed=False, + transfer_features=True, + ): + complex2plain = Complex2PlainComplex( + max_rank=max_rank, + neighborhoods=neighborhoods, + signed=signed, + transfer_features=transfer_features, + ) + plain2dict = PlainComplex2Dict() + super().__init__(converters=(complex2plain, plain2dict)) diff --git a/topobenchmarkx/transforms/feature_liftings/base.py b/topobenchmarkx/transforms/feature_liftings/base.py new file mode 100644 index 00000000..c5969398 --- /dev/null +++ b/topobenchmarkx/transforms/feature_liftings/base.py @@ -0,0 +1,13 @@ +import abc + + +class FeatureLiftingMap(abc.ABC): + """Feature lifting map.""" + + def __call__(self, domain): + """Lift features of a domain.""" + return self.lift_features(domain) + + @abc.abstractmethod + def lift_features(self, domain): + """Lift features of a domain.""" diff --git a/topobenchmarkx/transforms/feature_liftings/identity.py b/topobenchmarkx/transforms/feature_liftings/identity.py index 93806f1d..9abf4e5d 100644 --- a/topobenchmarkx/transforms/feature_liftings/identity.py +++ b/topobenchmarkx/transforms/feature_liftings/identity.py @@ -1,36 +1,13 @@ """Identity transform that does nothing to the input data.""" -import torch_geometric +from .base import FeatureLiftingMap -class Identity(torch_geometric.transforms.BaseTransform): - r"""An identity transform that does nothing to the input data. +class Identity(FeatureLiftingMap): + """Identity feature lifting map.""" - Parameters - ---------- - **kwargs : optional - Parameters for the base transform. - """ + # TODO: rename to IdentityFeatureLifting - def __init__(self, **kwargs): - super().__init__() - self.type = "domain2domain" - self.parameters = kwargs - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" - - def forward(self, data: torch_geometric.data.Data): - r"""Apply the transform to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The same data. - """ - return data + def lift_features(self, domain): + """Lift features of a domain using identity map.""" + return domain diff --git a/topobenchmarkx/transforms/feature_liftings/projection_sum.py b/topobenchmarkx/transforms/feature_liftings/projection_sum.py index 3cce03eb..4d0c04b5 100644 --- a/topobenchmarkx/transforms/feature_liftings/projection_sum.py +++ b/topobenchmarkx/transforms/feature_liftings/projection_sum.py @@ -1,69 +1,30 @@ """ProjectionSum class.""" -import torch -import torch_geometric +from .base import FeatureLiftingMap -class ProjectionSum(torch_geometric.transforms.BaseTransform): - r"""Lift r-cell features to r+1-cells by projection. +class ProjectionSum(FeatureLiftingMap): + r"""Lift r-cell features to r+1-cells by projection.""" - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. - """ - - def __init__(self, **kwargs): - super().__init__() - - def __repr__(self) -> str: - return f"{self.__class__.__name__}()" - - def lift_features( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: + def lift_features(self, domain): r"""Project r-cell features of a graph to r+1-cell structures. Parameters ---------- - data : torch_geometric.data.Data | dict + data : PlainComplex The input data to be lifted. Returns ------- - torch_geometric.data.Data | dict - The data with the lifted features. + PlainComplex + Domain with the lifted features. """ - keys = sorted( - [ - key.split("_")[1] - for key in data - if ("incidence" in key and "-" not in key) - ] - ) - for elem in keys: - if f"x_{elem}" not in data: - idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 - data["x_" + elem] = torch.matmul( - abs(data["incidence_" + elem].t()), - data[f"x_{idx_to_project}"], - ) - return data + for rank in range(domain.max_rank - 1): + if domain.features[rank + 1] is not None: + continue - def forward( - self, data: torch_geometric.data.Data | dict - ) -> torch_geometric.data.Data | dict: - r"""Apply the lifting to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. + domain.features[rank + 1] = domain.propagate_values( + rank, domain.features[rank] + ) - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. - """ - data = self.lift_features(data) - return data + return domain diff --git a/topobenchmarkx/transforms/liftings/base.py b/topobenchmarkx/transforms/liftings/base.py index c08a54e5..fa00e40e 100644 --- a/topobenchmarkx/transforms/liftings/base.py +++ b/topobenchmarkx/transforms/liftings/base.py @@ -1,10 +1,99 @@ """Abstract class for topological liftings.""" -from abc import abstractmethod +import abc import torch_geometric +from topobenchmarkx.transforms.converters import IdentityConverter from topobenchmarkx.transforms.feature_liftings import FEATURE_LIFTINGS +from topobenchmarkx.transforms.feature_liftings.identity import ( + Identity, +) + + +class LiftingTransform(torch_geometric.transforms.BaseTransform): + """Lifting transform. + + Parameters + ---------- + data2domain : Converter + Conversion between ``torch_geometric.Data`` into + domain for consumption by lifting. + domain2dict : Converter + Conversion between output domain of feature lifting + and ``torch_geometric.Data``. + lifting : LiftingMap + Lifting map. + domain2domain : Converter + Conversion between output domain of lifting + and input domain for feature lifting. + feature_lifting : FeatureLiftingMap + Feature lifting map. + """ + + # NB: emulates previous AbstractLifting + def __init__( + self, + data2domain, + domain2dict, + lifting, + domain2domain=None, + feature_lifting=None, + ): + if feature_lifting is None: + feature_lifting = Identity() + + if domain2domain is None: + domain2domain = IdentityConverter() + + self.data2domain = data2domain + self.domain2domain = domain2domain + self.domain2dict = domain2dict + self.lifting = lifting + self.feature_lifting = feature_lifting + + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: + r"""Apply the full lifting (topology + features) to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data to be lifted. + + Returns + ------- + torch_geometric.data.Data + The lifted data. + """ + initial_data = data.to_dict() + + domain = self.data2domain(data) + lifted_topology = self.lifting(domain) + lifted_topology = self.domain2domain(lifted_topology) + lifted_topology = self.feature_lifting(lifted_topology) + lifted_topology_dict = self.domain2dict(lifted_topology) + + # TODO: make this line more clear + return torch_geometric.data.Data( + **initial_data, **lifted_topology_dict + ) + + +class LiftingMap(abc.ABC): + """Lifting map. + + Lifts a domain into another. + """ + + def __call__(self, domain): + """Lift domain.""" + return self.lift(domain) + + @abc.abstractmethod + def lift(self, domain): + """Lift domain.""" class AbstractLifting(torch_geometric.transforms.BaseTransform): @@ -18,12 +107,14 @@ class AbstractLifting(torch_geometric.transforms.BaseTransform): Additional arguments for the class. """ + # TODO: delete + def __init__(self, feature_lifting=None, **kwargs): super().__init__() self.feature_lifting = FEATURE_LIFTINGS[feature_lifting]() self.neighborhoods = kwargs.get("neighborhoods") - @abstractmethod + @abc.abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lift the topology of a graph to higher-order topological domains. diff --git a/topobenchmarkx/transforms/liftings/graph2simplicial/clique.py b/topobenchmarkx/transforms/liftings/graph2simplicial/clique.py index af7d5cdf..990d2e6e 100755 --- a/topobenchmarkx/transforms/liftings/graph2simplicial/clique.py +++ b/topobenchmarkx/transforms/liftings/graph2simplicial/clique.py @@ -1,32 +1,30 @@ """This module implements the CliqueLifting class, which lifts graphs to simplicial complexes.""" from itertools import combinations -from typing import Any import networkx as nx -import torch_geometric from toponetx.classes import SimplicialComplex -from topobenchmarkx.transforms.liftings.graph2simplicial import ( - Graph2SimplicialLifting, -) +from topobenchmarkx.transforms.liftings.base import LiftingMap -class SimplicialCliqueLifting(Graph2SimplicialLifting): +class SimplicialCliqueLifting(LiftingMap): r"""Lift graphs to simplicial complex domain. The algorithm creates simplices by identifying the cliques and considering them as simplices of the same dimension. Parameters ---------- - **kwargs : optional - Additional arguments for the class. + complex_dim : int + Maximum rank of the complex. """ - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, complex_dim=2): + super().__init__() + # TODO: better naming + self.complex_dim = complex_dim - def lift_topology(self, data: torch_geometric.data.Data) -> dict: + def lift(self, domain): r"""Lift the topology of a graph to a simplicial complex. Parameters @@ -39,12 +37,11 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: dict The lifted topology. """ - graph = self._generate_graph_from_data(data) + graph = domain + simplicial_complex = SimplicialComplex(graph) cliques = nx.find_cliques(graph) - simplices: list[set[tuple[Any, ...]]] = [ - set() for _ in range(2, self.complex_dim + 1) - ] + simplices = [set() for _ in range(2, self.complex_dim + 1)] for clique in cliques: for i in range(2, self.complex_dim + 1): for c in combinations(clique, i + 1): @@ -53,4 +50,5 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: for set_k_simplices in simplices: simplicial_complex.add_simplices_from(list(set_k_simplices)) - return self._get_lifted_topology(simplicial_complex, graph) + # TODO: need to check for edge preservation + return simplicial_complex