Skip to content

Commit

Permalink
preparation for 3 manifolds experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Aug 28, 2024
1 parent e10d013 commit eac11c3
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 129 deletions.
45 changes: 0 additions & 45 deletions code/benchmarks.md

This file was deleted.

20 changes: 20 additions & 0 deletions code/datasets/dataset_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from enum import Enum
from torch_geometric.data import Batch, Data


class DatasetType(Enum):
"""
DESCRIPTION:
- FULL_2D - Mantra on 2D manifolds
- FULL_3D - Mantra on 3D manifolds
- NO_NAMELESS_2D - Mantra on 2D manifolds only including simplicial complexes which have a name label.
"""

FULL_2D = "full_2d"
FULL_3D = "full_3d"
NO_NAMELESS_2D = "no_nameless_2d"


def filter_nameless(data: Data) -> bool:
return data.name != ""
49 changes: 34 additions & 15 deletions code/datasets/simplicial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import List
from .simplicial_ds import SimplicialDS
from metrics.tasks import TaskType
from datasets.dataset_types import DatasetType, filter_nameless
import os


def unique_counts(input_list: List[str]) -> Counter:
Expand All @@ -22,8 +24,10 @@ def __init__(
batch_size: int = 128,
seed: int = 2024,
dataloader_builder: Callable = DataLoaderGeometric,
ds_type: DatasetType = DatasetType.FULL_2D,
):
super().__init__()
self.ds_type = ds_type
self.data_dir = data_dir
self.transform = transform
self.task_type = task_type
Expand All @@ -32,12 +36,37 @@ def __init__(
self.batch_size = batch_size
self.seed = seed
self.dataloader_builder = dataloader_builder
self.split = [0.6, 0.2, 0.2]

def get_ds(self, mode: str = "train") -> SimplicialDS:
if self.ds_type == DatasetType.FULL_2D:
manifold = "2"
pre_filter = None
elif self.ds_type == DatasetType.FULL_3D:
manifold = "3"
pre_filter = None
elif self.ds_type == DatasetType.NO_NAMELESS_2D:
manifold = "2"
pre_filter = filter_nameless
else:
raise ValueError(f"Unknown dataset type {self.ds_type}")
return SimplicialDS(
root=os.path.join(self.data_dir, self.ds_type.name.lower()),
manifold=manifold,
split=self.split,
seed=self.seed,
mode=mode,
use_stratified=self.use_stratified,
task_type=self.task_type,
transform=self.transform,
pre_filter=pre_filter,
)

def prepare_data(self) -> None:
SimplicialDS(root=self.data_dir)
self.get_ds()

def class_imbalance_statistics(self) -> Counter:
dataset = SimplicialDS(root=self.data_dir, task_type=self.task_type)
dataset = self.get_ds()

statistics = None
if self.task_type == TaskType.NAME:
Expand All @@ -49,19 +78,9 @@ def class_imbalance_statistics(self) -> Counter:
return statistics

def setup(self, stage=None):
get_ds = lambda mode: SimplicialDS(
root=self.data_dir,
split=[0.7, 0.15, 0.15],
seed=self.seed,
use_stratified=self.use_stratified,
task_type=self.task_type,
mode=mode,
transform=self.transform,
)

self.train_ds = get_ds("train")
self.val_ds = get_ds("val")
self.test_ds = get_ds("test")
self.train_ds = self.get_ds("train")
self.val_ds = self.get_ds("val")
self.test_ds = self.get_ds("test")

def train_dataloader(self):
return self.dataloader_builder(
Expand Down
4 changes: 3 additions & 1 deletion code/datasets/simplicial_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def process(self):
test_size=self.split_config.split[2],
shuffle=True,
stratify=(
stratified if self.split_config.use_stratified else None
stratified.numpy()
if self.split_config.use_stratified
else None
),
random_state=self.split_config.seed,
)
Expand Down
8 changes: 1 addition & 7 deletions code/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,7 @@
)
from enum import Enum

NAME_TO_CLASS = {
"Klein bottle": 0,
"": 1,
"RP^2": 2,
"T^2": 3,
"S^2": 4,
}
NAME_TO_CLASS = {"Klein bottle": 0, "RP^2": 1, "T^2": 2, "S^2": 3, "": 4}


class SetNumNodesTransform:
Expand Down
142 changes: 104 additions & 38 deletions code/experiments/generate_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
sys.path.append(os.curdir)
from metrics.tasks import TaskType
from datasets.transforms import TransformType
from datasets.dataset_types import DatasetType
from models import ModelType, model_cfg_lookup
from experiments.utils.configs import ConfigExperimentRun, TrainerConfig
import yaml
import json
import os
import shutil
import argparse
from typing import List, Dict

# ARGS ------------------------------------------------------------------------
parser = argparse.ArgumentParser(
Expand All @@ -37,43 +39,64 @@
help="Maximum number of epochs.",
)

parser.add_argument(
"--use_imbalance_weighting",
action="store_true",
help="Whether to weight loss terms with imbalance weights.",
)

args = parser.parse_args()
max_epochs: int = args.max_epochs
lr: float = args.lr
config_dir: str = args.config_dir
use_imbalance_weights: bool = args.use_imbalance_weighting
# -----------------------------------------------------------------------------

# CONFIGS ---------------------------------------------------------------------
tasks = [TaskType.ORIENTABILITY, TaskType.NAME, TaskType.BETTI_NUMBERS]

# DS TYPE ###
dataset_types = [
DatasetType.FULL_2D,
DatasetType.FULL_3D,
DatasetType.NO_NAMELESS_2D,
]
# ###########

# TASKS #####
tasks_mantra2 = [TaskType.ORIENTABILITY, TaskType.NAME, TaskType.BETTI_NUMBERS]
tasks_mantra3 = [TaskType.ORIENTABILITY, TaskType.BETTI_NUMBERS]
# ###########

# TRANSFORMS
graph_features = [
TransformType.degree_transform,
TransformType.degree_transform_onehot,
TransformType.random_node_features,
]

simplicial_features = [
TransformType.degree_transform_sc,
TransformType.random_simplices_features,
]
# ###########

# MODELS ####
graph_models = {
ModelType.GCN,
ModelType.GAT,
ModelType.MLP,
ModelType.TAG,
ModelType.TransfConv,
}

simplicial_models = {
ModelType.SAN,
ModelType.SCCN,
ModelType.SCCNN,
ModelType.SCN,
}

models = list(graph_models) + list(simplicial_models)
# ###########

# MISC ######
feature_dim_dict = {
TransformType.degree_transform: 1,
TransformType.degree_transform_onehot: 10,
Expand All @@ -82,11 +105,22 @@
TransformType.random_simplices_features: [8, 8, 8],
}

out_channels_dict = {
out_channels_dict_mantra2_full = {
TaskType.ORIENTABILITY: 1,
TaskType.NAME: 5,
TaskType.BETTI_NUMBERS: 3,
}
out_channels_dict_mantra2_no_nameless = {
TaskType.ORIENTABILITY: 1,
TaskType.NAME: 4,
TaskType.BETTI_NUMBERS: 3,
}
out_channels_dict_mantra3 = {
TaskType.ORIENTABILITY: 1,
TaskType.BETTI_NUMBERS: 3,
}
# ###########

# -----------------------------------------------------------------------------


Expand Down Expand Up @@ -125,42 +159,74 @@ def manage_directory(path: str):
os.makedirs(path)


def get_tasks(ds_type: DatasetType) -> List[TaskType]:
tasks = (
tasks_mantra2.copy()
if (
ds_type == DatasetType.FULL_2D
or ds_type == DatasetType.NO_NAMELESS_2D
)
else tasks_mantra3.copy()
)
return tasks


def get_out_channels_dict(ds_type: DatasetType) -> Dict[TaskType, int]:
if ds_type == DatasetType.FULL_2D:
return out_channels_dict_mantra2_full.copy()
elif ds_type == DatasetType.FULL_3D:
return out_channels_dict_mantra3.copy()
elif ds_type == DatasetType.NO_NAMELESS_2D:
return out_channels_dict_mantra2_no_nameless.copy()
else:
raise ValueError("Unknown dataset type")


# -----------------------------------------------------------------------------

# GENERATE --------------------------------------------------------------------
manage_directory(config_dir)

for model in models:
features = get_feature_types(model)
for feature in features:
for task in tasks:
dim_features = feature_dim_dict[feature]
out_channels = out_channels_dict[task]
model_config = get_model_config(model, out_channels, dim_features)
model_config_cls = model_cfg_lookup[model]
trainer_config = TrainerConfig(
accelerator="auto", max_epochs=max_epochs, log_every_n_steps=1
)
config = ConfigExperimentRun(
task_type=task,
seed=1234, # any seed, will be overwritten during actual run
transforms=feature,
use_stratified=(
False if task == TaskType.BETTI_NUMBERS else True
),
learning_rate=lr,
trainer_config=trainer_config,
conf_model=model_config,
)

json_string = config.model_dump_json()

python_dict = json.loads(json_string)
yaml_string = yaml.dump(python_dict)
yaml_file_path = os.path.join(
config_dir,
f"{model.name.lower()}_{task.name.lower()}_{feature.name.lower()}.yaml",
)
with open(yaml_file_path, "w") as file:
file.write(yaml_string)
for ds_type in dataset_types:
tasks = get_tasks(ds_type)
for model in models:
features = get_feature_types(model)
for feature in features:
for task in tasks:
dim_features = feature_dim_dict[feature]
out_channels = get_out_channels_dict(ds_type)[task]
model_config = get_model_config(
model, out_channels, dim_features
)
model_config_cls = model_cfg_lookup[model]
trainer_config = TrainerConfig(
accelerator="auto",
max_epochs=max_epochs,
log_every_n_steps=1,
)
config = ConfigExperimentRun(
ds_type=ds_type,
task_type=task,
seed=1234,
transforms=feature,
use_stratified=(
False if task == TaskType.BETTI_NUMBERS else True
),
use_imbalance_weighting=use_imbalance_weights,
learning_rate=lr,
trainer_config=trainer_config,
conf_model=model_config,
)

json_string = config.model_dump_json()

python_dict = json.loads(json_string)
yaml_string = yaml.dump(python_dict)
config_identifier = config.get_identifier()
yaml_file_path = os.path.join(
config_dir,
f"{config_identifier}.yaml",
)
with open(yaml_file_path, "w") as file:
file.write(yaml_string)
# -----------------------------------------------------------------------------
Loading

0 comments on commit eac11c3

Please sign in to comment.