From a6db5d899349d62766a863d000589e6b5ac4ebb3 Mon Sep 17 00:00:00 2001 From: Daniel Bin Schmid Date: Mon, 24 Jun 2024 13:00:36 +0200 Subject: [PATCH] separate script for training and benchmarking --- .gitignore | 3 +- experiments/configs.py | 5 + experiments/run_experiment.py | 30 ++++- experiments/vis/__init__.py | 2 +- experiments/vis/overview_barplot.py | 41 +++++-- experiments/vis/plotting.py | 10 +- experiments/vis/result_handler.py | 174 +++++++++++++++++----------- experiments/vis/us_cmap.py | 13 ++- run.py | 40 ++++++- test.py | 47 ++++++++ test.sh | 4 + train.sh | 7 ++ 12 files changed, 275 insertions(+), 101 deletions(-) create mode 100644 test.py create mode 100755 test.sh create mode 100755 train.sh diff --git a/.gitignore b/.gitignore index 1c84865..35281a7 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ .ignore* raw_data data -test*.* **/__pycache__ /lightning_logs/ k-simplex* @@ -9,6 +8,8 @@ RandomWalks* results wandb configs +*.lock +*.vscode # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/experiments/configs.py b/experiments/configs.py index 2cde7be..7a1970b 100644 --- a/experiments/configs.py +++ b/experiments/configs.py @@ -10,6 +10,7 @@ from models.models import ModelType import yaml from typing import Any, List +import os class TrainerConfig(BaseSettings): @@ -43,6 +44,10 @@ class ConfigExperimentRun(BaseSettings): discriminator=Discriminator(get_discriminator_value) ) + def get_checkpoint_path(self, base_folder: str): + identifier = f"{self.transforms.name.lower()}_{self.task_type.name.lower()}_{self.conf_model.type.name.lower()}_seed_{self.seed}.ckpt" + return os.path.join(base_folder, identifier) + def load_config(config_fpath: str) -> ConfigExperimentRun: with open(config_fpath, "r") as file: diff --git a/experiments/run_experiment.py b/experiments/run_experiment.py index bb71ede..99e26c5 100644 --- a/experiments/run_experiment.py +++ b/experiments/run_experiment.py @@ -10,17 +10,19 @@ from experiments.configs import ConfigExperimentRun from models import model_lookup from metrics.tasks import TaskType -from typing import Dict +from typing import Dict, Optional, Tuple from datasets.simplicial import SimplicialDataModule from models.base import BaseModel from experiments.loggers import get_wandb_logger import lightning as L import uuid from datasets.transforms import transforms_lookup +from lightning.pytorch.loggers import WandbLogger -def run_configuration(config: ConfigExperimentRun): - +def get_setup( + config: ConfigExperimentRun, +) -> Tuple[SimplicialDataModule, BaseModel, L.Trainer, WandbLogger]: run_id = str(uuid.uuid4()) transforms = transforms_lookup[config.transforms] task_lookup: Dict[TaskType, Task] = get_task_lookup(transforms) @@ -54,7 +56,7 @@ def run_configuration(config: ConfigExperimentRun): task_lookup[config.task_type].accuracies, task_lookup[config.task_type].loss_fn, learning_rate=config.learning_rate, - imbalance=imbalance + imbalance=imbalance, ) logger = get_wandb_logger( @@ -73,8 +75,28 @@ def run_configuration(config: ConfigExperimentRun): fast_dev_run=False, ) + return dm, lit_model, trainer, logger + + +def run_configuration( + config: ConfigExperimentRun, save_checkpoint_path: Optional[str] = None +): + dm, lit_model, trainer, logger = get_setup(config) + # run trainer.fit(lit_model, dm) logger.experiment.finish() + if save_checkpoint_path: + print(f"[INFO] Saving checkpoint here {save_checkpoint_path}") + trainer.save_checkpoint(save_checkpoint_path) + return trainer + + +def benchmark_configuration( + config: ConfigExperimentRun, save_checkpoint_path: str +): + dm, lit_model, trainer, logger = get_setup(config) + + trainer.test(lit_model, dm, save_checkpoint_path) diff --git a/experiments/vis/__init__.py b/experiments/vis/__init__.py index 50bf83e..5d66ae4 100644 --- a/experiments/vis/__init__.py +++ b/experiments/vis/__init__.py @@ -1 +1 @@ -from .result_handler import ResultHandler \ No newline at end of file +from .result_handler import ResultHandler diff --git a/experiments/vis/overview_barplot.py b/experiments/vis/overview_barplot.py index 2e3f7bd..56a5b62 100644 --- a/experiments/vis/overview_barplot.py +++ b/experiments/vis/overview_barplot.py @@ -3,28 +3,45 @@ import numpy as np from .result_handler import ResultHandler + def overview_barplot(result_handler: ResultHandler): - other_colors = ['lightblue', 'darkgray'] + other_colors = ["lightblue", "darkgray"] - categories = ['betti_0', 'betti_1', 'betti_2', 'name', 'orientability'] + categories = ["betti_0", "betti_1", "betti_2", "name", "orientability"] values, errors = result_handler.get_task_means() plotting.prepare_for_latex() fig, ax = plt.subplots(figsize=(8, 6)) - bar_width = 0.6 - betti_positions = np.arange(3) - other_positions = np.arange(3, 5) + 1 + bar_width = 0.6 + betti_positions = np.arange(3) + other_positions = np.arange(3, 5) + 1 for i, pos in enumerate(betti_positions): - ax.bar(pos, values[i], yerr=errors[i], color='lightcoral', width=bar_width, label=categories[i]) + ax.bar( + pos, + values[i], + yerr=errors[i], + color="lightcoral", + width=bar_width, + label=categories[i], + ) for i, pos in enumerate(other_positions): - ax.bar(pos, values[i + 3], yerr=errors[i + 3], color=other_colors[i], width=bar_width, label=categories[i + 3]) - - ax.set_xticklabels( ['','betti_0', 'betti_1', 'betti_2', '','name', 'orientability']) - ax.set_ylabel("Mean Accuracy") + ax.bar( + pos, + values[i + 3], + yerr=errors[i + 3], + color=other_colors[i], + width=bar_width, + label=categories[i + 3], + ) + + ax.set_xticklabels( + ["", "betti_0", "betti_1", "betti_2", "", "name", "orientability"] + ) + ax.set_ylabel("Mean Accuracy/ F1 Score") ax.set_xlabel("Task") - ax.grid(True, linestyle='--', linewidth=0.5) - plt.savefig("plot.pdf") \ No newline at end of file + ax.grid(True, linestyle="--", linewidth=0.5) + plt.savefig("plot_overview.pdf") diff --git a/experiments/vis/plotting.py b/experiments/vis/plotting.py index 7bd049b..2905110 100644 --- a/experiments/vis/plotting.py +++ b/experiments/vis/plotting.py @@ -24,7 +24,9 @@ def above_legend_args(ax): ) -def add_single_row_legend(ax: matplotlib.pyplot.Axes, title: str, **legend_args): +def add_single_row_legend( + ax: matplotlib.pyplot.Axes, title: str, **legend_args +): # Extracting handles and labels try: h, l = legend_args.pop("legs") @@ -49,7 +51,9 @@ def filter_duplicate_handles(ax): handles, labels = ax.get_legend_handles_labels() unique = [ - (h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] + (h, l) + for i, (h, l) in enumerate(zip(handles, labels)) + if l not in labels[:i] ] return zip(*unique) @@ -144,4 +148,4 @@ def prepare_for_latex(preamble=""): # \begin{figure*} # \currentpage\pagedesign # \end{figure*} -# \end{document} \ No newline at end of file +# \end{document} diff --git a/experiments/vis/result_handler.py b/experiments/vis/result_handler.py index 3831bf6..1cccc64 100644 --- a/experiments/vis/result_handler.py +++ b/experiments/vis/result_handler.py @@ -1,99 +1,121 @@ -import pandas as pd +import pandas as pd import wandb from typing import Tuple, List -def setup_wandb(wandb_project_id: str= "mantra-dev-run-3"): + +def setup_wandb(wandb_project_id: str = "mantra-dev-run-3"): wandb.login() api = wandb.Api() runs = api.runs(wandb_project_id) full_history, config_list, name_list = [], [], [] - for run in runs: + for run in runs: h = run._full_history() - h = [ r | run.config for r in h] - full_history.extend(h) + h = [r | run.config for r in h] + full_history.extend(h) import json - with open("raw_results.json","w") as f: - json.dump(full_history,f) - return full_history - - - + with open("raw_results.json", "w") as f: + json.dump(full_history, f) + return full_history def convert_history(full_history): df = pd.DataFrame(full_history) - df = df.set_index([ "task","model_name","node_features", "run_id"],inplace=False) - mean_df = df.groupby(level=[0,1,2]).mean() + df = df.set_index( + ["task", "model_name", "node_features", "run_id"], inplace=False + ) + mean_df = df.groupby(level=[0, 1, 2]).mean() mean_df = mean_df.add_suffix("_mean") - std_df = df.groupby(level=[0,1,2]).std() + std_df = df.groupby(level=[0, 1, 2]).std() std_df = std_df.add_suffix("_std") - res_df = pd.concat([mean_df,std_df],join="outer",axis=1) + res_df = pd.concat([mean_df, std_df], join="outer", axis=1) res_df.to_csv("results.csv") - df__ = res_df[[ - "validation_accuracy_mean", - "validation_accuracy_std", - "train_accuracy_mean", - "train_accuracy_std", - "validation_accuracy_betti_0_mean", - "validation_accuracy_betti_0_std", - "validation_accuracy_betti_1_mean", - "validation_accuracy_betti_1_std", - "validation_accuracy_betti_2_mean", - "validation_accuracy_betti_2_std", - "train_accuracy_betti_0_mean", - "train_accuracy_betti_0_std", - "train_accuracy_betti_1_mean", - "train_accuracy_betti_1_std", - "train_accuracy_betti_2_mean", - "train_accuracy_betti_2_std", - ]] - + df__ = res_df[ + [ + "validation_accuracy_mean", + "validation_accuracy_std", + "train_accuracy_mean", + "train_accuracy_std", + "validation_accuracy_betti_0_mean", + "validation_accuracy_betti_0_std", + "validation_accuracy_betti_1_mean", + "validation_accuracy_betti_1_std", + "validation_accuracy_betti_2_mean", + "validation_accuracy_betti_2_std", + "train_accuracy_betti_0_mean", + "train_accuracy_betti_0_std", + "train_accuracy_betti_1_mean", + "train_accuracy_betti_1_std", + "train_accuracy_betti_2_mean", + "train_accuracy_betti_2_std", + ] + ] + return df__ def process_df(df): - reshaped_df = pd.DataFrame(columns=["Task", "Model Name", "Node Features", "Mean Accuracy", "Std Accuracy", "Mean Train Accuracy", "Std Train Accuracy"]) - - if 'task' not in df.columns: + reshaped_df = pd.DataFrame( + columns=[ + "Task", + "Model Name", + "Node Features", + "Mean Accuracy", + "Std Accuracy", + "Mean Train Accuracy", + "Std Train Accuracy", + ] + ) + + if "task" not in df.columns: df.reset_index(inplace=True) for index, row in df.iterrows(): - if pd.notna(row['validation_accuracy_betti_0_mean']): - val_acc_std_betti = lambda i: f'validation_accuracy_betti_{int(i)}_std' - val_acc_mean_betti = lambda i: f'validation_accuracy_betti_{int(i)}_mean' - train_acc_betti_mean = lambda i: f'train_accuracy_betti_{int(i)}_std' - train_acc_betti_std = lambda i: f'validation_accuracy_betti_{int(i)}_mean' - betti_task = lambda i: f'betti_{int(i)}' + if pd.notna(row["validation_accuracy_betti_0_mean"]): + val_acc_std_betti = ( + lambda i: f"validation_accuracy_betti_{int(i)}_std" + ) + val_acc_mean_betti = ( + lambda i: f"validation_accuracy_betti_{int(i)}_mean" + ) + train_acc_betti_mean = ( + lambda i: f"train_accuracy_betti_{int(i)}_std" + ) + train_acc_betti_std = ( + lambda i: f"validation_accuracy_betti_{int(i)}_mean" + ) + betti_task = lambda i: f"betti_{int(i)}" for i in range(3): new_row_dict = { - "Task": [betti_task(i)], - "Model Name": [row['model_name']], - "Node Features": [row['node_features']], - "Mean Accuracy": [row[val_acc_mean_betti(i)]], - "Std Accuracy": [row[val_acc_std_betti(i)]], - "Mean Train Accuracy": [row[train_acc_betti_mean(i)]], - "Std Train Accuracy": [row[train_acc_betti_std(i)]] + "Task": [betti_task(i)], + "Model Name": [row["model_name"]], + "Node Features": [row["node_features"]], + "Mean Accuracy": [row[val_acc_mean_betti(i)]], + "Std Accuracy": [row[val_acc_std_betti(i)]], + "Mean Train Accuracy": [row[train_acc_betti_mean(i)]], + "Std Train Accuracy": [row[train_acc_betti_std(i)]], } new_row = pd.DataFrame(new_row_dict, index=[0]) - reshaped_df = pd.concat([reshaped_df, new_row], ignore_index=True) + reshaped_df = pd.concat( + [reshaped_df, new_row], ignore_index=True + ) else: new_row_dict = { - "Task": [row['task']], - "Model Name": [row['model_name']], - "Node Features": [row['node_features']], - "Mean Accuracy": [row['validation_accuracy_mean']], - "Std Accuracy": [row['validation_accuracy_std']], - "Mean Train Accuracy": [row['train_accuracy_mean']], - "Std Train Accuracy": [row['train_accuracy_std']] + "Task": [row["task"]], + "Model Name": [row["model_name"]], + "Node Features": [row["node_features"]], + "Mean Accuracy": [row["validation_accuracy_mean"]], + "Std Accuracy": [row["validation_accuracy_std"]], + "Mean Train Accuracy": [row["train_accuracy_mean"]], + "Std Train Accuracy": [row["train_accuracy_std"]], } new_row = pd.DataFrame(new_row_dict, index=[0]) reshaped_df = pd.concat([reshaped_df, new_row], ignore_index=True) @@ -115,23 +137,35 @@ def __init__(self, wandb_project_id: str = "mantra-dev-run-3") -> None: df = convert_history(full_history) df = process_df(df) self.df = df - + def get(self): return self.df - + def get_task_means(self) -> Tuple[List[float], List[float]]: - categories = ['betti_0', 'betti_1', 'betti_2', 'name', 'orientability']# categories = self.df.to_numpy() + categories = [ + "betti_0", + "betti_1", + "betti_2", + "name", + "orientability", + ] # categories = self.df.to_numpy() values = [] - errors = [] + errors = [] for c in categories: - indeces = self.df['Task'] == c - filtered_df = self.df.iloc[indeces.to_numpy()] - mean_accuracy = filtered_df['Mean Accuracy'].mean() - std_accuracy = filtered_df['Std Accuracy'].mean() - mean_train_accuracy = filtered_df['Mean Train Accuracy'].mean() - std_train_accuracy = filtered_df['Std Train Accuracy'].mean() + indeces = self.df["Task"] == c + filtered_df = self.df.iloc[indeces.to_numpy()] + mean_accuracy = filtered_df["Mean Accuracy"].mean() + std_accuracy = filtered_df["Std Accuracy"].mean() + mean_train_accuracy = filtered_df["Mean Train Accuracy"].mean() + std_train_accuracy = filtered_df["Std Train Accuracy"].mean() values.append(mean_accuracy) errors.append(std_accuracy) - - return values, errors \ No newline at end of file + + return values, errors + + def save(self, path: str): + raise NotImplementedError() + + def load(self, path: str): + raise NotImplementedError() diff --git a/experiments/vis/us_cmap.py b/experiments/vis/us_cmap.py index cc032a5..ecdb2f8 100644 --- a/experiments/vis/us_cmap.py +++ b/experiments/vis/us_cmap.py @@ -70,7 +70,9 @@ def activate(): scale_white_amount(values, step) ) -matplotlib.cm.register_cmap(name="US", cmap=matplotlib.colors.ListedColormap(list_cmap)) +matplotlib.cm.register_cmap( + name="US", cmap=matplotlib.colors.ListedColormap(list_cmap) +) register_name("others", (200, 200, 200)) register_name("graytext", (120, 120, 120)) @@ -94,7 +96,9 @@ def activate(): for base_name in cdict: step_labels = [f"{step * 100:.0f}" for step in steps] for name in [base_name] + [f"{base_name}!{sl}" for sl in step_labels]: - colors[name] = ", ".join(f"{v:.3f}" for v in matplotlib.colors.to_rgb(name)) + colors[name] = ", ".join( + f"{v:.3f}" for v in matplotlib.colors.to_rgb(name) + ) out = Path("uni-stuttgart-colors.sty") out.write_text( @@ -104,7 +108,8 @@ def activate(): ) out.write_text( "\n".join( - f"\\definecolor{{{name}}}{{rgb}}{{{rgb}}}" for name, rgb in colors.items() + f"\\definecolor{{{name}}}{{rgb}}{{{rgb}}}" + for name, rgb in colors.items() ) ) - print(f"wrote color map to {out.as_posix()}") \ No newline at end of file + print(f"wrote color map to {out.as_posix()}") diff --git a/run.py b/run.py index 126ae67..b3b9b36 100644 --- a/run.py +++ b/run.py @@ -1,19 +1,35 @@ -from experiments.configs import load_config +from experiments.configs import load_config, ConfigExperimentRun from experiments.run_experiment import run_configuration import os import argparse -from typing import Dict, Any +from typing import Dict, Any, Optional -def run_configs_folder(args_dict: Dict[str, Any]): +def print_info(config: ConfigExperimentRun): + print("[INFO] Using configuration:", config) + + +def run_configs_folder( + args_dict: Dict[str, Any], checkpoint_folder: Optional[str] = None +): config_dir = "./configs" files = os.listdir(config_dir) for file in files: for _ in range(5): config_file = os.path.join(config_dir, file) config = load_config(config_file) + print("[INFO] Using configuration file:", config_file) + print_info(config) + + checkpoint_path = None + if checkpoint_folder: + checkpoint_path = config.get_checkpoint_path(checkpoint_folder) + print("[INFO] Using checkpoint folder:", checkpoint_path) + else: + print("[INFO] No checkpoint folder.") + config.logging.wandb_project_id = args_dict["wandb"] - run_configuration(config) + run_configuration(config, save_checkpoint_path=checkpoint_path) if __name__ == "__main__": @@ -32,6 +48,12 @@ def run_configs_folder(args_dict: Dict[str, Any]): default="./config.yaml", help="Path to .yaml configuration for experiment if running 'single mode.", ) + parser.add_argument( + "--checkpoints", + type=str, + default=None, + help="Path where the model checkpoints are stored.", + ) parser.add_argument( "--wandb", type=str, default="mantra-dev", help="Wandb project id." ) @@ -40,11 +62,17 @@ def run_configs_folder(args_dict: Dict[str, Any]): args_dict = vars(args) if args_dict["mode"] == "all": - run_configs_folder() + run_configs_folder( + args_dict, checkpoint_folder=args_dict["checkpoints"] + ) exit(0) if args_dict["mode"] == "single": config = load_config(args_dict["config"]) config.logging.wandb_project_id = args_dict["wandb"] - run_configuration(config) + print_info(config) + config_path = None + if args_dict["checkpoints"]: + config_path = config.get_checkpoint_path(args_dict["checkpoints"]) + run_configuration(config, save_checkpoint_path=config_path) exit(0) diff --git a/test.py b/test.py new file mode 100644 index 0000000..bd92e7a --- /dev/null +++ b/test.py @@ -0,0 +1,47 @@ +from experiments.configs import load_config, ConfigExperimentRun +from experiments.run_experiment import ( + run_configuration, + benchmark_configuration, +) +import os +import argparse +from typing import Dict, Any, Optional + + +def test(config: ConfigExperimentRun, checkpoint_path: str): + print("[INFO] Testing with config", config) + print("[INFO] Testing with checkpoint path:", checkpoint_path) + + benchmark_configuration( + config=config, save_checkpoint_path=checkpoint_path + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Argument parser for experiment configurations." + ) + parser.add_argument( + "--mode", + type=str, + default="single", + help="'all' for running all configurations in the ./configs folder, or 'single' for running a single model.", + ) + parser.add_argument( + "--config", + type=str, + default="./config.yaml", + help="Path to .yaml configuration for experiment if running 'single mode.", + ) + parser.add_argument( + "--checkpoints", + type=str, + help="Path where the model checkpoints are stored.", + ) + + args = parser.parse_args() + args_dict = vars(args) + + config = load_config(args_dict["config"]) + checkpoint_path = config.get_checkpoint_path(args_dict["checkpoints"]) + test(config=config, checkpoint_path=checkpoint_path) diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..34821c6 --- /dev/null +++ b/test.sh @@ -0,0 +1,4 @@ +python ./test.py \ + --mode "single" \ + --config "./configs/gat_betti_numbers_degree_transform_onehot.yaml" \ + --checkpoints "./checkpoints" \ No newline at end of file diff --git a/train.sh b/train.sh new file mode 100755 index 0000000..301a188 --- /dev/null +++ b/train.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +python ./run.py \ + --mode "single" \ + --config "./configs/gat_betti_numbers_degree_transform_onehot.yaml" \ + --wandb "mantra-proj" \ + --checkpoints "./checkpoints"