-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding betti number task and GCN architectures on it for pytorch ligh…
…tning.
- Loading branch information
Showing
8 changed files
with
231 additions
and
3 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import lightning as L | ||
import torch | ||
import wandb | ||
from torch.utils.data import Subset | ||
from torch_geometric.loader import DataLoader | ||
from torch_geometric.transforms import FaceToEdge, OneHotDegree | ||
from torchvision import transforms | ||
|
||
from experiments.experiment_utils import get_wandb_logger | ||
from experiments.lightning_modules.GraphCommonModuleBettiNumbers import ( | ||
GraphCommonModuleBettiNumbers, | ||
) | ||
from mantra.simplicial import SimplicialDataset | ||
from mantra.transforms import TriangulationToFaceTransform, DegreeTransform | ||
from models.graphs.GCN import GCNetwork | ||
|
||
|
||
class GCNModule(GraphCommonModuleBettiNumbers): | ||
def __init__( | ||
self, | ||
hidden_channels, | ||
num_node_features, | ||
out_channels, | ||
num_hidden_layers, | ||
learning_rate=0.01, | ||
): | ||
base_model = GCNetwork( | ||
hidden_channels=hidden_channels, | ||
num_node_features=num_node_features, | ||
out_channels=out_channels, | ||
num_hidden_layers=num_hidden_layers, | ||
) | ||
super().__init__(base_model=base_model) | ||
self.learning_rate = learning_rate | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.Adam( | ||
self.base_model.parameters(), lr=self.learning_rate | ||
) | ||
return optimizer | ||
|
||
|
||
def load_dataset_with_transformations(): | ||
tr = transforms.Compose( | ||
[ | ||
TriangulationToFaceTransform(), | ||
FaceToEdge(remove_faces=False), | ||
DegreeTransform(), | ||
OneHotDegree(max_degree=8, cat=False), | ||
] | ||
) | ||
dataset = SimplicialDataset(root="./data", transform=tr) | ||
return dataset | ||
|
||
|
||
def single_experiment_betti_numbers_gnn(): | ||
# =============================== | ||
# Training parameters | ||
# =============================== | ||
hidden_channels = 64 | ||
num_hidden_layers = 2 | ||
batch_size = 32 | ||
max_epochs = 100 | ||
learning_rate = 0.1 | ||
num_workers = 0 | ||
# =============================== | ||
dataset = load_dataset_with_transformations() | ||
model = GCNModule( | ||
hidden_channels=hidden_channels, | ||
num_node_features=dataset.num_node_features, | ||
out_channels=3, # Three different Betti numbers | ||
num_hidden_layers=num_hidden_layers, | ||
learning_rate=learning_rate, | ||
) | ||
train_ds = Subset(dataset, dataset.train_betti_numbers_indices) | ||
test_ds = Subset(dataset, dataset.test_betti_numbers_indices) | ||
train_dl = DataLoader( | ||
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers | ||
) | ||
test_dl = DataLoader( | ||
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers | ||
) | ||
logger = get_wandb_logger(task_name="betti_numbers", model_name="GCN") | ||
trainer = L.Trainer( | ||
max_epochs=max_epochs, log_every_n_steps=1, logger=logger | ||
) | ||
trainer.fit( | ||
model, | ||
train_dl, | ||
test_dl, | ||
) | ||
wandb.finish() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from typing import Literal | ||
|
||
import lightning as L | ||
from torch.nn import ModuleList | ||
|
||
from experiments.metrics import GeneralAccuracy | ||
|
||
|
||
class BaseBettiNumbersModule(L.LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
# 3 accuracies per modality: one for each betti number | ||
self.training_accuracies = ModuleList( | ||
[GeneralAccuracy() for _ in range(3)] | ||
) | ||
self.validation_accuracies = ModuleList( | ||
[GeneralAccuracy() for _ in range(3)] | ||
) | ||
self.test_accuracies = ModuleList( | ||
[GeneralAccuracy() for _ in range(3)] | ||
) | ||
|
||
def log_scores( | ||
self, x_hat, y, batch_len, step: Literal["train", "test", "validation"] | ||
): | ||
# x_hat is a float tensor of shape (batch_len, 3), one column per betti number and row per sample | ||
# y is a long tensor of shape (batch_len, 3), one row per sample and column per betti number | ||
if step == "train": | ||
for dim in range(3): | ||
self.training_accuracies[dim]( | ||
x_hat[:, dim].round().long(), y[:, dim].long() | ||
) | ||
self.log( | ||
f"training_accuracy_betti_{dim}", | ||
self.training_accuracies[dim], | ||
prog_bar=True, | ||
on_step=False, | ||
on_epoch=True, | ||
batch_size=batch_len, | ||
) | ||
|
||
elif step == "test": | ||
for dim in range(3): | ||
self.test_accuracies[dim]( | ||
x_hat[:, dim].round().long(), y[:, dim].long() | ||
) | ||
self.log( | ||
f"test_accuracy_betti_{dim}", | ||
self.test_accuracies[dim], | ||
prog_bar=True, | ||
on_step=False, | ||
on_epoch=True, | ||
batch_size=batch_len, | ||
) | ||
elif step == "validation": | ||
for dim in range(3): | ||
self.validation_accuracies[dim]( | ||
x_hat[:, dim].round().long(), y[:, dim].long() | ||
) | ||
self.log( | ||
f"validation_accuracy_betti_{dim}", | ||
self.validation_accuracies[dim], | ||
prog_bar=True, | ||
on_step=False, | ||
on_epoch=True, | ||
batch_size=batch_len, | ||
) | ||
|
||
def test_step(self, batch, batch_idx): | ||
return self.general_step(batch, batch_idx, "test") | ||
|
||
def validation_step(self, batch, batch_idx): | ||
return self.general_step(batch, batch_idx, "validation") | ||
|
||
def training_step(self, batch, batch_idx): | ||
return self.general_step(batch, batch_idx, "train") |
34 changes: 34 additions & 0 deletions
34
experiments/lightning_modules/GraphCommonModuleBettiNumbers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from experiments.lightning_modules.BaseModuleBettiNumbers import ( | ||
BaseBettiNumbersModule, | ||
) | ||
|
||
|
||
class GraphCommonModuleBettiNumbers(BaseBettiNumbersModule): | ||
def __init__(self, base_model): | ||
super().__init__() | ||
self.base_model = base_model | ||
|
||
def forward(self, x, edge_index, batch): | ||
x = self.base_model(x, edge_index, batch) | ||
return x | ||
|
||
def general_step(self, batch, batch_idx, step: str): | ||
x_hat = self(batch.x, batch.edge_index, batch.batch) | ||
y = torch.tensor( | ||
batch.betti_numbers, device=x_hat.device, dtype=x_hat.dtype | ||
) | ||
batch_len = len(y) | ||
loss = nn.functional.mse_loss(x_hat, y) | ||
self.log( | ||
f"{step}_loss", | ||
loss, | ||
prog_bar=True, | ||
batch_size=batch_len, | ||
on_step=False, | ||
on_epoch=True, | ||
) | ||
self.log_scores(x_hat, y, batch_len, step) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch | ||
from torch import Tensor | ||
from torchmetrics import Metric | ||
|
||
|
||
class GeneralAccuracy(Metric): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.add_state( | ||
"correct", default=torch.tensor(0), dist_reduce_fx="sum" | ||
) | ||
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") | ||
|
||
def update(self, preds: Tensor, target: Tensor) -> None: | ||
if preds.shape != target.shape: | ||
raise ValueError("preds and target must have the same shape") | ||
|
||
self.correct += torch.sum(preds == target) | ||
self.total += target.numel() | ||
|
||
def compute(self) -> Tensor: | ||
return self.correct.float() / self.total |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters