Skip to content

Commit

Permalink
separate script for training and benchmarking
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Jun 24, 2024
1 parent 4a12f9e commit a6db5d8
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 101 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
.ignore*
raw_data
data
test*.*
**/__pycache__
/lightning_logs/
k-simplex*
RandomWalks*
results
wandb
configs
*.lock
*.vscode

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
5 changes: 5 additions & 0 deletions experiments/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from models.models import ModelType
import yaml
from typing import Any, List
import os


class TrainerConfig(BaseSettings):
Expand Down Expand Up @@ -43,6 +44,10 @@ class ConfigExperimentRun(BaseSettings):
discriminator=Discriminator(get_discriminator_value)
)

def get_checkpoint_path(self, base_folder: str):
identifier = f"{self.transforms.name.lower()}_{self.task_type.name.lower()}_{self.conf_model.type.name.lower()}_seed_{self.seed}.ckpt"
return os.path.join(base_folder, identifier)


def load_config(config_fpath: str) -> ConfigExperimentRun:
with open(config_fpath, "r") as file:
Expand Down
30 changes: 26 additions & 4 deletions experiments/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@
from experiments.configs import ConfigExperimentRun
from models import model_lookup
from metrics.tasks import TaskType
from typing import Dict
from typing import Dict, Optional, Tuple
from datasets.simplicial import SimplicialDataModule
from models.base import BaseModel
from experiments.loggers import get_wandb_logger
import lightning as L
import uuid
from datasets.transforms import transforms_lookup
from lightning.pytorch.loggers import WandbLogger


def run_configuration(config: ConfigExperimentRun):

def get_setup(
config: ConfigExperimentRun,
) -> Tuple[SimplicialDataModule, BaseModel, L.Trainer, WandbLogger]:
run_id = str(uuid.uuid4())
transforms = transforms_lookup[config.transforms]
task_lookup: Dict[TaskType, Task] = get_task_lookup(transforms)
Expand Down Expand Up @@ -54,7 +56,7 @@ def run_configuration(config: ConfigExperimentRun):
task_lookup[config.task_type].accuracies,
task_lookup[config.task_type].loss_fn,
learning_rate=config.learning_rate,
imbalance=imbalance
imbalance=imbalance,
)

logger = get_wandb_logger(
Expand All @@ -73,8 +75,28 @@ def run_configuration(config: ConfigExperimentRun):
fast_dev_run=False,
)

return dm, lit_model, trainer, logger


def run_configuration(
config: ConfigExperimentRun, save_checkpoint_path: Optional[str] = None
):
dm, lit_model, trainer, logger = get_setup(config)

# run
trainer.fit(lit_model, dm)
logger.experiment.finish()

if save_checkpoint_path:
print(f"[INFO] Saving checkpoint here {save_checkpoint_path}")
trainer.save_checkpoint(save_checkpoint_path)

return trainer


def benchmark_configuration(
config: ConfigExperimentRun, save_checkpoint_path: str
):
dm, lit_model, trainer, logger = get_setup(config)

trainer.test(lit_model, dm, save_checkpoint_path)
2 changes: 1 addition & 1 deletion experiments/vis/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .result_handler import ResultHandler
from .result_handler import ResultHandler
41 changes: 29 additions & 12 deletions experiments/vis/overview_barplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,45 @@
import numpy as np
from .result_handler import ResultHandler


def overview_barplot(result_handler: ResultHandler):
other_colors = ['lightblue', 'darkgray']
other_colors = ["lightblue", "darkgray"]

categories = ['betti_0', 'betti_1', 'betti_2', 'name', 'orientability']
categories = ["betti_0", "betti_1", "betti_2", "name", "orientability"]
values, errors = result_handler.get_task_means()

plotting.prepare_for_latex()

fig, ax = plt.subplots(figsize=(8, 6))

bar_width = 0.6
betti_positions = np.arange(3)
other_positions = np.arange(3, 5) + 1
bar_width = 0.6
betti_positions = np.arange(3)
other_positions = np.arange(3, 5) + 1

for i, pos in enumerate(betti_positions):
ax.bar(pos, values[i], yerr=errors[i], color='lightcoral', width=bar_width, label=categories[i])
ax.bar(
pos,
values[i],
yerr=errors[i],
color="lightcoral",
width=bar_width,
label=categories[i],
)

for i, pos in enumerate(other_positions):
ax.bar(pos, values[i + 3], yerr=errors[i + 3], color=other_colors[i], width=bar_width, label=categories[i + 3])

ax.set_xticklabels( ['','betti_0', 'betti_1', 'betti_2', '','name', 'orientability'])
ax.set_ylabel("Mean Accuracy")
ax.bar(
pos,
values[i + 3],
yerr=errors[i + 3],
color=other_colors[i],
width=bar_width,
label=categories[i + 3],
)

ax.set_xticklabels(
["", "betti_0", "betti_1", "betti_2", "", "name", "orientability"]
)
ax.set_ylabel("Mean Accuracy/ F1 Score")
ax.set_xlabel("Task")
ax.grid(True, linestyle='--', linewidth=0.5)
plt.savefig("plot.pdf")
ax.grid(True, linestyle="--", linewidth=0.5)
plt.savefig("plot_overview.pdf")
10 changes: 7 additions & 3 deletions experiments/vis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def above_legend_args(ax):
)


def add_single_row_legend(ax: matplotlib.pyplot.Axes, title: str, **legend_args):
def add_single_row_legend(
ax: matplotlib.pyplot.Axes, title: str, **legend_args
):
# Extracting handles and labels
try:
h, l = legend_args.pop("legs")
Expand All @@ -49,7 +51,9 @@ def filter_duplicate_handles(ax):

handles, labels = ax.get_legend_handles_labels()
unique = [
(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]
(h, l)
for i, (h, l) in enumerate(zip(handles, labels))
if l not in labels[:i]
]
return zip(*unique)

Expand Down Expand Up @@ -144,4 +148,4 @@ def prepare_for_latex(preamble=""):
# \begin{figure*}
# \currentpage\pagedesign
# \end{figure*}
# \end{document}
# \end{document}
Loading

0 comments on commit a6db5d8

Please sign in to comment.