diff --git a/experiments/betti_numbers/graphs/GATSimplex2Vec.py b/experiments/betti_numbers/graphs/GATSimplex2Vec.py new file mode 100644 index 0000000..8c4b8fd --- /dev/null +++ b/experiments/betti_numbers/graphs/GATSimplex2Vec.py @@ -0,0 +1,87 @@ +import torch +import torchvision.transforms as transforms +from torch_geometric.transforms import FaceToEdge + +from experiments.experiment_utils import perform_experiment +from experiments.lightning_modules.GraphCommonModuleBettiNumbers import ( + GraphCommonModuleBettiNumbers, +) +from mantra.simplicial import SimplicialDataset +from mantra.transforms import ( + TriangulationToFaceTransform, + SetNumNodesTransform, + DegreeTransform, + Simplex2VecTransform, +) +from models.graphs.GAT import GATNetwork + + +class GATSimplexToVecModule(GraphCommonModuleBettiNumbers): + def __init__( + self, + hidden_channels, + num_node_features, + out_channels, + num_heads, + num_hidden_layers, + learning_rate=0.0001, + ): + base_model = GATNetwork( + hidden_channels=hidden_channels, + num_node_features=num_node_features, + out_channels=out_channels, + num_heads=num_heads, + num_hidden_layers=num_hidden_layers, + ) + super().__init__(base_model=base_model) + self.learning_rate = learning_rate + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + +def load_dataset_with_transformations(): + tr = transforms.Compose( + [ + TriangulationToFaceTransform(), + SetNumNodesTransform(), + FaceToEdge(remove_faces=False), + DegreeTransform(), + Simplex2VecTransform(), + ] + ) + dataset = SimplicialDataset(root="./data", transform=tr) + return dataset + + +def single_experiment_betti_numbers_gat_simplex2vec(): + # =============================== + # Training parameters + # =============================== + hidden_channels = 64 + num_hidden_layers = 2 + num_heads = 4 + batch_size = 32 + max_epochs = 100 + learning_rate = 0.0001 + num_workers = 0 + # =============================== + dataset = load_dataset_with_transformations() + model = GATSimplexToVecModule( + hidden_channels=hidden_channels, + num_node_features=dataset.num_node_features, + out_channels=3, # Three different Betti numbers + num_heads=num_heads, + num_hidden_layers=num_hidden_layers, + learning_rate=learning_rate, + ) + perform_experiment( + task="betti_numbers", + model=model, + model_name="GATSimplex2Vec", + dataset=dataset, + batch_size=batch_size, + max_epochs=max_epochs, + num_workers=num_workers, + ) diff --git a/experiments/name/graphs/GATSimplex2Vec.py b/experiments/name/graphs/GATSimplex2Vec.py new file mode 100644 index 0000000..15ea447 --- /dev/null +++ b/experiments/name/graphs/GATSimplex2Vec.py @@ -0,0 +1,89 @@ +import torch +import torchvision.transforms as transforms +from torch_geometric.transforms import FaceToEdge + +from experiments.experiment_utils import perform_experiment +from experiments.lightning_modules.GraphCommonModuleName import ( + GraphCommonModuleName, +) +from mantra.simplicial import SimplicialDataset +from mantra.transforms import ( + TriangulationToFaceTransform, + SetNumNodesTransform, + DegreeTransform, + Simplex2VecTransform, + NameToClassTransform, +) +from models.graphs.GAT import GATNetwork + + +class GATSimplexToVecModule(GraphCommonModuleName): + def __init__( + self, + hidden_channels, + num_node_features, + out_channels, + num_heads, + num_hidden_layers, + learning_rate=0.0001, + ): + base_model = GATNetwork( + hidden_channels=hidden_channels, + num_node_features=num_node_features, + out_channels=out_channels, + num_heads=num_heads, + num_hidden_layers=num_hidden_layers, + ) + super().__init__(base_model=base_model) + self.learning_rate = learning_rate + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return optimizer + + +def load_dataset_with_transformations(): + tr = transforms.Compose( + [ + TriangulationToFaceTransform(), + SetNumNodesTransform(), + FaceToEdge(remove_faces=False), + DegreeTransform(), + NameToClassTransform(), + Simplex2VecTransform(), + ] + ) + dataset = SimplicialDataset(root="./data", transform=tr) + return dataset + + +def single_experiment_name_gat_simplex2vec(): + # =============================== + # Training parameters + # =============================== + hidden_channels = 64 + num_hidden_layers = 2 + num_heads = 4 + batch_size = 32 + max_epochs = 100 + learning_rate = 0.0001 + num_workers = 0 + # =============================== + dataset = load_dataset_with_transformations() + model = GATSimplexToVecModule( + hidden_channels=hidden_channels, + num_node_features=dataset.num_node_features, + out_channels=5, # Five different name classes + num_heads=num_heads, + num_hidden_layers=num_hidden_layers, + learning_rate=learning_rate, + ) + perform_experiment( + task="name", + model=model, + model_name="GATSimplex2Vec", + dataset=dataset, + batch_size=batch_size, + max_epochs=max_epochs, + num_workers=num_workers, + ) diff --git a/main.py b/main.py index eca20bd..f190506 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,8 @@ -from experiments.betti_numbers.simplicial_complexes.SCNN import ( - single_experiment_betti_numbers_scnn, +from experiments.betti_numbers.graphs.GATSimplex2Vec import ( + single_experiment_betti_numbers_gat_simplex2vec, ) -from experiments.name.simplicial_complexes.SCNN import ( - single_experiment_name_scnn, +from experiments.name.graphs.GATSimplex2Vec import ( + single_experiment_name_gat_simplex2vec, ) if __name__ == "__main__": @@ -21,4 +21,6 @@ # single_experiment_orientability_transformer_conv() # single_experiment_orientability_mlp_constant_shape() # single_experiment_name_scnn() - single_experiment_betti_numbers_scnn() + # single_experiment_betti_numbers_scnn() + single_experiment_name_gat_simplex2vec() + single_experiment_betti_numbers_gat_simplex2vec()