diff --git a/configs/dataset/graph/US-county-demos.yaml b/configs/dataset/graph/US-county-demos.yaml index 6e21a4a9..87b12fea 100755 --- a/configs/dataset/graph/US-county-demos.yaml +++ b/configs/dataset/graph/US-county-demos.yaml @@ -30,6 +30,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 0 pin_memory: False diff --git a/configs/dataset/graph/amazon_ratings.yaml b/configs/dataset/graph/amazon_ratings.yaml index 3e5a9dae..149b20ea 100755 --- a/configs/dataset/graph/amazon_ratings.yaml +++ b/configs/dataset/graph/amazon_ratings.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 0 pin_memory: False diff --git a/configs/dataset/graph/cocitation_citeseer.yaml b/configs/dataset/graph/cocitation_citeseer.yaml index cfb1b6fe..b92f31a9 100755 --- a/configs/dataset/graph/cocitation_citeseer.yaml +++ b/configs/dataset/graph/cocitation_citeseer.yaml @@ -28,6 +28,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/configs/dataset/graph/cocitation_cora.yaml b/configs/dataset/graph/cocitation_cora.yaml index d2b9fa3b..64de64e3 100755 --- a/configs/dataset/graph/cocitation_cora.yaml +++ b/configs/dataset/graph/cocitation_cora.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/configs/dataset/graph/cocitation_pubmed.yaml b/configs/dataset/graph/cocitation_pubmed.yaml index 7d901437..c974b6b1 100755 --- a/configs/dataset/graph/cocitation_pubmed.yaml +++ b/configs/dataset/graph/cocitation_pubmed.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/configs/dataset/graph/manual_dataset.yaml b/configs/dataset/graph/manual_dataset.yaml index e0357d2b..bafe272a 100755 --- a/configs/dataset/graph/manual_dataset.yaml +++ b/configs/dataset/graph/manual_dataset.yaml @@ -28,6 +28,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 + batch_size: -1 num_workers: 1 pin_memory: False diff --git a/configs/dataset/graph/minesweeper.yaml b/configs/dataset/graph/minesweeper.yaml index 19119e78..c487de79 100755 --- a/configs/dataset/graph/minesweeper.yaml +++ b/configs/dataset/graph/minesweeper.yaml @@ -28,6 +28,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 0 pin_memory: False diff --git a/configs/dataset/graph/questions.yaml b/configs/dataset/graph/questions.yaml index 25333b75..a10d0f9a 100755 --- a/configs/dataset/graph/questions.yaml +++ b/configs/dataset/graph/questions.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/configs/dataset/graph/roman_empire.yaml b/configs/dataset/graph/roman_empire.yaml index 37adfb4b..e40d0e7b 100755 --- a/configs/dataset/graph/roman_empire.yaml +++ b/configs/dataset/graph/roman_empire.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 0 pin_memory: False diff --git a/configs/dataset/graph/tolokers.yaml b/configs/dataset/graph/tolokers.yaml index f1657f16..2da6e9af 100755 --- a/configs/dataset/graph/tolokers.yaml +++ b/configs/dataset/graph/tolokers.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/configs/dataset/hypergraph/coauthorship_cora.yaml b/configs/dataset/hypergraph/coauthorship_cora.yaml index 80699bbd..2bc0ea7c 100755 --- a/configs/dataset/hypergraph/coauthorship_cora.yaml +++ b/configs/dataset/hypergraph/coauthorship_cora.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/configs/dataset/hypergraph/coauthorship_dblp.yaml b/configs/dataset/hypergraph/coauthorship_dblp.yaml index 5f4c4e25..0e378a9b 100755 --- a/configs/dataset/hypergraph/coauthorship_dblp.yaml +++ b/configs/dataset/hypergraph/coauthorship_dblp.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/configs/dataset/hypergraph/cocitation_citeseer.yaml b/configs/dataset/hypergraph/cocitation_citeseer.yaml index d51b884f..7823c357 100755 --- a/configs/dataset/hypergraph/cocitation_citeseer.yaml +++ b/configs/dataset/hypergraph/cocitation_citeseer.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/configs/dataset/hypergraph/cocitation_cora.yaml b/configs/dataset/hypergraph/cocitation_cora.yaml index 557b0a14..cbe8c613 100755 --- a/configs/dataset/hypergraph/cocitation_cora.yaml +++ b/configs/dataset/hypergraph/cocitation_cora.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/configs/dataset/hypergraph/cocitation_pubmed.yaml b/configs/dataset/hypergraph/cocitation_pubmed.yaml index 8aa19826..6fb00abf 100755 --- a/configs/dataset/hypergraph/cocitation_pubmed.yaml +++ b/configs/dataset/hypergraph/cocitation_pubmed.yaml @@ -27,6 +27,6 @@ split_params: # Dataloader parameters dataloader_params: - batch_size: 1 # Fixed + batch_size: -1 # Fixed num_workers: 1 pin_memory: False diff --git a/test/data/batching/test_neighbor_cells_loader.py b/test/data/batching/test_neighbor_cells_loader.py new file mode 100644 index 00000000..5153183b --- /dev/null +++ b/test/data/batching/test_neighbor_cells_loader.py @@ -0,0 +1,132 @@ +""" Test for the NeighborCellsLoader class.""" +import os +import shutil +import rootutils +from hydra import compose +import torch + +from topobenchmark.data.preprocessor import PreProcessor +from topobenchmark.data.utils.utils import load_manual_graph +from topobenchmark.data.batching import NeighborCellsLoader +from topobenchmark.run import initialize_hydra + +initialize_hydra() + +path = "./graph2simplicial_lifting/" +if os.path.isdir(path): + shutil.rmtree(path) +cfg = compose(config_name="run.yaml", + overrides=["dataset=graph/manual_dataset", "model=simplicial/san"], + return_hydra_config=True) + +data = load_manual_graph() +preprocessed_dataset = PreProcessor(data, path, cfg['transforms']) +data = preprocessed_dataset[0] + +batch_size=2 + +rank = 0 +n_cells = data[f'x_{rank}'].shape[0] +train_prop = 0.5 +n_train = int(train_prop * n_cells) +train_mask = torch.zeros(n_cells, dtype=torch.bool) +train_mask[:n_train] = 1 + +y = torch.zeros(n_cells, dtype=torch.long) +data.y = y + +loader = NeighborCellsLoader(data, + rank=rank, + num_neighbors=[-1], + input_nodes=train_mask, + batch_size=batch_size, + shuffle=False) +train_nodes = [] +for batch in loader: + train_nodes += [n for n in batch.n_id[:batch_size]] +for i in range(n_train): + assert i in train_nodes + +rank = 1 +n_cells = data[f'x_{rank}'].shape[0] +train_prop = 0.5 +n_train = int(train_prop * n_cells) +train_mask = torch.zeros(n_cells, dtype=torch.bool) +train_mask[:n_train] = 1 + +y = torch.zeros(n_cells, dtype=torch.long) +data.y = y + +loader = NeighborCellsLoader(data, + rank=rank, + num_neighbors=[-1,-1], + input_nodes=train_mask, + batch_size=batch_size, + shuffle=False) + +train_nodes = [] +for batch in loader: + train_nodes += [n for n in batch.n_id[:batch_size]] +for i in range(n_train): + assert i in train_nodes +shutil.rmtree(path) + + +path = "./graph2hypergraph_lifting/" +if os.path.isdir(path): + shutil.rmtree(path) +cfg = compose(config_name="run.yaml", + overrides=["dataset=graph/manual_dataset", "model=hypergraph/allsettransformer"], + return_hydra_config=True) + +data = load_manual_graph() +preprocessed_dataset = PreProcessor(data, path, cfg['transforms']) +data = preprocessed_dataset[0] + +batch_size=2 + +rank = 0 +n_cells = data[f'x_0'].shape[0] +train_prop = 0.5 +n_train = int(train_prop * n_cells) +train_mask = torch.zeros(n_cells, dtype=torch.bool) +train_mask[:n_train] = 1 + +y = torch.zeros(n_cells, dtype=torch.long) +data.y = y + +loader = NeighborCellsLoader(data, + rank=rank, + num_neighbors=[-1], + input_nodes=train_mask, + batch_size=batch_size, + shuffle=False) +train_nodes = [] +for batch in loader: + train_nodes += [n for n in batch.n_id[:batch_size]] +for i in range(n_train): + assert i in train_nodes + +rank = 1 +n_cells = data[f'x_hyperedges'].shape[0] +train_prop = 0.5 +n_train = int(train_prop * n_cells) +train_mask = torch.zeros(n_cells, dtype=torch.bool) +train_mask[:n_train] = 1 + +y = torch.zeros(n_cells, dtype=torch.long) +data.y = y + +loader = NeighborCellsLoader(data, + rank=rank, + num_neighbors=[-1,-1], + input_nodes=train_mask, + batch_size=batch_size, + shuffle=False) + +train_nodes = [] +for batch in loader: + train_nodes += [n for n in batch.n_id[:batch_size]] +for i in range(n_train): + assert i in train_nodes +shutil.rmtree(path) \ No newline at end of file diff --git a/test/data/dataload/test_Dataloaders.py b/test/data/dataload/test_Dataloaders.py index 35770d68..26b7de36 100644 --- a/test/data/dataload/test_Dataloaders.py +++ b/test/data/dataload/test_Dataloaders.py @@ -20,10 +20,6 @@ class TestCollateFunction: def setup_method(self): """Setup the test.""" - - hydra.initialize( - version_base="1.3", config_path="../../../configs", job_name="run" - ) cfg = hydra.compose(config_name="run.yaml", overrides=["dataset=graph/NCI1"]) graph_loader = hydra.utils.instantiate(cfg.dataset.loader, _recursive_=False) diff --git a/topobenchmark/data/batching/__init__.py b/topobenchmark/data/batching/__init__.py new file mode 100644 index 00000000..71952e48 --- /dev/null +++ b/topobenchmark/data/batching/__init__.py @@ -0,0 +1,7 @@ +"""Init file for batching module.""" + +from .neighbor_cells_loader import NeighborCellsLoader + +__all__ = [ + "NeighborCellsLoader", +] diff --git a/topobenchmark/data/batching/cell_loader.py b/topobenchmark/data/batching/cell_loader.py new file mode 100644 index 00000000..21593ec6 --- /dev/null +++ b/topobenchmark/data/batching/cell_loader.py @@ -0,0 +1,241 @@ +"""Cell Loader module from PyTorch Geometric with custom filter_data function.""" + +from collections.abc import Callable, Iterator +from typing import Any + +import torch +from torch import Tensor +from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData +from torch_geometric.loader.base import DataLoaderIterator +from torch_geometric.loader.mixin import ( + AffinityMixin, + LogMemoryMixin, + MultithreadingMixin, +) +from torch_geometric.loader.utils import ( + get_input_nodes, + infer_filter_per_worker, +) +from torch_geometric.sampler import ( + BaseSampler, + HeteroSamplerOutput, + NodeSamplerInput, + SamplerOutput, +) +from torch_geometric.typing import InputNodes, OptTensor + +from topobenchmark.data.batching.utils import filter_data + + +class CellLoader( + torch.utils.data.DataLoader, + AffinityMixin, + MultithreadingMixin, + LogMemoryMixin, +): + r"""A data loader that performs mini-batch sampling from cell information. + + It uses a generic :class:`~torch_geometric.sampler.BaseSampler` + implementation that defines a + :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes` function and + is supported on the provided input :obj:`data` object. + + Parameters + ---------- + data : Any + A :class:`~torch_geometric.data.Data`, + :class:`~torch_geometric.data.HeteroData`, or + (:class:`~torch_geometric.data.FeatureStore`, + :class:`~torch_geometric.data.GraphStore`) data object. + cell_sampler : torch_geometric.sampler.BaseSampler + The sampler implementation to be used with this loader. + Needs to implement + :meth:`~torch_geometric.sampler.BaseSampler.sample_from_cells`. + The sampler implementation must be compatible with the input + :obj:`data` object. + input_cells : torch.Tensor or str or Tuple[str, torch.Tensor] + The indices of seed cells to start sampling from. + Needs to be either given as a :obj:`torch.LongTensor` or + :obj:`torch.BoolTensor`. + If set to :obj:`None`, all cells will be considered. + In heterogeneous graphs, needs to be passed as a tuple that holds + the cell type and cell indices. (default: :obj:`None`). + input_time : torch.Tensor, optional + Optional values to override the timestamp for the input cells given in + :obj:`input_cells`. If not set, will use the timestamps in + :obj:`time_attr` as default (if present). The :obj:`time_attr` needs + to be set for this to work. (default: :obj:`None`). + transform : callable, optional + A function/transform that takes in a sampled mini-batch and returns a + transformed version. (default: :obj:`None`). + transform_sampler_output : callable, optional + A function/transform that takes in a + :class:`torch_geometric.sampler.SamplerOutput` and returns a + transformed version. (default: :obj:`None`). + filter_per_worker : bool, optional + If set to :obj:`True`, will filter the returned data in each worker's + subprocess. + If set to :obj:`False`, will filter the returned data in the main + process. + If set to :obj:`None`, will automatically infer the decision based + on whether data partially lives on the GPU + (:obj:`filter_per_worker=True`) or entirely on the CPU + (:obj:`filter_per_worker=False`). + There exists different trade-offs for setting this option. + Specifically, setting this option to :obj:`True` for in-memory + datasets will move all features to shared memory, which may result + in too many open file handles. (default: :obj:`None`). + custom_cls : torch_geometric.data.HeteroData, optional + A custom :class:`~torch_geometric.data.HeteroData` class to return for + mini-batches in case of remote backends. (default: :obj:`None`). + input_id : torch.Tensor, optional + The indices of the input cells in the original data object. + (default: :obj:`None`). + **kwargs : optional + Additional arguments of :class:`torch.utils.data.DataLoader`, such as + :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or + :obj:`num_workers`. + """ + + def __init__( + self, + data: Data | HeteroData | tuple[FeatureStore, GraphStore], + cell_sampler: BaseSampler, + input_cells: InputNodes = None, + input_time: OptTensor = None, + transform: Callable | None = None, + transform_sampler_output: Callable | None = None, + filter_per_worker: bool | None = None, + custom_cls: HeteroData | None = None, + input_id: OptTensor = None, + **kwargs, + ): + if filter_per_worker is None: + filter_per_worker = infer_filter_per_worker(data) + + self.data = data + self.cell_sampler = cell_sampler + self.input_cells = input_cells + self.input_time = input_time + self.transform = transform + self.transform_sampler_output = transform_sampler_output + self.filter_per_worker = filter_per_worker + self.custom_cls = custom_cls + self.input_id = input_id + + kwargs.pop("dataset", None) + kwargs.pop("collate_fn", None) + + # Get cell type (or `None` for homogeneous graphs): + input_type, input_cells, input_id = get_input_nodes( + data, input_cells, input_id + ) + + self.input_data = NodeSamplerInput( + input_id=input_id, + node=input_cells, + time=input_time, + input_type=input_type, + ) + + iterator = range(input_cells.size(0)) + super().__init__(iterator, collate_fn=self.collate_fn, **kwargs) + + def __call__( + self, + index: Tensor | list[int], + ) -> Data | HeteroData: + r"""Sample a subgraph from a batch of input cells. + + Parameters + ---------- + index : torch.Tensor or List[int] + The indices of cells to sample. + + Returns + ------- + Union[Data, HeteroData] + The sampled subgraph. + """ + out = self.collate_fn(index) + if not self.filter_per_worker: + out = self.filter_fn(out) + return out + + def collate_fn(self, index: Tensor | list[int]) -> Any: + r"""Sample a subgraph from a batch of input cells. + + Parameters + ---------- + index : torch.Tensor or List[int] + The indices of cells to sample. + + Returns + ------- + Any + The sampled subgraph. + """ + input_data: NodeSamplerInput = self.input_data[index] + + out = self.cell_sampler.sample_from_nodes(input_data) + + if self.filter_per_worker: # Execute `filter_fn` in the worker process + out = self.filter_fn(out) + + return out + + def filter_fn( + self, + out: SamplerOutput | HeteroSamplerOutput, + ) -> Data | HeteroData: + r"""Join the sampled cells with their corresponding features. + + It returns the resulting :class:`~torch_geometric.data.Data` + object to be used downstream. + + Parameters + ---------- + out : Union[SamplerOutput, HeteroSamplerOutput] + The output of the sampler. + + Returns + ------- + Union[Data, HeteroData] + The resulting data object. + """ + if self.transform_sampler_output: + out = self.transform_sampler_output(out) + + if isinstance(out, SamplerOutput) and isinstance(self.data, Data): + data = filter_data(self.data, out.node, self.rank) + else: + raise TypeError( + f"'{self.__class__.__name__}'' found invalid " + f"type: '{type(data)}'" + ) + + return data if self.transform is None else self.transform(data) + + def _get_iterator(self) -> Iterator: + r"""Return the internal iterator to be used for sampling. + + Returns + ------- + Iterator + The internal iterator to be used for sampling. + """ + if self.filter_per_worker: + return super()._get_iterator() + + # if not self.is_cuda_available and not self.cpu_affinity_enabled: + # TODO: Add manual page for best CPU practices + # link = ... + # Warning('Dataloader CPU affinity opt is not enabled, consider ' + # 'switching it on with enable_cpu_affinity() or see CPU ' + # f'best practices for PyG [{link}])') + + # Execute `filter_fn` in the main process: + return DataLoaderIterator(super()._get_iterator(), self.filter_fn) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/topobenchmark/data/batching/neighbor_cells_loader.py b/topobenchmark/data/batching/neighbor_cells_loader.py new file mode 100644 index 00000000..be772ef9 --- /dev/null +++ b/topobenchmark/data/batching/neighbor_cells_loader.py @@ -0,0 +1,182 @@ +"""NeighborCellsLoader class to batch in the transductive setting when working with topological domains.""" + +from collections.abc import Callable + +from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData +from torch_geometric.sampler import NeighborSampler +from torch_geometric.sampler.base import SubgraphType +from torch_geometric.typing import EdgeType, InputNodes, OptTensor + +from topobenchmark.data.batching.cell_loader import CellLoader +from topobenchmark.data.batching.utils import get_sampled_neighborhood +from topobenchmark.dataloader import DataloadDataset + + +class NeighborCellsLoader(CellLoader): + r"""A data loader that samples neighbors for each cell. Cells are considered neighbors if they are upper or lower neighbors. + + Parameters + ---------- + data : Any + A :class:`~torch_geometric.data.Data`, + :class:`~torch_geometric.data.HeteroData`, or + (:class:`~torch_geometric.data.FeatureStore`, + :class:`~torch_geometric.data.GraphStore`) data object. + rank : int + The rank of the cells to consider. + num_neighbors : List[int] or Dict[Tuple[str, str, str], List[int]] + The number of neighbors to sample for each node in each iteration. + If an entry is set to :obj:`-1`, all neighbors will be included. + input_nodes : torch.Tensor or str or Tuple[str, torch.Tensor] + The indices of nodes for which neighbors are sampled to create + mini-batches. + Needs to be either given as a :obj:`torch.LongTensor` or + :obj:`torch.BoolTensor`. + If set to :obj:`None`, all nodes will be considered. + input_time : torch.Tensor, optional + Optional values to override the timestamp for the input nodes given in :obj:`input_nodes`. If not + set, will use the timestamps in :obj:`time_attr` as default (if + present). The :obj:`time_attr` needs to be set for this to work. + (default: :obj:`None`). + replace : bool, optional + If set to :obj:`True`, will sample with replacement. (default: :obj:`False`). + subgraph_type : SubgraphType or str, optional + The type of the returned subgraph. + If set to :obj:`"directional"`, the returned subgraph only holds + the sampled (directed) edges which are necessary to compute + representations for the sampled seed nodes. + If set to :obj:`"bidirectional"`, sampled edges are converted to + bidirectional edges. + If set to :obj:`"induced"`, the returned subgraph contains the + induced subgraph of all sampled nodes. + (default: :obj:`"directional"`). + disjoint : bool, optional + If set to :obj: `True`, each seed node will create its own disjoint subgraph. + If set to :obj:`True`, mini-batch outputs will have a :obj:`batch` + vector holding the mapping of nodes to their respective subgraph. + Will get automatically set to :obj:`True` in case of temporal + sampling. (default: :obj:`False`). + temporal_strategy : str, optional + The sampling strategy when using temporal sampling (:obj:`"uniform"`, :obj:`"last"`). + If set to :obj:`"uniform"`, will sample uniformly across neighbors + that fulfill temporal constraints. + If set to :obj:`"last"`, will sample the last `num_neighbors` that + fulfill temporal constraints. + (default: :obj:`"uniform"`). + time_attr : str, optional + The name of the attribute that denotes timestamps for either the nodes or edges in the graph. + If set, temporal sampling will be used such that neighbors are + guaranteed to fulfill temporal constraints, *i.e.* neighbors have + an earlier or equal timestamp than the center node. + (default: :obj:`None`). + weight_attr : str, optional + The name of the attribute that denotes edge weights in the graph. + If set, weighted/biased sampling will be used such that neighbors + are more likely to get sampled the higher their edge weights are. + Edge weights do not need to sum to one, but must be non-negative, + finite and have a non-zero sum within local neighborhoods. + (default: :obj:`None`). + transform : callable, optional + A function/transform that takes in a sampled mini-batch and returns a transformed version. + (default: :obj:`None`). + transform_sampler_output : callable, optional + A function/transform that takes in a :class:`torch_geometric.sampler.SamplerOutput` and + returns a transformed version. (default: :obj:`None`). + is_sorted : bool, optional + If set to :obj:`True`, assumes that :obj:`edge_index` is sorted by column. + If :obj:`time_attr` is set, additionally requires that rows are + sorted according to time within individual neighborhoods. + This avoids internal re-sorting of the data and can improve + runtime and memory efficiency. (default: :obj:`False`). + filter_per_worker : bool, optional + If set to :obj:`True`, will filter the returned data in each worker's subprocess. + If set to :obj:`False`, will filter the returned data in the main process. + If set to :obj:`None`, will automatically infer the decision based + on whether data partially lives on the GPU + (:obj:`filter_per_worker=True`) or entirely on the CPU + (:obj:`filter_per_worker=False`). + There exists different trade-offs for setting this option. + Specifically, setting this option to :obj:`True` for in-memory + datasets will move all features to shared memory, which may result + in too many open file handles. (default: :obj:`None`). + neighbor_sampler : NeighborSampler, optional + The neighbor sampler implementation to be used with this loader. + If not set, a new :class:`torch_geometric.sampler.NeighborSampler` + instance will be created. (default: :obj:`None`). + directed : bool, optional + If set to :obj:`True`, will consider the graph as directed. + If set to :obj:`False`, will consider the graph as undirected. + (default: :obj:`True`). + **kwargs : optional + Additional arguments of :class:`torch.utils.data.DataLoader`, such as + :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`. + """ + + def __init__( + self, + data: Data | HeteroData | tuple[FeatureStore, GraphStore], + rank: int, + num_neighbors: list[int] | dict[EdgeType, list[int]], + input_nodes: InputNodes = None, + input_time: OptTensor = None, + replace: bool = False, + subgraph_type: SubgraphType | str = "directional", + disjoint: bool = False, + temporal_strategy: str = "uniform", + time_attr: str | None = None, + weight_attr: str | None = None, + transform: Callable | None = None, + transform_sampler_output: Callable | None = None, + is_sorted: bool = False, + filter_per_worker: bool | None = None, + neighbor_sampler: NeighborSampler | None = None, + directed: bool = True, + **kwargs, + ): + if input_time is not None and time_attr is None: + raise ValueError( + "Received conflicting 'input_time' and " + "'time_attr' arguments: 'input_time' is set " + "while 'time_attr' is not set." + ) + + data_obj = Data() + if isinstance(data, DataloadDataset): + for tensor, name in zip(data[0][0], data[0][1], strict=False): + setattr(data_obj, name, tensor) + else: + data_obj = data + is_hypergraph = hasattr(data_obj, "incidence_hyperedges") + n_hops = len(num_neighbors) + data_obj = get_sampled_neighborhood( + data_obj, rank, n_hops, is_hypergraph + ) + self.rank = rank + if self.rank != 0: + # When rank is different than 0 get_sampled_neighborhood connects cells that are up to n_hops away, meaning that the NeighborhoodSampler needs to consider only one hop. + num_neighbors = [num_neighbors[0]] + if neighbor_sampler is None: + neighbor_sampler = NeighborSampler( + data_obj, + num_neighbors=num_neighbors, + replace=replace, + subgraph_type=subgraph_type, + disjoint=disjoint, + temporal_strategy=temporal_strategy, + time_attr=time_attr, + weight_attr=weight_attr, + is_sorted=is_sorted, + share_memory=kwargs.get("num_workers", 0) > 0, + directed=directed, + ) + + super().__init__( + data=data_obj, + cell_sampler=neighbor_sampler, + input_cells=input_nodes, + input_time=input_time, + transform=transform, + transform_sampler_output=transform_sampler_output, + filter_per_worker=filter_per_worker, + **kwargs, + ) diff --git a/topobenchmark/data/batching/utils.py b/topobenchmark/data/batching/utils.py new file mode 100644 index 00000000..78df4fe9 --- /dev/null +++ b/topobenchmark/data/batching/utils.py @@ -0,0 +1,331 @@ +"""Utility functions for batching cells of different ranks.""" + +import copy + +import torch +import torch_geometric.typing +from torch import Tensor +from torch_geometric.data import Data + + +def reduce_higher_ranks_incidences( + batch, cells_ids, rank, max_rank, is_hypergraph=False +): + """Reduce the incidences with higher rank than the specified one. + + Parameters + ---------- + batch : torch_geometric.data.Data + The input data. + cells_ids : list[torch.Tensor] + List of tensors containing the ids of the cells. The length of the list should be equal to the maximum rank. + rank : int + The rank to select the higher order incidences. + max_rank : int + The maximum rank of the incidences. + is_hypergraph : bool + Whether the data represents an hypergraph. + + Returns + ------- + torch_geometric.data.Data + The output data with the reduced incidences. + list[torch.Tensor] + The updated indices of the cells. Each element of the list is a tensor containing the ids of the cells of the corresponding rank. + """ + for i in range(rank + 1, max_rank + 1): + if is_hypergraph: + incidence = batch.incidence_hyperedges + else: + incidence = batch[f"incidence_{i}"] + + # if i != rank+1: + incidence = torch.index_select(incidence, 0, cells_ids[i - 1]) + cells_ids[i] = torch.where(torch.sum(incidence, dim=0).to_dense() > 1)[ + 0 + ] + incidence = torch.index_select(incidence, 1, cells_ids[i]) + if is_hypergraph: + batch.incidence_hyperedges = incidence + else: + batch[f"incidence_{i}"] = incidence + + return batch, cells_ids + + +def reduce_lower_ranks_incidences(batch, cells_ids, rank, is_hypergraph=False): + """Reduce the incidences with lower rank than the specified one. + + Parameters + ---------- + batch : torch_geometric.data.Data + The input data. + cells_ids : list[torch.Tensor] + List of tensors containing the ids of the cells. The length of the list should be equal to the maximum rank. + rank : int + The rank of the cells to consider. + is_hypergraph : bool + Whether the data represents an hypergraph. + + Returns + ------- + torch.Tensor + The indices of the nodes contained by the cells. + list[torch.Tensor] + The updated indices of the cells. Each element of the list is a tensor containing the ids of the cells of the corresponding rank. + """ + for i in range(rank, 0, -1): + if is_hypergraph: + incidence = batch.incidence_hyperedges + else: + incidence = batch[f"incidence_{i}"] + incidence = torch.index_select(incidence, 1, cells_ids[i]) + cells_ids[i - 1] = torch.where( + torch.sum(incidence, dim=1).to_dense() > 0 + )[0] + incidence = torch.index_select(incidence, 0, cells_ids[i - 1]) + if is_hypergraph: + batch.incidence_hyperedges = incidence + else: + batch[f"incidence_{i}"] = incidence + + if not is_hypergraph: + incidence = batch["incidence_0"] + incidence = torch.index_select(incidence, 1, cells_ids[0]) + batch["incidence_0"] = incidence + return batch, cells_ids + + +def reduce_matrices(batch, cells_ids, names, max_rank): + """Reduce the matrices using the indices in cells_ids. + + The matrices are assumed to be in the batch with the names specified in the list names. + + Parameters + ---------- + batch : torch_geometric.data.Data + The input data. + cells_ids : list[torch.Tensor] + List of tensors containing the ids of the cells. The length of the list should be equal to the maximum rank. + names : list[str] + List of names of the matrices in the batch. They should appear in the format f"{name}{i}" where i is the rank of the matrix. + max_rank : int + The maximum rank of the matrices. + + Returns + ------- + torch_geometric.data.Data + The output data with the reduced matrices. + """ + for i in range(max_rank + 1): + for name in names: + if f"{name}{i}" in batch.keys(): # noqa + matrix = batch[f"{name}{i}"] + matrix = torch.index_select(matrix, 0, cells_ids[i]) + matrix = torch.index_select(matrix, 1, cells_ids[i]) + batch[f"{name}{i}"] = matrix + return batch + + +def reduce_neighborhoods(batch, node, rank=0, remove_self_loops=True): + """Reduce the neighborhoods of the cells in the batch. + + Parameters + ---------- + batch : torch_geometric.data.Data + The input data. + node : torch.Tensor + The indices of the cells to batch over. + rank : int + The rank of the cells to batch over. + remove_self_loops : bool + Whether to remove self loops from the edge_index. + + Returns + ------- + torch_geometric.data.Data + The output data with the reduced neighborhoods. + """ + is_hypergraph = False + if hasattr(batch, "incidence_hyperedges"): + is_hypergraph = True + max_rank = 1 + else: + max_rank = len([key for key in batch.keys() if "incidence" in key]) - 1 # noqa + + if rank > max_rank: + raise ValueError( + f"Rank {rank} is greater than the maximum rank {max_rank} in the dataset." + ) + + cells_ids = [None for _ in range(max_rank + 1)] + + # the indices of the cells selected by the NeighborhoodLoader are saved in the batch in the attribute n_id + cells_ids[rank] = node + + batch, cells_ids = reduce_higher_ranks_incidences( + batch, cells_ids, rank, max_rank, is_hypergraph + ) + batch, cells_ids = reduce_lower_ranks_incidences( + batch, cells_ids, rank, is_hypergraph + ) + + batch = reduce_matrices( + batch, + cells_ids, + names=[ + "down_laplacian_", + "up_laplacian_", + "hodge_laplacian_", + "adjacency_", + ], + max_rank=max_rank, + ) + + # reduce the feature matrices + for i in range(max_rank + 1): + if f"x_{i}" in batch.keys(): # noqa + batch[f"x_{i}"] = batch[f"x_{i}"][cells_ids[i]] + + # fix edge_index + if not is_hypergraph: + adjacency_0 = batch.adjacency_0.coalesce() + edge_index = adjacency_0.indices() + if remove_self_loops: + edge_index = torch_geometric.utils.remove_self_loops(edge_index)[0] + batch.edge_index = edge_index + + # fix x + batch.x = batch["x_0"] + if hasattr(batch, "num_nodes"): + batch.num_nodes = batch.x.shape[0] + + if hasattr(batch, "y"): + batch.y = batch.y[cells_ids[rank]] + + batch.cells_ids = cells_ids + return batch + + +def filter_data(data: Data, cells: Tensor, rank: int) -> Data: + """Filter the attributes of the data based on the cells passed. + + The function uses the indices passed to select the cells of the specified rank. The cells of lower or higher ranks are selected using the incidence matrices. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + cells : Tensor + Tensor containing the indices of the cells of the specified rank to keep. + rank : int + Rank of the cells of interest. + + Returns + ------- + torch_geometric.data.Data + The output data with the filtered attributes. + """ + out = copy.copy(data) + out = reduce_neighborhoods(out, cells, rank=rank) + out.n_id = cells + return out + + +def get_sampled_neighborhood(data, rank=0, n_hops=1, is_hypergraph=False): + """Update the edge_index attribute of torch_geometric.data.Data. + + The function finds cells, of the specified rank, that are either upper or lower neighbors. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + rank : int + The rank of the cells that you want to batch over. + n_hops : int + Two cells are considered neighbors if they are connected by n hops in the upper or lower neighborhoods. + is_hypergraph : bool + Whether the data represents an hypergraph. + + Returns + ------- + torch_geometric.data.Data + The output data with updated edge_index. + edge_index contains indices of connected cells of the specified rank K. + Two cells of rank K are connected if they are either lower or upper neighbors. + """ + if rank == 0: + data.edge_index = torch_geometric.utils.to_undirected(data.edge_index) + return data + if is_hypergraph: + if rank > 1: + raise ValueError( + "Hypergraphs are not supported for ranks greater than 1." + ) + if rank == 1: + incidence = data.incidence_hyperedges + A = torch.sparse.mm(incidence, incidence.T) # lower adj matrix + else: + incidence = data.incidence_hyperedges + A = torch.sparse.mm(incidence.T, incidence) + for _ in range(n_hops - 1): + A = torch.sparse.mm(A, A) + edges = A.indices() + else: + # get number of incidences + max_rank = len([key for key in data.keys() if "incidence" in key]) - 1 # noqa + if rank > max_rank: + raise ValueError( + f"Rank {rank} is greater than the maximum rank {max_rank} in the data." + ) + + # This considers the upper adjacencies + n_cells = data[f"x_{rank}"].shape[0] + A_sum = torch.sparse_coo_tensor([[], []], [], (n_cells, n_cells)) + if rank == max_rank: + edges = torch.empty((2, 0), dtype=torch.long) + else: + incidence = data[f"incidence_{rank+1}"] + A = torch.sparse.mm(incidence, incidence.T) + for _ in range(n_hops - 1): + A = torch.sparse.mm(A, A) + A_sum += A + + # This is for selecting the whole upper cells + # for i in range(rank+1, max_rank): + # P = torch.sparse.mm(P, data[f"incidence_{i+1}"]) + # Q = torch.sparse.mm(P,P.T) + # edges = torch.cat((edges, Q.indices()), dim=1) + + # This considers the lower adjacencies + if rank != 0: + incidence = data[f"incidence_{rank}"] + A = torch.sparse.mm(incidence.T, incidence) + for _ in range(n_hops - 1): + A = torch.sparse.mm(A, A) + A_sum += A + + # This is for selecting cells if they share any node + # for i in range(rank-1, 0, -1): + # P = torch.sparse.mm(data[f"incidence_{i}"], P) + # Q = torch.sparse.mm(P.T,P) + # edges = torch.cat((edges, Q.indices()), dim=1) + + edges = A_sum.coalesce().indices() + # Remove self edges + mask = edges[0, :] != edges[1, :] + edges = edges[:, mask] + + data.edge_index = edges + + # We need to set x to x_{rank} since NeighborLoader will take the number of nodes from the x attribute + # The correct x is given after the reduce_neighborhoods function + if is_hypergraph and rank == 1: + data.x = data.x_hyperedges + else: + data.x = data[f"x_{rank}"] + + if hasattr(data, "num_nodes"): + data.num_nodes = data.x.shape[0] + return data diff --git a/topobenchmark/dataloader/dataloader.py b/topobenchmark/dataloader/dataloader.py index 30c42689..b3d50a86 100755 --- a/topobenchmark/dataloader/dataloader.py +++ b/topobenchmark/dataloader/dataloader.py @@ -5,6 +5,7 @@ from lightning import LightningDataModule from torch.utils.data import DataLoader +from topobenchmark.data.batching import NeighborCellsLoader from topobenchmark.dataloader.dataload_dataset import DataloadDataset from topobenchmark.dataloader.utils import collate_fn @@ -24,6 +25,10 @@ class TBDataloader(LightningDataModule): The test dataset (default: None). batch_size : int, optional The batch size for the dataloader (default: 1). + rank : int, optional + The rank of the cells to consider when batching in the transductive setting (default: 0). + num_neighbors : list[int], optional + The number of neighbors to sample in the transductive setting. To consider n-hop neighborhoods this list should contain n elements. Care should be taken to check that the number of hops is appropriate for your model. With topological models the number of layers might not be enough to determine how far information is propagated. (default: [-1]). num_workers : int, optional The number of worker processes to use for data loading (default: 0). pin_memory : bool, optional @@ -43,6 +48,8 @@ def __init__( dataset_val: DataloadDataset = None, dataset_test: DataloadDataset = None, batch_size: int = 1, + rank: int = 0, + num_neighbors: list[int] | None = None, num_workers: int = 0, pin_memory: bool = False, **kwargs: Any, @@ -57,24 +64,67 @@ def __init__( ) self.dataset_train = dataset_train self.batch_size = batch_size - + self.transductive = False + self.rank = rank + self.num_neighbors = ( + num_neighbors if num_neighbors is not None else [-1] + ) if dataset_val is None and dataset_test is None: # Transductive setting self.dataset_val = dataset_train self.dataset_test = dataset_train - assert ( - self.batch_size == 1 - ), "Batch size must be 1 for transductive setting." + self.transductive = True else: self.dataset_val = dataset_val self.dataset_test = dataset_test self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = kwargs.get("persistent_workers", False) + self.kwargs = kwargs def __repr__(self) -> str: return f"{self.__class__.__name__}(dataset_train={self.dataset_train}, dataset_val={self.dataset_val}, dataset_test={self.dataset_test}, batch_size={self.batch_size})" + def _get_dataloader(self, split: str) -> DataLoader | NeighborCellsLoader: + r"""Create and return the dataloader for the specified split. + + Parameters + ---------- + split : str + The split to create the dataloader for. + + Returns + ------- + torch.utils.data.DataLoader | NeighborCellsLoader + The dataloader for the specified split. + """ + shuffle = split == "train" + + if not self.transductive or self.batch_size == -1: + batch_size = self.batch_size if self.batch_size != -1 else 1 + + return DataLoader( + dataset=getattr(self, f"dataset_{split}"), + batch_size=batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=shuffle, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + **self.kwargs, + ) + mask_idx = self.dataset_train[0][1].index(f"{split}_mask") + mask = self.dataset_train[0][0][mask_idx] + return NeighborCellsLoader( + data=getattr(self, f"dataset_{split}"), + rank=self.rank, + num_neighbors=self.num_neighbors, + input_nodes=mask, + batch_size=self.batch_size, + shuffle=shuffle, + **self.kwargs, + ) + def train_dataloader(self) -> DataLoader: r"""Create and return the train dataloader. @@ -83,15 +133,7 @@ def train_dataloader(self) -> DataLoader: torch.utils.data.DataLoader The train dataloader. """ - return DataLoader( - dataset=self.dataset_train, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - shuffle=True, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) + return self._get_dataloader("train") def val_dataloader(self) -> DataLoader: r"""Create and return the validation dataloader. @@ -101,15 +143,7 @@ def val_dataloader(self) -> DataLoader: torch.utils.data.DataLoader The validation dataloader. """ - return DataLoader( - dataset=self.dataset_val, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - shuffle=False, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) + return self._get_dataloader("val") def test_dataloader(self) -> DataLoader: r"""Create and return the test dataloader. @@ -121,15 +155,7 @@ def test_dataloader(self) -> DataLoader: """ if self.dataset_test is None: raise ValueError("There is no test dataloader.") - return DataLoader( - dataset=self.dataset_test, - batch_size=self.batch_size, - num_workers=self.num_workers, - pin_memory=self.pin_memory, - shuffle=False, - collate_fn=collate_fn, - persistent_workers=self.persistent_workers, - ) + return self._get_dataloader("test") def teardown(self, stage: str | None = None) -> None: r"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and `trainer.predict()`. diff --git a/topobenchmark/nn/readouts/propagate_signal_down.py b/topobenchmark/nn/readouts/propagate_signal_down.py index 1eafe325..d79dc466 100644 --- a/topobenchmark/nn/readouts/propagate_signal_down.py +++ b/topobenchmark/nn/readouts/propagate_signal_down.py @@ -26,23 +26,23 @@ def __init__(self, **kwargs): self.name = kwargs["readout_name"] self.dimensions = range(kwargs["num_cell_dimensions"] - 1, 0, -1) - hidden_dim = kwargs["hidden_dim"] + self.hidden_dim = kwargs["hidden_dim"] for i in self.dimensions: setattr( self, f"agg_conv_{i}", topomodelx.base.conv.Conv( - hidden_dim, hidden_dim, aggr_norm=False + self.hidden_dim, self.hidden_dim, aggr_norm=False ), ) - setattr(self, f"ln_{i}", torch.nn.LayerNorm(hidden_dim)) + setattr(self, f"ln_{i}", torch.nn.LayerNorm(self.hidden_dim)) setattr( self, f"projector_{i}", - torch.nn.Linear(2 * hidden_dim, hidden_dim), + torch.nn.Linear(2 * self.hidden_dim, self.hidden_dim), ) def forward(self, model_out: dict, batch: torch_geometric.data.Data): diff --git a/tutorials/batching.ipynb b/tutorials/batching.ipynb new file mode 100644 index 00000000..b78ea289 --- /dev/null +++ b/tutorials/batching.ipynb @@ -0,0 +1,712 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Batching\n", + "\n", + "This notebook shows how to use and test the NeighborCellsLoader, which is the topological counterpart of NeighborLoader from torch_geometric." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports and utilities" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_26947/1154935553.py:31: UserWarning: \n", + "The version_base parameter is not specified.\n", + "Please specify a compatability version level, or None.\n", + "Will assume defaults for version 1.1\n", + " initialize(config_path=\"../configs\", job_name=\"job\")\n" + ] + }, + { + "data": { + "text/plain": [ + "hydra.initialize()" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os, shutil, rootutils\n", + "\n", + "rootutils.setup_root(\"./\", indicator=\".project-root\", pythonpath=True)\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "import hydra\n", + "import torch\n", + "import torch_geometric\n", + "from hydra import compose, initialize\n", + "from omegaconf import OmegaConf\n", + "\n", + "from topobenchmark.data.utils.utils import load_manual_graph\n", + "from topobenchmark.data.preprocessor import PreProcessor\n", + "from topobenchmark.dataloader.dataloader import TBDataloader\n", + "from topobenchmark.data.loaders import PlanetoidDatasetLoader\n", + "\n", + "from topobenchmark.data.batching.neighbor_cells_loader import NeighborCellsLoader\n", + "from topobenchmark.data.preprocessor import PreProcessor\n", + "from topomodelx.nn.simplicial.scn2 import SCN2\n", + "from topomodelx.nn.hypergraph.allset_transformer import AllSetTransformer\n", + "\n", + "from topobenchmark.utils.config_resolvers import (\n", + " get_default_transform,\n", + " get_monitor_metric,\n", + " get_monitor_mode,\n", + " infer_in_channels,\n", + ")\n", + "\n", + "initialize(config_path=\"../configs\", job_name=\"job\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import networkx as nx\n", + "import numpy as np\n", + "import torch\n", + "from matplotlib.patches import Polygon\n", + "from itertools import combinations\n", + "from typing import Optional, Dict, List\n", + "\n", + "def plot_graph(\n", + " data,\n", + " face_color_map: Optional[Dict[int, str]] = None,\n", + " node_size: int = 500,\n", + " font_size: int = 12,\n", + " seed: int = 5,\n", + " show: bool = True\n", + ") -> plt.Figure:\n", + " \"\"\"\n", + " Visualize a simplicial complex from a PyTorch Geometric Data object.\n", + " \n", + " Args:\n", + " data: torch_geometric.data.Data object containing the simplicial complex\n", + " face_color_map: Dictionary mapping number of tetrahedrons to colors\n", + " node_size: Size of nodes in the visualization\n", + " font_size: Size of font for labels\n", + " seed: Random seed for layout\n", + " show: Whether to display the plot immediately\n", + " \n", + " Returns:\n", + " matplotlib.figure.Figure: The generated figure\n", + " \"\"\"\n", + " # Default color map if none provided\n", + " if face_color_map is None:\n", + " face_color_map = {\n", + " 0: \"pink\",\n", + " 1: \"gray\",\n", + " 2: \"blue\",\n", + " 3: \"blue\",\n", + " 4: \"orange\",\n", + " 5: \"purple\",\n", + " 6: \"red\",\n", + " 7: \"brown\",\n", + " 8: \"black\",\n", + " 9: \"gray\",\n", + " }\n", + " \n", + " # Extract vertices\n", + " num_vertices = data.num_nodes if hasattr(data, 'num_nodes') else data.x.shape[0]\n", + " vertices = list(range(num_vertices))\n", + " \n", + " # Extract edges from incidence matrix\n", + " edges = []\n", + " for edge in abs(data.incidence_1.to_dense().T):\n", + " edges.append(torch.where(edge == 1)[0].numpy())\n", + " edges = np.array(edges)\n", + " \n", + " # Extract tetrahedrons if available\n", + " tetrahedrons = []\n", + " if hasattr(data, 'tetrahedrons'):\n", + " tetrahedrons = data.tetrahedrons\n", + " elif hasattr(data, 'incidence_2'):\n", + " # Extract tetrahedrons from incidence_2 matrix if available\n", + " for tetra in abs(data.incidence_2.to_dense().T):\n", + " tetrahedrons.append(torch.where(tetra == 1)[0].tolist())\n", + " \n", + " # Create graph\n", + " G = nx.Graph()\n", + " G.add_nodes_from(vertices)\n", + " G.add_edges_from(edges)\n", + " \n", + " # Find triangular cliques\n", + " cliques = list(nx.enumerate_all_cliques(G))\n", + " cliques = [triangle for triangle in cliques if len(triangle) == 3]\n", + " \n", + " # Create layout\n", + " pos = nx.spring_layout(G, seed=seed)\n", + " \n", + " # Create figure\n", + " fig = plt.figure(figsize=(10, 8))\n", + " \n", + " # Draw nodes and labels\n", + " node_labels = {i: f\"v_{n.item()}\" for i,n in enumerate(data.n_id)} if hasattr(data, 'n_id') else {i: f\"v_{i}\" for i in G.nodes()}\n", + " nx.draw(\n", + " G,\n", + " pos,\n", + " labels=node_labels,\n", + " node_size=node_size,\n", + " node_color=\"skyblue\",\n", + " font_size=font_size,\n", + " )\n", + " \n", + " # Draw edges\n", + " nx.draw_networkx_edges(G, pos, edgelist=edges, edge_color=\"g\", width=2, alpha=0.5)\n", + " \n", + " # # Add edge labels\n", + " for i, (u, v) in enumerate(edges):\n", + " x = (pos[u][0] + pos[v][0]) / 2\n", + " y = (pos[u][1] + pos[v][1]) / 2\n", + " plt.text(x, y, f\"e_{i}\", fontsize=font_size - 2, color=\"r\")\n", + " \n", + " # Color the faces (cliques)\n", + " for clique in cliques:\n", + " # Count tetrahedrons containing this clique\n", + " counter = 0\n", + " for tetrahedron in tetrahedrons:\n", + " for comb in combinations(tetrahedron, 3):\n", + " if set(clique) == set(comb):\n", + " counter += 1\n", + " \n", + " # Create and add polygon\n", + " polygon = [pos[v] for v in clique]\n", + " poly = Polygon(\n", + " polygon,\n", + " closed=True,\n", + " facecolor=face_color_map.get(counter, \"gray\"), # Default to gray if counter not in map\n", + " edgecolor=\"pink\",\n", + " alpha=0.3,\n", + " )\n", + " plt.gca().add_patch(poly)\n", + " \n", + " plt.title(f\"Graph with cliques colored ({num_vertices} vertices)\")\n", + " \n", + " if show:\n", + " plt.show()\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Different use cases" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Manual Simplicial Graph" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data(x=[8, 1], edge_index=[2, 13], y=[8], num_nodes=8, incidence_0=[1, 8], down_laplacian_0=[8, 8], up_laplacian_0=[8, 8], adjacency_0=[8, 8], coadjacency_0=[8, 8], hodge_laplacian_0=[8, 8], incidence_1=[8, 13], down_laplacian_1=[13, 13], up_laplacian_1=[13, 13], adjacency_1=[13, 13], coadjacency_1=[13, 13], hodge_laplacian_1=[13, 13], incidence_2=[13, 6], down_laplacian_2=[6, 6], up_laplacian_2=[6, 6], adjacency_2=[6, 6], coadjacency_2=[6, 6], hodge_laplacian_2=[6, 6], incidence_3=[6, 1], down_laplacian_3=[1, 1], up_laplacian_3=[1, 1], adjacency_3=[1, 1], coadjacency_3=[1, 1], hodge_laplacian_3=[1, 1], shape=[4], x_0=[8, 1], x_1=[13, 1], x_2=[6, 1], x_3=[1, 1])\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "Done!\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "path = \"./graph2simplicial_lifting/\"\n", + "if os.path.isdir(path):\n", + " shutil.rmtree(path)\n", + "cfg = compose(config_name=\"run.yaml\", \n", + " overrides=[\"dataset=graph/manual_dataset\", \"model=simplicial/san\"], \n", + " return_hydra_config=True)\n", + "\n", + "data = load_manual_graph()\n", + "preprocessed_dataset = PreProcessor(data, './', cfg['transforms'])\n", + "data = preprocessed_dataset[0]\n", + "print(data)\n", + "plot_graph(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Data(x=[8, 1], edge_index=[2, 13], y=[8], num_nodes=8, incidence_0=[1, 8], down_laplacian_0=[8, 8], up_laplacian_0=[8, 8], adjacency_0=[8, 8], coadjacency_0=[8, 8], hodge_laplacian_0=[8, 8], incidence_1=[8, 13], down_laplacian_1=[13, 13], up_laplacian_1=[13, 13], adjacency_1=[13, 13], coadjacency_1=[13, 13], hodge_laplacian_1=[13, 13], incidence_2=[13, 6], down_laplacian_2=[6, 6], up_laplacian_2=[6, 6], adjacency_2=[6, 6], coadjacency_2=[6, 6], hodge_laplacian_2=[6, 6], incidence_3=[6, 1], down_laplacian_3=[1, 1], up_laplacian_3=[1, 1], adjacency_3=[1, 1], coadjacency_3=[1, 1], hodge_laplacian_3=[1, 1], shape=[4], x_0=[8, 1], x_1=[13, 1], x_2=[6, 1], x_3=[1, 1])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Training, validation and split idxs should be defined somewhere, here we use a toy example\n", + "rank = 0\n", + "if hasattr(data, \"x_hyperedges\") and rank==1:\n", + " n_cells = data.x_hyperedges.shape[0]\n", + "else:\n", + " n_cells = data[f'x_{rank}'].shape[0]\n", + "\n", + "train_prop = 0.5\n", + "n_train = int(train_prop * n_cells)\n", + "train_mask = torch.zeros(n_cells, dtype=torch.bool)\n", + "train_mask[:n_train] = 1\n", + "\n", + "y = torch.zeros(n_cells, dtype=torch.long)\n", + "data.y = y\n", + "\n", + "data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 1\n", + "\n", + "loader = NeighborCellsLoader(data,\n", + " rank=rank,\n", + " num_neighbors=[-1],\n", + " input_nodes=train_mask,\n", + " batch_size=batch_size,\n", + " shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data(x=[5, 1], edge_index=[2, 16], y=[5], num_nodes=5, incidence_0=[1, 5], down_laplacian_0=[5, 5], up_laplacian_0=[5, 5], adjacency_0=[5, 5], coadjacency_0=[8, 8], hodge_laplacian_0=[5, 5], incidence_1=[5, 8], down_laplacian_1=[8, 8], up_laplacian_1=[8, 8], adjacency_1=[8, 8], coadjacency_1=[13, 13], hodge_laplacian_1=[8, 8], incidence_2=[8, 5], down_laplacian_2=[5, 5], up_laplacian_2=[5, 5], adjacency_2=[5, 5], coadjacency_2=[6, 6], hodge_laplacian_2=[5, 5], incidence_3=[5, 1], down_laplacian_3=[1, 1], up_laplacian_3=[1, 1], adjacency_3=[1, 1], coadjacency_3=[1, 1], hodge_laplacian_3=[1, 1], shape=[4], x_0=[5, 1], x_1=[8, 1], x_2=[5, 1], x_3=[1, 1], cells_ids=[4], n_id=[5])\n", + "The cells of rank 0 that were originally selected are [0]\n", + "Selected cells of rank 0: tensor([0, 7, 1, 2, 4])\n", + "Incidence 3:\n", + "tensor([[1.],\n", + " [1.],\n", + " [1.],\n", + " [0.],\n", + " [1.]])\n", + "Incidence 2:\n", + "tensor([[1., 1., 0., 0., 0.],\n", + " [1., 0., 1., 1., 0.],\n", + " [0., 1., 1., 0., 0.],\n", + " [0., 0., 0., 1., 0.],\n", + " [1., 0., 0., 0., 1.],\n", + " [0., 1., 0., 0., 1.],\n", + " [0., 0., 1., 0., 1.],\n", + " [0., 0., 0., 1., 0.]])\n", + "Incidence 1:\n", + "tensor([[1., 1., 1., 1., 0., 0., 0., 0.],\n", + " [0., 0., 0., 1., 0., 0., 0., 1.],\n", + " [1., 0., 0., 0., 1., 1., 0., 0.],\n", + " [0., 1., 0., 0., 1., 0., 1., 1.],\n", + " [0., 0., 1., 0., 0., 1., 1., 0.]])\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for batch in loader:\n", + " print(batch)\n", + " print(f\"The cells of rank {rank} that were originally selected are {batch.n_id[:batch_size].tolist()}\")\n", + " \n", + " print(f\"Selected cells of rank {rank}: {batch.n_id}\")\n", + " if hasattr(batch, 'incidence_hyperedges'):\n", + " print(\"Incidence hyperedges:\")\n", + " print(batch.incidence_hyperedges.to_dense())\n", + " else:\n", + " print(\"Incidence 3:\")\n", + " print(batch.incidence_3.to_dense())\n", + " print(\"Incidence 2:\")\n", + " print(batch.incidence_2.to_dense())\n", + " print(\"Incidence 1:\")\n", + " print(batch.incidence_1.to_dense())\n", + " if rank == 0:\n", + " plot_graph(batch)\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cora hypergraph" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "Done!\n" + ] + } + ], + "source": [ + "path = \"./graph2hypergraph_lifting/\"\n", + "if os.path.isdir(path):\n", + " shutil.rmtree(path)\n", + "cfg = compose(config_name=\"run.yaml\", \n", + " overrides=[\"dataset=graph/cocitation_cora\", \"model=hypergraph/allsettransformer\"], \n", + " return_hydra_config=True)\n", + "graph_loader = PlanetoidDatasetLoader(cfg.dataset.loader.parameters)\n", + "dataset, dataset_dir = graph_loader.load()\n", + "preprocessed_dataset = PreProcessor(dataset, './', cfg['transforms'])\n", + "data = preprocessed_dataset[0]\n", + "\n", + "# Training, validation and split idxs should be defined somewhere, here we use a toy example\n", + "rank = 0\n", + "if hasattr(data, \"x_hyperedges\") and rank==1:\n", + " n_cells = data.x_hyperedges.shape[0]\n", + "else:\n", + " n_cells = data[f'x_{rank}'].shape[0]\n", + "\n", + "train_prop = 0.5\n", + "n_train = int(train_prop * n_cells)\n", + "train_mask = torch.zeros(n_cells, dtype=torch.bool)\n", + "train_mask[:n_train] = 1\n", + "\n", + "if rank != 0:\n", + " y = torch.zeros(n_cells, dtype=torch.long)\n", + " data.y = y\n", + "batch_size = 1\n", + "\n", + "# num_neighbors controls also the number of hops (for 2 hops do num_neighbors=[-1, -1])\n", + "\n", + "\n", + "loader = NeighborCellsLoader(data,\n", + " rank=rank,\n", + " num_neighbors=[-1],\n", + " input_nodes=train_mask,\n", + " batch_size=batch_size,\n", + " shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Data(x=[4, 1433], edge_index=[2, 10556], y=[4], train_mask=[2708], val_mask=[2708], test_mask=[2708], incidence_hyperedges=[4, 5], num_hyperedges=2708, x_0=[4, 1433], x_hyperedges=[2708, 1433], num_nodes=4, cells_ids=[2], n_id=[4])\n", + "tensor([ 0, 1862, 633, 2582])\n", + "tensor([[ 0, 0, 0, ..., 2707, 2707, 2707],\n", + " [ 633, 1862, 2582, ..., 598, 1473, 2706]])\n", + "tensor([[1., 1., 0., 1., 1.],\n", + " [1., 0., 1., 1., 1.],\n", + " [1., 1., 1., 0., 0.],\n", + " [1., 0., 0., 1., 1.]])\n" + ] + } + ], + "source": [ + "for batch in loader:\n", + " print(batch)\n", + " print(batch.n_id)\n", + " print(batch.edge_index)\n", + " if hasattr(batch, 'incidence_hyperedges'):\n", + " print(batch.incidence_hyperedges.to_dense())\n", + " else:\n", + " print(batch.incidence_3.to_dense())\n", + " print(batch.incidence_2.to_dense())\n", + " print(batch.incidence_1.to_dense())\n", + " \n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Testing NeighborCellLoader" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If batching is done correctly the results on the selected cells should not change when compared to the results obtained over the whole graph.\n", + "We test this to check that our batching strategy is correct." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Testing simplicial complexes" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transform parameters are the same, using existing data_dir: /TopoBenchmark/datasets/graph/cocitation/Cora/graph2simplicial_lifting/1597083846\n", + "Testing batching for SCN2 using 2 layers.\n", + "We return the MSE between the full batch and the batched version.\n", + " n_hops = 1 MSE: 3.947006937138043\n", + " n_hops = 2 MSE: 9.035350438167982e-12\n", + " n_hops = 3 MSE: 1.0411274301838695e-11\n" + ] + } + ], + "source": [ + "path = \"./graph2simplicial_lifting/\"\n", + "if os.path.isdir(path):\n", + " shutil.rmtree(path)\n", + "cfg = compose(config_name=\"run.yaml\", \n", + " overrides=[\"dataset=graph/cocitation_cora\", \"model=simplicial/scn\"], \n", + " return_hydra_config=True)\n", + "\n", + "dataset_loader = hydra.utils.instantiate(cfg.dataset.loader)\n", + "dataset, dataset_dir = dataset_loader.load()\n", + "# Preprocess dataset and load the splits\n", + "transform_config = cfg.get(\"transforms\", None)\n", + "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n", + "dataset_train, dataset_val, dataset_test = (\n", + " preprocessor.load_dataset_splits(cfg.dataset.split_params)\n", + ")\n", + "\n", + "### Full batch --------------------------------------------------------\n", + "cfg.dataset.dataloader_params.batch_size = -1\n", + "\n", + "datamodule = TBDataloader(\n", + " dataset_train=dataset_train,\n", + " dataset_val=dataset_val,\n", + " dataset_test=dataset_test,\n", + " **cfg.dataset.get(\"dataloader_params\", {}),\n", + " )\n", + "\n", + "input_dim = 1433\n", + "hidden_channels = 16\n", + "out_dim = 7\n", + "\n", + "model = SCN2(input_dim, input_dim, input_dim, n_layers=2)\n", + "model.eval()\n", + "\n", + "train_dataloader = datamodule.train_dataloader()\n", + "for data in train_dataloader:\n", + " x_0_full, x_1_full, x_2_full = model(data.x_0, data.x_1, data.x_2, data.hodge_laplacian_0, data.hodge_laplacian_1, data.hodge_laplacian_2)\n", + "\n", + "### Batched --------------------------------------------------------\n", + "print(\"Testing batching for SCN2 using 2 layers.\")\n", + "print(\"We return the MSE between the full batch and the batched version.\")\n", + "for n_hops in range(1, 4):\n", + " cfg.dataset.dataloader_params.batch_size = 32\n", + "\n", + " datamodule_batched = TBDataloader(\n", + " dataset_train=dataset_train,\n", + " dataset_val=dataset_val,\n", + " dataset_test=dataset_test,\n", + " num_neighbors = [-1] * n_hops,\n", + " **cfg.dataset.get(\"dataloader_params\", {}),\n", + " )\n", + " train_dataloader_batched = datamodule_batched.train_dataloader()\n", + " mse = 0\n", + " for i, batch in enumerate(train_dataloader_batched):\n", + " x_0_batch, x_1_batch, x_2_batch = model(batch.x_0, batch.x_1, batch.x_2, batch.hodge_laplacian_0, batch.hodge_laplacian_1, batch.hodge_laplacian_2)\n", + " n_ids = batch.n_id[:batch_size]\n", + " mse += torch.mean((x_0_full[n_ids, :] - x_0_batch[:batch_size, :]).pow(2)).item()\n", + " mse = mse / (i + 1)\n", + " \n", + " # The last element might be False since the last batch might not be full\n", + " print(f\" n_hops = {n_hops} MSE: {mse}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Testing hypergraphs" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Transform parameters are the same, using existing data_dir: /TopoBenchmark/datasets/graph/cocitation/Cora/graph2hypergraph_lifting/1010717418\n", + "Testing batching for AllSetTransformer using 3 layers.\n", + "We return the MSE between the full batch and the batched version.\n", + " n_hops = 1 MSE: 0.0002609546898304469\n", + " n_hops = 2 MSE: 1.660647607597814e-05\n", + " n_hops = 3 MSE: 1.1712208103968767e-06\n", + " n_hops = 4 MSE: 1.5722682561617035e-08\n", + " n_hops = 5 MSE: 1.550387910249659e-11\n", + " n_hops = 6 MSE: 8.365642818938425e-15\n", + " n_hops = 7 MSE: 7.258141882427575e-15\n", + " n_hops = 8 MSE: 7.711832348623849e-15\n", + " n_hops = 9 MSE: 4.981246563409259e-15\n" + ] + } + ], + "source": [ + "path = \"./graph2hypergraph_lifting/\"\n", + "if os.path.isdir(path):\n", + " shutil.rmtree(path)\n", + "cfg = compose(config_name=\"run.yaml\", \n", + " overrides=[\"dataset=graph/cocitation_cora\", \"model=hypergraph/allsettransformer\"], \n", + " return_hydra_config=True)\n", + "\n", + "dataset_loader = hydra.utils.instantiate(cfg.dataset.loader)\n", + "dataset, dataset_dir = dataset_loader.load()\n", + "# Preprocess dataset and load the splits\n", + "transform_config = cfg.get(\"transforms\", None)\n", + "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n", + "dataset_train, dataset_val, dataset_test = (\n", + " preprocessor.load_dataset_splits(cfg.dataset.split_params)\n", + ")\n", + "\n", + "### Full batch --------------------------------------------------------\n", + "cfg.dataset.dataloader_params.batch_size = -1\n", + "\n", + "datamodule = TBDataloader(\n", + " dataset_train=dataset_train,\n", + " dataset_val=dataset_val,\n", + " dataset_test=dataset_test,\n", + " **cfg.dataset.get(\"dataloader_params\", {}),\n", + " )\n", + "\n", + "input_dim = 1433\n", + "hidden_channels = 16\n", + "out_dim = 7\n", + "n_layers = 3\n", + "model = AllSetTransformer(input_dim, hidden_channels, n_layers=n_layers)\n", + "model.eval()\n", + "\n", + "train_dataloader = datamodule.train_dataloader()\n", + "for data in train_dataloader:\n", + " x_0_full, x_1_full = model(data.x_0, data.incidence_hyperedges)\n", + "\n", + "### Batched --------------------------------------------------------\n", + "print(f\"Testing batching for AllSetTransformer using {n_layers} layers.\")\n", + "print(\"We return the MSE between the full batch and the batched version.\")\n", + "for n_hops in range(1, 10):\n", + " cfg.dataset.dataloader_params.batch_size = 32\n", + "\n", + " datamodule_batched = TBDataloader(\n", + " dataset_train=dataset_train,\n", + " dataset_val=dataset_val,\n", + " dataset_test=dataset_test,\n", + " num_neighbors = [-1] * n_hops,\n", + " **cfg.dataset.get(\"dataloader_params\", {}),\n", + " )\n", + " train_dataloader_batched = datamodule_batched.train_dataloader()\n", + " mse = 0\n", + " for i, batch in enumerate(train_dataloader_batched):\n", + " x_0_batch, x_1_batch = model(batch.x_0, batch.incidence_hyperedges)\n", + " n_ids = batch.n_id[:batch_size]\n", + " mse += torch.mean((x_0_full[n_ids, :] - x_0_batch[:batch_size, :]).pow(2)).item()\n", + " mse = mse / (i + 1)\n", + " \n", + " # The last element might be False since the last batch might not be full\n", + " print(f\" n_hops = {n_hops} MSE: {mse}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}