-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
337 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
/logs | ||
/datasets/graph | ||
.ruff_cache |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
FROM python:3.11.3 | ||
|
||
WORKDIR /TopoBenchmarkX | ||
|
||
COPY . . | ||
|
||
RUN pip install --upgrade pip | ||
|
||
RUN pip install -e '.[all]' | ||
RUN pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git | ||
RUN pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git | ||
RUN pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu115 | ||
RUN pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu115.html | ||
RUN pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu115.html | ||
#RUN pip install lightning>=2.0.0 | ||
#RUN pip install numpy pre-commit jupyterlab notebook ipykernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# #!/bin/bash | ||
|
||
#conda create -n topoxx python=3.11.3 | ||
#conda activate topoxx | ||
|
||
pip install --upgrade pip | ||
pip install -e '.[all]' | ||
|
||
pip install git+https://github.com/pyt-team/TopoNetX.git | ||
pip install git+https://github.com/pyt-team/TopoModelX.git | ||
pip install git+https://github.com/pyt-team/TopoEmbedX.git | ||
|
||
CUDA="cu117" # if available, select the CUDA version suitable for your system | ||
# e.g. cpu, cu102, cu111, cu113, cu115 | ||
pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} | ||
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html | ||
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html | ||
|
||
pytest | ||
|
||
pre-commit install |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
,num_hyperedges,zero_cell,one_cell,two_cell,three_cell,dataset,domain | ||
0,0,3224,9483,6266,0,US-county-demos,cell | ||
1,0,2708,5278,2648,0,Cora,cell | ||
2,0,3327,4552,1663,0,citeseer,cell | ||
3,0,19717,44324,23605,0,PubMed,cell | ||
4,0,277864,298985,33121,0,ZINC,cell | ||
5,0,22662,32927,10266,0,roman_empire,cell |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
import random | ||
from typing import Any | ||
|
||
import hydra | ||
import lightning as L | ||
import numpy as np | ||
import rootutils | ||
|
||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | ||
import torch | ||
from lightning import Callback, LightningModule, Trainer | ||
from lightning.pytorch.loggers import Logger | ||
from omegaconf import DictConfig, OmegaConf | ||
|
||
from topobenchmarkx.data.dataloaders import DefaultDataModule | ||
from topobenchmarkx.utils import ( | ||
RankedLogger, | ||
extras, | ||
get_metric_value, | ||
instantiate_callbacks, | ||
instantiate_loggers, | ||
log_hyperparameters, | ||
task_wrapper, | ||
) | ||
|
||
from topobenchmarkx.utils.config_resolvers import ( | ||
get_default_transform, | ||
get_monitor_metric, | ||
get_monitor_mode, | ||
infer_in_channels, | ||
infere_list_length, | ||
) | ||
import pandas as pd | ||
import os | ||
# ------------------------------------------------------------------------------------ # | ||
# the setup_root above is equivalent to: | ||
# - adding project root dir to PYTHONPATH | ||
# (so you don't need to force user to install project as a package) | ||
# (necessary before importing any local modules e.g. `from src import utils`) | ||
# - setting up PROJECT_ROOT environment variable | ||
# (which is used as a base for paths in "configs/paths/default.yaml") | ||
# (this way all filepaths are the same no matter where you run the code) | ||
# - loading environment variables from ".env" in root dir | ||
# | ||
# you can remove it if you: | ||
# 1. either install project as a package or move entry files to project root dir | ||
# 2. set `root_dir` to "." in "configs/paths/default.yaml" | ||
# | ||
# more info: https://github.com/ashleve/rootutils | ||
# ------------------------------------------------------------------------------------ # | ||
|
||
|
||
OmegaConf.register_new_resolver("get_default_transform", get_default_transform) | ||
OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric) | ||
OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode) | ||
OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels) | ||
OmegaConf.register_new_resolver("infere_list_length", infere_list_length) | ||
OmegaConf.register_new_resolver( | ||
"parameter_multiplication", lambda x, y: int(int(x) * int(y)) | ||
) | ||
|
||
torch.set_num_threads(1) | ||
log = RankedLogger(__name__, rank_zero_only=True) | ||
|
||
|
||
|
||
def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: | ||
"""Trains the model. Can additionally evaluate on a testset, using best | ||
weights obtained during training. | ||
This method is wrapped in optional @task_wrapper decorator, that controls | ||
the behavior during failure. Useful for multiruns, saving info about the | ||
crash, etc. | ||
:param cfg: A DictConfig configuration composed by Hydra. | ||
:return: A tuple with metrics and dict with all instantiated objects. | ||
""" | ||
|
||
# Set seed for random number generators in pytorch, numpy and python.random | ||
# if cfg.get("seed"): | ||
L.seed_everything(cfg.seed, workers=True) | ||
# Seed for torch | ||
torch.manual_seed(cfg.seed) | ||
# Seed for numpy | ||
np.random.seed(cfg.seed) | ||
# Seed for python random | ||
random.seed(cfg.seed) | ||
|
||
if cfg.model.model_domain == "cell": | ||
cfg.dataset.transforms.graph2cell_lifting.max_cell_length=1000 | ||
|
||
# Instantiate and load dataset | ||
dataset = hydra.utils.instantiate(cfg.dataset, _recursive_=False) | ||
dataset = dataset.load() | ||
|
||
one_graph_flag = True | ||
if cfg.dataset.parameters.batch_size != 1: | ||
cfg.dataset.parameters.batch_size != 1 | ||
one_graph_flag = False | ||
|
||
|
||
log.info(f"Instantiating datamodule <{cfg.dataset._target_}>") | ||
|
||
if cfg.dataset.parameters.task_level == "node": | ||
datamodule = DefaultDataModule(dataset_train=dataset) | ||
|
||
elif cfg.dataset.parameters.task_level == "graph": | ||
datamodule = DefaultDataModule( | ||
dataset_train=dataset[0], | ||
dataset_val=dataset[1], | ||
dataset_test=dataset[2], | ||
batch_size=cfg.dataset.parameters.batch_size, | ||
) | ||
|
||
else: | ||
raise ValueError("Invalid task_level") | ||
|
||
if one_graph_flag == True: | ||
dataloaders = [datamodule.train_dataloader()] | ||
else: | ||
dataloaders = [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()] | ||
|
||
dict_collector = { | ||
"num_hyperedges": 0, | ||
"zero_cell": 0, | ||
"one_cell": 0, | ||
"two_cell": 0, | ||
"three_cell": 0, | ||
} | ||
|
||
for loader in dataloaders: | ||
for batch in loader: | ||
if cfg.model.model_domain == "hypergraph": | ||
dict_collector["zero_cell"] += batch.x.shape[0] | ||
dict_collector["num_hyperedges"] += batch.x_hyperedges.shape[0] | ||
|
||
elif cfg.model.model_domain == "simplicial": | ||
dict_collector["zero_cell"] += batch.x_0.shape[0] | ||
dict_collector["one_cell"] +=batch.x_1.shape[0] | ||
dict_collector["two_cell"] +=batch.x_2.shape[0] | ||
dict_collector["three_cell"] += batch.x_3.shape[0] | ||
|
||
elif cfg.model.model_domain == "cell": | ||
dict_collector["zero_cell"] += batch.x_0.shape[0] | ||
dict_collector["one_cell"] += batch.x_1.shape[0] | ||
dict_collector["two_cell"] += batch.x_2.shape[0] | ||
cell_sizes, cell_counts = torch.unique(batch.incidence_2.to_dense().sum(0), return_counts=True) | ||
|
||
|
||
# Get current working dir | ||
filename = f"{cfg.paths['root_dir']}/tables/dataset_statistics.csv" | ||
|
||
dict_collector['dataset'] = cfg.dataset.parameters.data_name | ||
dict_collector['domain'] = cfg.model.model_domain | ||
|
||
df = pd.DataFrame.from_dict(dict_collector, orient='index') | ||
if not os.path.exists(filename) == True: | ||
# Save to csv file such as methods .... is a header | ||
df.T.to_csv(filename, header=True) | ||
else: | ||
# read csv file with deader | ||
df_saved = pd.read_csv(filename, index_col=0) | ||
# add new row | ||
df_saved = df_saved._append(dict_collector, ignore_index=True) | ||
# write to csv file | ||
df_saved.to_csv(filename) | ||
|
||
# if cfg.model.model_domain == "cell": | ||
# filename = f"{cfg.paths['root_dir']}/tables/cell_statistics.csv" | ||
# # Create a dict from two arrays | ||
# cell_dict = dict(zip(cell_sizes.long().tolist(), cell_counts.long().tolist())) | ||
|
||
# # Check if there are cells size of which greater than 10 | ||
# n_large_cells = 0 | ||
# subset_keys = [key for key in sorted(cell_dict.keys()) if key > 10] | ||
|
||
# for key in subset_keys: | ||
# n_large_cells += cell_dict.pop(key) | ||
|
||
# cell_dict["greater_than_10"] = n_large_cells | ||
|
||
# cell_dict['dataset'] = cfg.dataset.parameters.data_name | ||
# cell_dict['domain'] = cfg.model.model_domain | ||
|
||
# df = pd.DataFrame.from_dict(cell_dict, orient='index') | ||
# if not os.path.exists(filename) == True: | ||
# # Save to csv file such as methods .... is a header | ||
# df.T.to_csv(filename, header=True) | ||
# else: | ||
# # read csv file with deader | ||
# df_saved = pd.read_csv(filename, index_col=0) | ||
# # add new row | ||
# df_saved = df_saved._append(df.T, ignore_index=True) | ||
# # write to csv file | ||
# df_saved.to_csv(filename) | ||
|
||
|
||
return | ||
|
||
|
||
|
||
@hydra.main( | ||
version_base="1.3", config_path="../configs", config_name="train.yaml" | ||
) | ||
def main(cfg: DictConfig) -> float | None: | ||
"""Main entry point for training. | ||
:param cfg: DictConfig configuration composed by Hydra. | ||
:return: Optional[float] with optimized metric value. | ||
""" | ||
# apply extra utilities | ||
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) | ||
extras(cfg) | ||
|
||
|
||
train(cfg) | ||
|
||
|
||
# return optimized metric | ||
return | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
|
||
# Description: Main experiment script for GCN model. | ||
# ----Node regression datasets: US County Demographics---- | ||
models=( 'simplicial/scn' 'cell/cwn' 'hypergraph/unignn2' ) | ||
for model in ${models[*]} | ||
do | ||
|
||
|
||
python dataset_statistics.py \ | ||
dataset=us_country_demos \ | ||
model=$model \ | ||
|
||
|
||
# ----Cocitation datasets---- | ||
datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) | ||
|
||
for dataset in ${datasets[*]} | ||
do | ||
python dataset_statistics.py \ | ||
dataset=$dataset \ | ||
model=$model | ||
|
||
done | ||
|
||
# ----Graph regression dataset---- | ||
# Train on ZINC dataset | ||
python dataset_statistics.py \ | ||
dataset=ZINC \ | ||
model=$model \ | ||
dataset.transforms.one_hot_node_degree_features.degrees_fields=x | ||
|
||
|
||
# ----Heterophilic datasets---- | ||
|
||
datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) | ||
|
||
for dataset in ${datasets[*]} | ||
do | ||
python dataset_statistics.py \ | ||
dataset=$dataset \ | ||
model=$model | ||
done | ||
|
||
# ----TU graph datasets---- | ||
# MUTAG have very few samples, so we use a smaller batch size | ||
# Train on MUTAG dataset | ||
python dataset_statistics.py \ | ||
dataset=MUTAG \ | ||
model=$model | ||
|
||
# Train rest of the TU graph datasets | ||
datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI') # | ||
|
||
for dataset in ${datasets[*]} | ||
do | ||
python dataset_statistics.py \ | ||
dataset=$dataset \ | ||
model=$model | ||
done | ||
|
||
done |