-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adapting experiments of GCN and Simplex2Vec to PyTorch Lightning expe…
…riments. Rearranging entries and cleaning up files.
- Loading branch information
Showing
24 changed files
with
477 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
name: Enforce coding style | ||
|
||
on: [push, pull_request] | ||
on: [ push, pull_request ] | ||
|
||
jobs: | ||
lint: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
experiments/lightning_modules/GraphCommonModuleOrientability.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from torch import nn | ||
|
||
from experiments.lightning_modules.BaseModuleOrientability import ( | ||
BaseOrientabilityModule, | ||
) | ||
|
||
|
||
class GraphCommonModuleOrientability(BaseOrientabilityModule): | ||
def __init__(self, base_model): | ||
super().__init__() | ||
self.base_model = base_model | ||
|
||
def forward(self, x, edge_index, batch): | ||
x = self.base_model(x, edge_index, batch) | ||
return x | ||
|
||
def general_step(self, batch, batch_idx, step: str): | ||
batch_len = len(batch.y) | ||
x_hat = self(batch.x, batch.edge_index, batch.batch) | ||
# Squeeze x_hat to match the shape of y | ||
x_hat = x_hat.squeeze() | ||
y = batch.y.float() | ||
loss = nn.functional.binary_cross_entropy_with_logits(x_hat, y) | ||
self.log( | ||
f"{step}_loss", | ||
loss, | ||
prog_bar=True, | ||
batch_size=batch_len, | ||
on_step=False, | ||
on_epoch=True, | ||
) | ||
self.log_accuracies(x_hat, y, batch_len, step) | ||
return loss |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import lightning as L | ||
import torch | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import Subset | ||
from torch_geometric.data import DataLoader | ||
from torch_geometric.transforms import FaceToEdge | ||
|
||
from experiments.lightning_modules.GraphCommonModuleOrientability import ( | ||
GraphCommonModuleOrientability, | ||
) | ||
from mantra.simplicial import SimplicialDataset | ||
from mantra.transforms import ( | ||
TriangulationToFaceTransform, | ||
SetNumNodesTransform, | ||
DegreeTransform, | ||
OrientableToClassTransform, | ||
Simplex2VecTransform, | ||
) | ||
from models.graphs.GAT import GATNetwork | ||
|
||
|
||
class GATSimplexToVecModule(GraphCommonModuleOrientability): | ||
def __init__( | ||
self, | ||
hidden_channels, | ||
num_node_features, | ||
out_channels, | ||
num_heads, | ||
num_hidden_layers, | ||
learning_rate=0.0001, | ||
): | ||
base_model = GATNetwork( | ||
hidden_channels=hidden_channels, | ||
num_node_features=num_node_features, | ||
out_channels=out_channels, | ||
num_heads=num_heads, | ||
num_hidden_layers=num_hidden_layers, | ||
) | ||
super().__init__(base_model=base_model) | ||
self.learning_rate = learning_rate | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) | ||
return optimizer | ||
|
||
|
||
def load_dataset_with_transformations(): | ||
tr = transforms.Compose( | ||
[ | ||
TriangulationToFaceTransform(), | ||
SetNumNodesTransform(), | ||
FaceToEdge(remove_faces=False), | ||
DegreeTransform(), | ||
OrientableToClassTransform(), | ||
Simplex2VecTransform(), | ||
] | ||
) | ||
dataset = SimplicialDataset(root="./data", transform=tr) | ||
return dataset | ||
|
||
|
||
def single_experiment_orientability_gat_simplex2vec(): | ||
# =============================== | ||
# Training parameters | ||
# =============================== | ||
hidden_channels = 64 | ||
num_hidden_layers = 2 | ||
num_heads = 4 | ||
batch_size = 32 | ||
max_epochs = 100 | ||
learning_rate = 0.0001 | ||
# =============================== | ||
dataset = load_dataset_with_transformations() | ||
model = GATSimplexToVecModule( | ||
hidden_channels=hidden_channels, | ||
num_node_features=dataset.num_node_features, | ||
out_channels=1, | ||
num_heads=num_heads, | ||
num_hidden_layers=num_hidden_layers, | ||
learning_rate=learning_rate, | ||
) | ||
train_ds = Subset(dataset, dataset.train_orientability_indices) | ||
test_ds = Subset(dataset, dataset.test_orientability_indices) | ||
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) | ||
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) | ||
trainer = L.Trainer(max_epochs=max_epochs, log_every_n_steps=1) | ||
trainer.fit(model, train_dl, test_dl) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,81 @@ | ||
# TODO | ||
import lightning as L | ||
import torch | ||
import torchvision.transforms as transforms | ||
from torch.utils.data import Subset | ||
from torch_geometric.loader import DataLoader | ||
from torch_geometric.transforms import FaceToEdge | ||
|
||
from experiments.lightning_modules.GraphCommonModuleOrientability import ( | ||
GraphCommonModuleOrientability, | ||
) | ||
from mantra.simplicial import SimplicialDataset | ||
from mantra.transforms import ( | ||
TriangulationToFaceTransform, | ||
DegreeTransform, | ||
OrientableToClassTransform, | ||
) | ||
from models.graphs.GCN import GCNetwork | ||
|
||
|
||
class GCNModule(GraphCommonModuleOrientability): | ||
def __init__( | ||
self, | ||
hidden_channels, | ||
num_node_features, | ||
out_channels, | ||
num_hidden_layers, | ||
learning_rate=0.01, | ||
): | ||
base_model = GCNetwork( | ||
hidden_channels=hidden_channels, | ||
num_node_features=num_node_features, | ||
out_channels=out_channels, | ||
num_hidden_layers=num_hidden_layers, | ||
) | ||
super().__init__(base_model=base_model) | ||
self.learning_rate = learning_rate | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam( | ||
self.base_model.parameters(), lr=self.learning_rate | ||
) | ||
return optimizer | ||
|
||
|
||
def load_dataset_with_transformations(): | ||
tr = transforms.Compose( | ||
[ | ||
SimplicialComplexTransform(), | ||
SimplicialComplexOnesTransform(ones_length=10), | ||
DimZeroHodgeLaplacianSimplicialComplexTransform(), | ||
DimOneHodgeLaplacianUpSimplicialComplexTransform(), | ||
DimOneHodgeLaplacianDownSimplicialComplexTransform(), | ||
DimTwoHodgeLaplacianSimplicialComplexTransform(), | ||
OrientableToClassSimplicialComplexTransform(), | ||
TriangulationToFaceTransform(), | ||
FaceToEdge(remove_faces=False), | ||
DegreeTransform(), | ||
OrientableToClassTransform(), | ||
] | ||
) | ||
dataset = SimplicialDataset(root="./data", transform=tr) | ||
return dataset | ||
|
||
|
||
def single_experiment_orientability_gcn(): | ||
def single_experiment_orientability_gnn(): | ||
# =============================== | ||
# Training parameters | ||
# =============================== | ||
hidden_channels = 64 | ||
num_hidden_layers = 2 | ||
batch_size = 32 | ||
max_epochs = 100 | ||
learning_rate = 0.1 | ||
# =============================== | ||
dataset = load_dataset_with_transformations() | ||
model = GCNModule( | ||
hidden_channels=hidden_channels, | ||
num_node_features=dataset.num_node_features, | ||
out_channels=1, # Binary classification | ||
num_hidden_layers=num_hidden_layers, | ||
learning_rate=learning_rate, | ||
) | ||
train_ds = Subset(dataset, dataset.train_orientability_indices) | ||
test_ds = Subset(dataset, dataset.test_orientability_indices) | ||
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True) | ||
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False) | ||
trainer = L.Trainer(max_epochs=max_epochs, log_every_n_steps=1) | ||
trainer.fit(model, train_dl, test_dl) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.