Skip to content

Commit

Permalink
Added GATSimplex2Vec for betti numbers and name tasks.
Browse files Browse the repository at this point in the history
  • Loading branch information
rballeba committed May 6, 2024
1 parent 0c60859 commit 295ccf0
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 5 deletions.
87 changes: 87 additions & 0 deletions experiments/betti_numbers/graphs/GATSimplex2Vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import torchvision.transforms as transforms
from torch_geometric.transforms import FaceToEdge

from experiments.experiment_utils import perform_experiment
from experiments.lightning_modules.GraphCommonModuleBettiNumbers import (
GraphCommonModuleBettiNumbers,
)
from mantra.simplicial import SimplicialDataset
from mantra.transforms import (
TriangulationToFaceTransform,
SetNumNodesTransform,
DegreeTransform,
Simplex2VecTransform,
)
from models.graphs.GAT import GATNetwork


class GATSimplexToVecModule(GraphCommonModuleBettiNumbers):
def __init__(
self,
hidden_channels,
num_node_features,
out_channels,
num_heads,
num_hidden_layers,
learning_rate=0.0001,
):
base_model = GATNetwork(
hidden_channels=hidden_channels,
num_node_features=num_node_features,
out_channels=out_channels,
num_heads=num_heads,
num_hidden_layers=num_hidden_layers,
)
super().__init__(base_model=base_model)
self.learning_rate = learning_rate

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer


def load_dataset_with_transformations():
tr = transforms.Compose(
[
TriangulationToFaceTransform(),
SetNumNodesTransform(),
FaceToEdge(remove_faces=False),
DegreeTransform(),
Simplex2VecTransform(),
]
)
dataset = SimplicialDataset(root="./data", transform=tr)
return dataset


def single_experiment_betti_numbers_gat_simplex2vec():
# ===============================
# Training parameters
# ===============================
hidden_channels = 64
num_hidden_layers = 2
num_heads = 4
batch_size = 32
max_epochs = 100
learning_rate = 0.0001
num_workers = 0
# ===============================
dataset = load_dataset_with_transformations()
model = GATSimplexToVecModule(
hidden_channels=hidden_channels,
num_node_features=dataset.num_node_features,
out_channels=3, # Three different Betti numbers
num_heads=num_heads,
num_hidden_layers=num_hidden_layers,
learning_rate=learning_rate,
)
perform_experiment(
task="betti_numbers",
model=model,
model_name="GATSimplex2Vec",
dataset=dataset,
batch_size=batch_size,
max_epochs=max_epochs,
num_workers=num_workers,
)
89 changes: 89 additions & 0 deletions experiments/name/graphs/GATSimplex2Vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
import torchvision.transforms as transforms
from torch_geometric.transforms import FaceToEdge

from experiments.experiment_utils import perform_experiment
from experiments.lightning_modules.GraphCommonModuleName import (
GraphCommonModuleName,
)
from mantra.simplicial import SimplicialDataset
from mantra.transforms import (
TriangulationToFaceTransform,
SetNumNodesTransform,
DegreeTransform,
Simplex2VecTransform,
NameToClassTransform,
)
from models.graphs.GAT import GATNetwork


class GATSimplexToVecModule(GraphCommonModuleName):
def __init__(
self,
hidden_channels,
num_node_features,
out_channels,
num_heads,
num_hidden_layers,
learning_rate=0.0001,
):
base_model = GATNetwork(
hidden_channels=hidden_channels,
num_node_features=num_node_features,
out_channels=out_channels,
num_heads=num_heads,
num_hidden_layers=num_hidden_layers,
)
super().__init__(base_model=base_model)
self.learning_rate = learning_rate

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer


def load_dataset_with_transformations():
tr = transforms.Compose(
[
TriangulationToFaceTransform(),
SetNumNodesTransform(),
FaceToEdge(remove_faces=False),
DegreeTransform(),
NameToClassTransform(),
Simplex2VecTransform(),
]
)
dataset = SimplicialDataset(root="./data", transform=tr)
return dataset


def single_experiment_name_gat_simplex2vec():
# ===============================
# Training parameters
# ===============================
hidden_channels = 64
num_hidden_layers = 2
num_heads = 4
batch_size = 32
max_epochs = 100
learning_rate = 0.0001
num_workers = 0
# ===============================
dataset = load_dataset_with_transformations()
model = GATSimplexToVecModule(
hidden_channels=hidden_channels,
num_node_features=dataset.num_node_features,
out_channels=5, # Five different name classes
num_heads=num_heads,
num_hidden_layers=num_hidden_layers,
learning_rate=learning_rate,
)
perform_experiment(
task="name",
model=model,
model_name="GATSimplex2Vec",
dataset=dataset,
batch_size=batch_size,
max_epochs=max_epochs,
num_workers=num_workers,
)
12 changes: 7 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from experiments.betti_numbers.simplicial_complexes.SCNN import (
single_experiment_betti_numbers_scnn,
from experiments.betti_numbers.graphs.GATSimplex2Vec import (
single_experiment_betti_numbers_gat_simplex2vec,
)
from experiments.name.simplicial_complexes.SCNN import (
single_experiment_name_scnn,
from experiments.name.graphs.GATSimplex2Vec import (
single_experiment_name_gat_simplex2vec,
)

if __name__ == "__main__":
Expand All @@ -21,4 +21,6 @@
# single_experiment_orientability_transformer_conv()
# single_experiment_orientability_mlp_constant_shape()
# single_experiment_name_scnn()
single_experiment_betti_numbers_scnn()
# single_experiment_betti_numbers_scnn()
single_experiment_name_gat_simplex2vec()
single_experiment_betti_numbers_gat_simplex2vec()

0 comments on commit 295ccf0

Please sign in to comment.