From 643bf8895b26d1b5bea199620b55b2a250d78fb0 Mon Sep 17 00:00:00 2001 From: ErnstRoell Date: Tue, 19 Nov 2024 10:52:01 +0100 Subject: [PATCH] Finished implementation for DECT --- code/datasets/transforms.py | 4 +- code/experiments/generate_configs_dect.py | 237 ++++++++++++++++++++++ code/models/DECT.py | 97 ++++++++- 3 files changed, 332 insertions(+), 6 deletions(-) create mode 100644 code/experiments/generate_configs_dect.py diff --git a/code/datasets/transforms.py b/code/datasets/transforms.py index 9f04262..d4f3830 100644 --- a/code/datasets/transforms.py +++ b/code/datasets/transforms.py @@ -71,7 +71,7 @@ class TriangulationToFaceTransform: the FaceToEdge transform by default creates undirected edges. """ - def __init__(self, remove_triangulation: bool = True) -> None: + def __init__(self, remove_triangulation: bool = False) -> None: self.remove_triangulation = remove_triangulation def __call__(self, data): @@ -87,6 +87,8 @@ def __call__(self, data): if self.remove_triangulation: data.triangulation = None + else: + data.triangulation = torch.tensor(data.triangulation) return data diff --git a/code/experiments/generate_configs_dect.py b/code/experiments/generate_configs_dect.py new file mode 100644 index 0000000..f274a31 --- /dev/null +++ b/code/experiments/generate_configs_dect.py @@ -0,0 +1,237 @@ +import sys +import os + +sys.path.append(os.curdir) +from metrics.tasks import TaskType +from datasets.transforms import TransformType +from datasets.dataset_types import DatasetType +from models import ModelType, model_cfg_lookup +from experiments.utils.configs import ConfigExperimentRun, TrainerConfig +import yaml +import json +import os +import shutil +import argparse +from typing import List, Dict + +# ARGS ------------------------------------------------------------------------ +parser = argparse.ArgumentParser( + description="Argument parser for experiment configurations." +) +parser.add_argument( + "--max_epochs", + type=int, + default=10, + help="Maximum number of epochs.", +) + +parser.add_argument( + "--config_dir", + type=str, + default="./configs", + help="Directory where config files shall be stored", +) + +parser.add_argument( + "--lr", + type=float, + default=0.01, + help="Maximum number of epochs.", +) + +parser.add_argument( + "--use_imbalance_weighting", + action="store_true", + help="Whether to weight loss terms with imbalance weights.", +) + +args = parser.parse_args() +max_epochs: int = args.max_epochs +lr: float = args.lr +config_dir: str = args.config_dir +use_imbalance_weights: bool = args.use_imbalance_weighting +# ----------------------------------------------------------------------------- + +# CONFIGS --------------------------------------------------------------------- + +# DS TYPE ### +dataset_types = [ + DatasetType.FULL_2D, + DatasetType.FULL_3D, + DatasetType.NO_NAMELESS_2D, +] +# ########### + +# TASKS ##### +tasks_mantra2 = [TaskType.ORIENTABILITY, TaskType.NAME, TaskType.BETTI_NUMBERS] +tasks_mantra3 = [TaskType.BETTI_NUMBERS, TaskType.ORIENTABILITY] +# ########### + +# TRANSFORMS +graph_features = [ + TransformType.degree_transform, + TransformType.degree_transform_onehot, + TransformType.random_node_features, +] +simplicial_features = [ + TransformType.degree_transform_sc, + TransformType.random_simplices_features, +] +# ########### + +# MODELS #### +graph_models = { + # ModelType.GCN, + # ModelType.GAT, + # ModelType.MLP, + # ModelType.TAG, + # ModelType.TransfConv, + ModelType.DECT +} +simplicial_models = { + # ModelType.SAN, + # ModelType.SCCN, + # ModelType.SCCNN, + # ModelType.SCN, +} +models = list(graph_models) # + list(simplicial_models) +# ########### + +# MISC ###### +feature_dim_dict = { + TransformType.degree_transform: 1, + TransformType.degree_transform_onehot: 10, + TransformType.random_node_features: 8, + # TransformType.degree_transform_sc: [1, 2, 1], + # TransformType.random_simplices_features: [8, 8, 8], +} + +out_channels_dict_mantra2_full = { + TaskType.ORIENTABILITY: 1, + TaskType.NAME: 5, + TaskType.BETTI_NUMBERS: 3, +} +out_channels_dict_mantra2_no_nameless = { + TaskType.ORIENTABILITY: 1, + TaskType.NAME: 4, + TaskType.BETTI_NUMBERS: 3, +} +out_channels_dict_mantra3 = { + TaskType.ORIENTABILITY: 1, + TaskType.BETTI_NUMBERS: 4, +} +# ########### + +# ----------------------------------------------------------------------------- + + +# UTILS ----------------------------------------------------------------------- +def get_feature_types(model: ModelType): + if model in graph_models: + return graph_features + else: + return simplicial_features + + +def get_model_config( + model: ModelType, out_channels: int, dim_features: int | tuple[int] +): + model_config_cls = model_cfg_lookup[model] + if model in graph_models: + model_config = model_config_cls( + out_channels=out_channels, num_node_features=dim_features + ) + else: + model_config = model_config_cls( + out_channels=out_channels, in_channels=tuple(dim_features) + ) + return model_config + + +def manage_directory(path: str): + """ + Removes directory if exists and creates and empty directory + """ + if os.path.exists(path): + if os.path.isdir(path): + shutil.rmtree(path) + else: + os.remove(path) + os.makedirs(path) + + +def get_tasks(ds_type: DatasetType) -> List[TaskType]: + tasks = ( + tasks_mantra2.copy() + if ( + ds_type == DatasetType.FULL_2D + or ds_type == DatasetType.NO_NAMELESS_2D + ) + else tasks_mantra3.copy() + ) + return tasks + + +def get_out_channels_dict(ds_type: DatasetType) -> Dict[TaskType, int]: + if ds_type == DatasetType.FULL_2D: + return out_channels_dict_mantra2_full.copy() + elif ds_type == DatasetType.FULL_3D: + return out_channels_dict_mantra3.copy() + elif ds_type == DatasetType.NO_NAMELESS_2D: + return out_channels_dict_mantra2_no_nameless.copy() + else: + raise ValueError("Unknown dataset type") + + +# ----------------------------------------------------------------------------- + +# GENERATE -------------------------------------------------------------------- +manage_directory(config_dir) + +for ds_type in dataset_types: + tasks = get_tasks(ds_type) + for model in models: + + # if ds_type == DatasetType.FULL_3D and model in simplicial_models: + # continue + + features = get_feature_types(model) + for feature in features: + for task in tasks: + dim_features = feature_dim_dict[feature] + out_channels = get_out_channels_dict(ds_type)[task] + model_config = get_model_config( + model, out_channels, dim_features + ) + model_config_cls = model_cfg_lookup[model] + trainer_config = TrainerConfig( + accelerator="auto", + max_epochs=max_epochs, + log_every_n_steps=1, + ) + config = ConfigExperimentRun( + ds_type=ds_type, + task_type=task, + seed=1234, + transforms=feature, + use_stratified=( + False if task == TaskType.BETTI_NUMBERS else True + ), + use_imbalance_weighting=use_imbalance_weights, + learning_rate=lr, + trainer_config=trainer_config, + conf_model=model_config, + ) + + json_string = config.model_dump_json() + + python_dict = json.loads(json_string) + yaml_string = yaml.dump(python_dict) + config_identifier = config.get_identifier() + yaml_file_path = os.path.join( + config_dir, + f"{config_identifier}.yaml", + ) + with open(yaml_file_path, "w") as file: + file.write(yaml_string) +# ----------------------------------------------------------------------------- diff --git a/code/models/DECT.py b/code/models/DECT.py index fb0d03c..7d1f508 100644 --- a/code/models/DECT.py +++ b/code/models/DECT.py @@ -26,7 +26,6 @@ class EctConfig: num_thetas: int = 32 bump_steps: int = 32 r: float = 1.1 - ect_type: str = "points" normalized: bool = True @@ -65,6 +64,88 @@ def compute_ect_points(data, index, v, lin, out, scale): return compute_ecc(nh, index, lin, out, scale) +def compute_ect_faces(batch, index, v, lin, out, scale): + """Computes the Euler Characteristic Transform of a batch of meshes. + + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, edges, faces and batch + index. + v: torch.FloatTensor + The direction vector that contains the directions. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + """ + # Compute the node heigths + nh = batch.x @ v + + # Perform a lookup with the edge indices on node heights, this replaces the + # node index with its node height and then compute the maximum over the + # columns to compute the edge height. + eh, _ = nh[batch.edge_index].max(dim=0) + + # Do the same thing for the faces. + fh, _ = nh[batch.face].max(dim=0) + + # Compute which batch an edge belongs to. We take the first index of the + # edge (or faces) and do a lookup on the batch index of that node in the + # batch indices of the nodes. + batch_index_nodes = batch.batch + batch_index_edges = batch.batch[batch.edge_index[0]] + batch_index_faces = batch.batch[batch.face[0]] + + return ( + compute_ecc(nh, batch_index_nodes, lin, out, scale) + - compute_ecc(eh, batch_index_edges, lin, out, scale) + + compute_ecc(fh, batch_index_faces, lin, out, scale) + ) + + +def compute_ect_tetrahedra(batch, index, v, lin, out, scale): + """Computes the Euler Characteristic Transform of a batch of meshes. + + Parameters + ---------- + batch : Batch + A batch of data containing the node coordinates, edges, faces and batch + index. + v: torch.FloatTensor + The direction vector that contains the directions. + lin: torch.FloatTensor + The discretization of the interval [-1,1] each node height falls in this + range due to rescaling in normalizing the data. + """ + # Compute the node heigths + nh = batch.x @ v + + # Perform a lookup with the edge indices on node heights, this replaces the + # node index with its node height and then compute the maximum over the + # columns to compute the edge height. + eh, _ = nh[batch.edge_index].max(dim=0) + + # Do the same thing for the faces. + fh, _ = nh[batch.face].max(dim=0) + + # Do the same thing for the faces. + th, _ = nh[batch.triangulation.T].max(dim=0) + + # Compute which batch an edge belongs to. We take the first index of the + # edge (or faces) and do a lookup on the batch index of that node in the + # batch indices of the nodes. + batch_index_nodes = batch.batch + batch_index_edges = batch.batch[batch.edge_index[0]] + batch_index_faces = batch.batch[batch.face[0]] + batch_index_triangulation = batch.batch[batch.triangulation.T[0]] + return ( + compute_ecc(nh, batch_index_nodes, lin, out, scale) + - compute_ecc(eh, batch_index_edges, lin, out, scale) + + compute_ecc(fh, batch_index_faces, lin, out, scale) + - compute_ecc(th, batch_index_triangulation, lin, out, scale) + ) + + class EctLayer(nn.Module): """Docstring for EctLayer.""" @@ -83,9 +164,6 @@ def __init__(self, config: EctConfig, v=None): raise AttributeError("Please provide the directions") self.v = nn.Parameter(v, requires_grad=False) - if config.ect_type == "points": - self.compute_ect = compute_ect_points - def forward(self, data: Data, index, scale=500): """Forward method""" out = torch.zeros( @@ -96,7 +174,16 @@ def forward(self, data: Data, index, scale=500): ), device=data.x.device, ) - ect = self.compute_ect(data, index, self.v, self.lin, out, scale) + + if data.triangulation.shape[1] == 3: + ect = compute_ect_faces(data, index, self.v, self.lin, out, scale) + elif data.triangulation.shape[1] == 4: + ect = compute_ect_tetrahedra( + data, index, self.v, self.lin, out, scale + ) + else: + raise ValueError("The triangulation is not correct.") + if self.config.normalized: return ect / torch.amax(ect, dim=(1, 2)).unsqueeze(1).unsqueeze(1) return ect