Skip to content

Commit

Permalink
Finished implementation for DECT
Browse files Browse the repository at this point in the history
  • Loading branch information
ErnstRoell committed Nov 19, 2024
1 parent 1193be6 commit 643bf88
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 6 deletions.
4 changes: 3 additions & 1 deletion code/datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class TriangulationToFaceTransform:
the FaceToEdge transform by default creates undirected edges.
"""

def __init__(self, remove_triangulation: bool = True) -> None:
def __init__(self, remove_triangulation: bool = False) -> None:
self.remove_triangulation = remove_triangulation

def __call__(self, data):
Expand All @@ -87,6 +87,8 @@ def __call__(self, data):

if self.remove_triangulation:
data.triangulation = None
else:
data.triangulation = torch.tensor(data.triangulation)

return data

Expand Down
237 changes: 237 additions & 0 deletions code/experiments/generate_configs_dect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import sys
import os

sys.path.append(os.curdir)
from metrics.tasks import TaskType
from datasets.transforms import TransformType
from datasets.dataset_types import DatasetType
from models import ModelType, model_cfg_lookup
from experiments.utils.configs import ConfigExperimentRun, TrainerConfig
import yaml
import json
import os
import shutil
import argparse
from typing import List, Dict

# ARGS ------------------------------------------------------------------------
parser = argparse.ArgumentParser(
description="Argument parser for experiment configurations."
)
parser.add_argument(
"--max_epochs",
type=int,
default=10,
help="Maximum number of epochs.",
)

parser.add_argument(
"--config_dir",
type=str,
default="./configs",
help="Directory where config files shall be stored",
)

parser.add_argument(
"--lr",
type=float,
default=0.01,
help="Maximum number of epochs.",
)

parser.add_argument(
"--use_imbalance_weighting",
action="store_true",
help="Whether to weight loss terms with imbalance weights.",
)

args = parser.parse_args()
max_epochs: int = args.max_epochs
lr: float = args.lr
config_dir: str = args.config_dir
use_imbalance_weights: bool = args.use_imbalance_weighting
# -----------------------------------------------------------------------------

# CONFIGS ---------------------------------------------------------------------

# DS TYPE ###
dataset_types = [
DatasetType.FULL_2D,
DatasetType.FULL_3D,
DatasetType.NO_NAMELESS_2D,
]
# ###########

# TASKS #####
tasks_mantra2 = [TaskType.ORIENTABILITY, TaskType.NAME, TaskType.BETTI_NUMBERS]
tasks_mantra3 = [TaskType.BETTI_NUMBERS, TaskType.ORIENTABILITY]
# ###########

# TRANSFORMS
graph_features = [
TransformType.degree_transform,
TransformType.degree_transform_onehot,
TransformType.random_node_features,
]
simplicial_features = [
TransformType.degree_transform_sc,
TransformType.random_simplices_features,
]
# ###########

# MODELS ####
graph_models = {
# ModelType.GCN,
# ModelType.GAT,
# ModelType.MLP,
# ModelType.TAG,
# ModelType.TransfConv,
ModelType.DECT
}
simplicial_models = {
# ModelType.SAN,
# ModelType.SCCN,
# ModelType.SCCNN,
# ModelType.SCN,
}
models = list(graph_models) # + list(simplicial_models)
# ###########

# MISC ######
feature_dim_dict = {
TransformType.degree_transform: 1,
TransformType.degree_transform_onehot: 10,
TransformType.random_node_features: 8,
# TransformType.degree_transform_sc: [1, 2, 1],
# TransformType.random_simplices_features: [8, 8, 8],
}

out_channels_dict_mantra2_full = {
TaskType.ORIENTABILITY: 1,
TaskType.NAME: 5,
TaskType.BETTI_NUMBERS: 3,
}
out_channels_dict_mantra2_no_nameless = {
TaskType.ORIENTABILITY: 1,
TaskType.NAME: 4,
TaskType.BETTI_NUMBERS: 3,
}
out_channels_dict_mantra3 = {
TaskType.ORIENTABILITY: 1,
TaskType.BETTI_NUMBERS: 4,
}
# ###########

# -----------------------------------------------------------------------------


# UTILS -----------------------------------------------------------------------
def get_feature_types(model: ModelType):
if model in graph_models:
return graph_features
else:
return simplicial_features


def get_model_config(
model: ModelType, out_channels: int, dim_features: int | tuple[int]
):
model_config_cls = model_cfg_lookup[model]
if model in graph_models:
model_config = model_config_cls(
out_channels=out_channels, num_node_features=dim_features
)
else:
model_config = model_config_cls(
out_channels=out_channels, in_channels=tuple(dim_features)
)
return model_config


def manage_directory(path: str):
"""
Removes directory if exists and creates and empty directory
"""
if os.path.exists(path):
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
os.makedirs(path)


def get_tasks(ds_type: DatasetType) -> List[TaskType]:
tasks = (
tasks_mantra2.copy()
if (
ds_type == DatasetType.FULL_2D
or ds_type == DatasetType.NO_NAMELESS_2D
)
else tasks_mantra3.copy()
)
return tasks


def get_out_channels_dict(ds_type: DatasetType) -> Dict[TaskType, int]:
if ds_type == DatasetType.FULL_2D:
return out_channels_dict_mantra2_full.copy()
elif ds_type == DatasetType.FULL_3D:
return out_channels_dict_mantra3.copy()
elif ds_type == DatasetType.NO_NAMELESS_2D:
return out_channels_dict_mantra2_no_nameless.copy()
else:
raise ValueError("Unknown dataset type")


# -----------------------------------------------------------------------------

# GENERATE --------------------------------------------------------------------
manage_directory(config_dir)

for ds_type in dataset_types:
tasks = get_tasks(ds_type)
for model in models:

# if ds_type == DatasetType.FULL_3D and model in simplicial_models:
# continue

features = get_feature_types(model)
for feature in features:
for task in tasks:
dim_features = feature_dim_dict[feature]
out_channels = get_out_channels_dict(ds_type)[task]
model_config = get_model_config(
model, out_channels, dim_features
)
model_config_cls = model_cfg_lookup[model]
trainer_config = TrainerConfig(
accelerator="auto",
max_epochs=max_epochs,
log_every_n_steps=1,
)
config = ConfigExperimentRun(
ds_type=ds_type,
task_type=task,
seed=1234,
transforms=feature,
use_stratified=(
False if task == TaskType.BETTI_NUMBERS else True
),
use_imbalance_weighting=use_imbalance_weights,
learning_rate=lr,
trainer_config=trainer_config,
conf_model=model_config,
)

json_string = config.model_dump_json()

python_dict = json.loads(json_string)
yaml_string = yaml.dump(python_dict)
config_identifier = config.get_identifier()
yaml_file_path = os.path.join(
config_dir,
f"{config_identifier}.yaml",
)
with open(yaml_file_path, "w") as file:
file.write(yaml_string)
# -----------------------------------------------------------------------------
97 changes: 92 additions & 5 deletions code/models/DECT.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class EctConfig:
num_thetas: int = 32
bump_steps: int = 32
r: float = 1.1
ect_type: str = "points"
normalized: bool = True


Expand Down Expand Up @@ -65,6 +64,88 @@ def compute_ect_points(data, index, v, lin, out, scale):
return compute_ecc(nh, index, lin, out, scale)


def compute_ect_faces(batch, index, v, lin, out, scale):
"""Computes the Euler Characteristic Transform of a batch of meshes.
Parameters
----------
batch : Batch
A batch of data containing the node coordinates, edges, faces and batch
index.
v: torch.FloatTensor
The direction vector that contains the directions.
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
"""
# Compute the node heigths
nh = batch.x @ v

# Perform a lookup with the edge indices on node heights, this replaces the
# node index with its node height and then compute the maximum over the
# columns to compute the edge height.
eh, _ = nh[batch.edge_index].max(dim=0)

# Do the same thing for the faces.
fh, _ = nh[batch.face].max(dim=0)

# Compute which batch an edge belongs to. We take the first index of the
# edge (or faces) and do a lookup on the batch index of that node in the
# batch indices of the nodes.
batch_index_nodes = batch.batch
batch_index_edges = batch.batch[batch.edge_index[0]]
batch_index_faces = batch.batch[batch.face[0]]

return (
compute_ecc(nh, batch_index_nodes, lin, out, scale)
- compute_ecc(eh, batch_index_edges, lin, out, scale)
+ compute_ecc(fh, batch_index_faces, lin, out, scale)
)


def compute_ect_tetrahedra(batch, index, v, lin, out, scale):
"""Computes the Euler Characteristic Transform of a batch of meshes.
Parameters
----------
batch : Batch
A batch of data containing the node coordinates, edges, faces and batch
index.
v: torch.FloatTensor
The direction vector that contains the directions.
lin: torch.FloatTensor
The discretization of the interval [-1,1] each node height falls in this
range due to rescaling in normalizing the data.
"""
# Compute the node heigths
nh = batch.x @ v

# Perform a lookup with the edge indices on node heights, this replaces the
# node index with its node height and then compute the maximum over the
# columns to compute the edge height.
eh, _ = nh[batch.edge_index].max(dim=0)

# Do the same thing for the faces.
fh, _ = nh[batch.face].max(dim=0)

# Do the same thing for the faces.
th, _ = nh[batch.triangulation.T].max(dim=0)

# Compute which batch an edge belongs to. We take the first index of the
# edge (or faces) and do a lookup on the batch index of that node in the
# batch indices of the nodes.
batch_index_nodes = batch.batch
batch_index_edges = batch.batch[batch.edge_index[0]]
batch_index_faces = batch.batch[batch.face[0]]
batch_index_triangulation = batch.batch[batch.triangulation.T[0]]
return (
compute_ecc(nh, batch_index_nodes, lin, out, scale)
- compute_ecc(eh, batch_index_edges, lin, out, scale)
+ compute_ecc(fh, batch_index_faces, lin, out, scale)
- compute_ecc(th, batch_index_triangulation, lin, out, scale)
)


class EctLayer(nn.Module):
"""Docstring for EctLayer."""

Expand All @@ -83,9 +164,6 @@ def __init__(self, config: EctConfig, v=None):
raise AttributeError("Please provide the directions")
self.v = nn.Parameter(v, requires_grad=False)

if config.ect_type == "points":
self.compute_ect = compute_ect_points

def forward(self, data: Data, index, scale=500):
"""Forward method"""
out = torch.zeros(
Expand All @@ -96,7 +174,16 @@ def forward(self, data: Data, index, scale=500):
),
device=data.x.device,
)
ect = self.compute_ect(data, index, self.v, self.lin, out, scale)

if data.triangulation.shape[1] == 3:
ect = compute_ect_faces(data, index, self.v, self.lin, out, scale)
elif data.triangulation.shape[1] == 4:
ect = compute_ect_tetrahedra(
data, index, self.v, self.lin, out, scale
)
else:
raise ValueError("The triangulation is not correct.")

if self.config.normalized:
return ect / torch.amax(ect, dim=(1, 2)).unsqueeze(1).unsqueeze(1)
return ect
Expand Down

0 comments on commit 643bf88

Please sign in to comment.