Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added batching in transductive setting #128

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fe2d5e7
some random error
levtelyatnikov Oct 31, 2024
82b3ede
run
levtelyatnikov Oct 31, 2024
c225cc6
resolver_error
levtelyatnikov Oct 31, 2024
8a3714b
Merge branch 'main' of https://github.com/geometric-intelligence/Topo…
levtelyatnikov Oct 31, 2024
9cb5def
start the developments of node level batching
levtelyatnikov Oct 31, 2024
c988b2d
Marco - added batching functions
Coerulatus Nov 14, 2024
877fce1
Marco - get_sampled_neighborhood reworked
Coerulatus Nov 14, 2024
09bab78
added proper plot function
levtelyatnikov Nov 15, 2024
72f92ec
Marco - batching done
Coerulatus Nov 15, 2024
76e1d03
added some comments
levtelyatnikov Nov 15, 2024
fb392ce
get rid of random files
levtelyatnikov Nov 15, 2024
f623d24
added just sampling over the graph
levtelyatnikov Nov 15, 2024
8519fef
Marco - defined NeighborCellsLoader
Coerulatus Nov 16, 2024
9488259
merged changes
Coerulatus Nov 16, 2024
5ec69b8
Marco - fixed conflict
Coerulatus Nov 16, 2024
64c2c9f
fixed __repr__ of readout
Coerulatus Nov 26, 2024
e154231
support for multiple hops
Coerulatus Nov 27, 2024
7bddf5d
changed DataloadDataset call
Coerulatus Nov 27, 2024
69a0e94
test batching with multiple hops
Coerulatus Nov 27, 2024
54d428b
Merge remote-tracking branch 'origin/main' into batching
Coerulatus Dec 16, 2024
6a0e4ec
hydra already initialized
Coerulatus Dec 18, 2024
78f5ed2
added test
Coerulatus Dec 18, 2024
9758ff4
added batching in transductive setting
Coerulatus Dec 18, 2024
4d80241
test mse when batching
Coerulatus Dec 18, 2024
503512f
formatting
Coerulatus Dec 18, 2024
f92e378
changed batch size for new TBDataloader
Coerulatus Dec 18, 2024
4a53d8f
ruff fixes
Coerulatus Dec 18, 2024
3f44880
fix temp folder
Coerulatus Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/dataset/graph/US-county-demos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 0
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/graph/amazon_ratings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 0
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/graph/cocitation_citeseer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/graph/cocitation_cora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/graph/cocitation_pubmed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/graph/manual_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1
batch_size: -1
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/graph/minesweeper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 0
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/graph/questions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/graph/roman_empire.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 0
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/graph/tolokers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/hypergraph/coauthorship_cora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/hypergraph/coauthorship_dblp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/hypergraph/cocitation_citeseer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/hypergraph/cocitation_cora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
2 changes: 1 addition & 1 deletion configs/dataset/hypergraph/cocitation_pubmed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ split_params:

# Dataloader parameters
dataloader_params:
batch_size: 1 # Fixed
batch_size: -1 # Fixed
num_workers: 1
pin_memory: False
132 changes: 132 additions & 0 deletions test/data/batching/test_neighbor_cells_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
""" Test for the NeighborCellsLoader class."""
import os
import shutil
import rootutils
from hydra import compose
import torch

from topobenchmark.data.preprocessor import PreProcessor
from topobenchmark.data.utils.utils import load_manual_graph
from topobenchmark.data.batching import NeighborCellsLoader
from topobenchmark.run import initialize_hydra

initialize_hydra()

path = "./graph2simplicial_lifting/"
if os.path.isdir(path):
shutil.rmtree(path)
cfg = compose(config_name="run.yaml",
overrides=["dataset=graph/manual_dataset", "model=simplicial/san"],
return_hydra_config=True)

data = load_manual_graph()
preprocessed_dataset = PreProcessor(data, path, cfg['transforms'])
data = preprocessed_dataset[0]

batch_size=2

rank = 0
n_cells = data[f'x_{rank}'].shape[0]
train_prop = 0.5
n_train = int(train_prop * n_cells)
train_mask = torch.zeros(n_cells, dtype=torch.bool)
train_mask[:n_train] = 1

y = torch.zeros(n_cells, dtype=torch.long)
data.y = y

loader = NeighborCellsLoader(data,
rank=rank,
num_neighbors=[-1],
input_nodes=train_mask,
batch_size=batch_size,
shuffle=False)
train_nodes = []
for batch in loader:
train_nodes += [n for n in batch.n_id[:batch_size]]
for i in range(n_train):
assert i in train_nodes

rank = 1
n_cells = data[f'x_{rank}'].shape[0]
train_prop = 0.5
n_train = int(train_prop * n_cells)
train_mask = torch.zeros(n_cells, dtype=torch.bool)
train_mask[:n_train] = 1

y = torch.zeros(n_cells, dtype=torch.long)
data.y = y

loader = NeighborCellsLoader(data,
rank=rank,
num_neighbors=[-1,-1],
input_nodes=train_mask,
batch_size=batch_size,
shuffle=False)

train_nodes = []
for batch in loader:
train_nodes += [n for n in batch.n_id[:batch_size]]
for i in range(n_train):
assert i in train_nodes
shutil.rmtree(path)


path = "./graph2hypergraph_lifting/"
if os.path.isdir(path):
shutil.rmtree(path)
cfg = compose(config_name="run.yaml",
overrides=["dataset=graph/manual_dataset", "model=hypergraph/allsettransformer"],
return_hydra_config=True)

data = load_manual_graph()
preprocessed_dataset = PreProcessor(data, path, cfg['transforms'])
data = preprocessed_dataset[0]

batch_size=2

rank = 0
n_cells = data[f'x_0'].shape[0]
train_prop = 0.5
n_train = int(train_prop * n_cells)
train_mask = torch.zeros(n_cells, dtype=torch.bool)
train_mask[:n_train] = 1

y = torch.zeros(n_cells, dtype=torch.long)
data.y = y

loader = NeighborCellsLoader(data,
rank=rank,
num_neighbors=[-1],
input_nodes=train_mask,
batch_size=batch_size,
shuffle=False)
train_nodes = []
for batch in loader:
train_nodes += [n for n in batch.n_id[:batch_size]]
for i in range(n_train):
assert i in train_nodes

rank = 1
n_cells = data[f'x_hyperedges'].shape[0]
train_prop = 0.5
n_train = int(train_prop * n_cells)
train_mask = torch.zeros(n_cells, dtype=torch.bool)
train_mask[:n_train] = 1

y = torch.zeros(n_cells, dtype=torch.long)
data.y = y

loader = NeighborCellsLoader(data,
rank=rank,
num_neighbors=[-1,-1],
input_nodes=train_mask,
batch_size=batch_size,
shuffle=False)

train_nodes = []
for batch in loader:
train_nodes += [n for n in batch.n_id[:batch_size]]
for i in range(n_train):
assert i in train_nodes
shutil.rmtree(path)
4 changes: 0 additions & 4 deletions test/data/dataload/test_Dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ class TestCollateFunction:

def setup_method(self):
"""Setup the test."""

hydra.initialize(
version_base="1.3", config_path="../../../configs", job_name="run"
)
cfg = hydra.compose(config_name="run.yaml", overrides=["dataset=graph/NCI1"])

graph_loader = hydra.utils.instantiate(cfg.dataset.loader, _recursive_=False)
Expand Down
7 changes: 7 additions & 0 deletions topobenchmark/data/batching/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Init file for batching module."""

from .neighbor_cells_loader import NeighborCellsLoader

__all__ = [
"NeighborCellsLoader",
]
Loading
Loading