Skip to content

Commit

Permalink
code for result interpretation
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Jun 18, 2024
1 parent 4b204a8 commit 665eb48
Show file tree
Hide file tree
Showing 6 changed files with 530 additions and 0 deletions.
1 change: 1 addition & 0 deletions experiments/vis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .result_handler import ResultHandler
30 changes: 30 additions & 0 deletions experiments/vis/overview_barplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import matplotlib.pyplot as plt
import plotting
import numpy as np
from .result_handler import ResultHandler

def overview_barplot(result_handler: ResultHandler):
other_colors = ['lightblue', 'darkgray']

categories = ['betti_0', 'betti_1', 'betti_2', 'name', 'orientability']
values, errors = result_handler.get_task_means()

plotting.prepare_for_latex()

fig, ax = plt.subplots(figsize=(8, 6))

bar_width = 0.6
betti_positions = np.arange(3)
other_positions = np.arange(3, 5) + 1

for i, pos in enumerate(betti_positions):
ax.bar(pos, values[i], yerr=errors[i], color='lightcoral', width=bar_width, label=categories[i])

for i, pos in enumerate(other_positions):
ax.bar(pos, values[i + 3], yerr=errors[i + 3], color=other_colors[i], width=bar_width, label=categories[i + 3])

ax.set_xticklabels( ['','betti_0', 'betti_1', 'betti_2', '','name', 'orientability'])
ax.set_ylabel("Mean Accuracy")
ax.set_xlabel("Task")
ax.grid(True, linestyle='--', linewidth=0.5)
plt.savefig("plot.pdf")
147 changes: 147 additions & 0 deletions experiments/vis/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import random

import matplotlib
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib.ticker
import numpy as np

import experiments.vis.us_cmap as us_cmap

"""
> \TU/OpenSansLight(0)/m/n/10.95 .
<recently read> \font
"""


def above_legend_args(ax):
return dict(
loc="lower center",
bbox_to_anchor=(0.5, 1.0),
bbox_transform=ax.transAxes,
borderaxespad=0.25,
)


def add_single_row_legend(ax: matplotlib.pyplot.Axes, title: str, **legend_args):
# Extracting handles and labels
try:
h, l = legend_args.pop("legs")
except KeyError:
h, l = ax.get_legend_handles_labels()
ph = mlines.Line2D([], [], color="white")
handles = [ph] + h
labels = [title] + l
legend_args["ncol"] = legend_args.get("ncol", len(handles))
leg = ax.legend(handles, labels, **legend_args)
for vpack in leg._legend_handle_box.get_children()[:1]:
for hpack in vpack.get_children()[:1]:
hpack.get_children()[0].set_width(-30)


def filter_duplicate_handles(ax):
"""
usage: ax.legend(*filter_duplicate_handles(ax), kwargs...)
:param ax:
:return:
"""

handles, labels = ax.get_legend_handles_labels()
unique = [
(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]
]
return zip(*unique)


class MaxTickSciFormatter(matplotlib.ticker.Formatter):
"""
Only formats ticks that are above a given maximum. Useful for log plots, where the last tick label is not shown.
Usage: ax.yaxis.set_minor_formatter(MaxTickSciFormatter(last_tick_value))
"""

def __init__(self, last_tick_value):
"""
:param last_tick_value: format all labels with an x/y value equal or above this value
"""
super().__init__()
self.last_tick_value = last_tick_value
self._sci_formatter = matplotlib.ticker.LogFormatterSciNotation()

def __call__(self, x, pos=None):
if x >= self.last_tick_value:
return self._sci_formatter(x, pos)
else:
return ""


def get_dimensions(height=140, num_cols=1, half_size=False):
# \showthe\columnwidth
fac = 0.48 if half_size else 1
single_col_pts = 426.79134999999251932 * fac
double_col_pts = 426.79134999999251932 * fac
inches_per_pt = 1 / 72.27

if num_cols == 1:
width_inches = (
single_col_pts * inches_per_pt + 0.23
) # added default matplotlib padding
elif num_cols == 2:
width_inches = double_col_pts * inches_per_pt + 0.23
else:
width_inches = single_col_pts * num_cols * inches_per_pt + 0.23

height_inches = height * inches_per_pt
return width_inches, height_inches


def prepare_matplotlib():
us_cmap.activate()
params = {
"savefig.pad_inches": 0.0,
"savefig.bbox": "tight",
"savefig.transparent": True,
"font.family": "sans-serif",
"mathtext.fontset": "dejavuserif",
"font.size": 10.95,
"xtick.labelsize": 10.95,
"ytick.labelsize": 10.95,
"axes.titlesize": 10.95,
"axes.labelsize": 10.95,
"legend.fontsize": 10.95,
"figure.titlesize": 10.95,
"figure.autolayout": True,
"axes.labelweight": "normal",
"axes.titleweight": "normal",
"legend.columnspacing": 0.75,
"legend.handlelength": 1,
"legend.handletextpad": 0.2,
"legend.frameon": False,
"legend.borderpad": 0,
}
matplotlib.rcParams.update(params)


def prepare_for_latex(preamble=""):
if "siunitx" not in preamble:
preamble += "\n" + r"\usepackage{siunitx}"
prepare_matplotlib()
params = {
"backend": "pgf",
"text.usetex": True,
"text.latex.preamble": preamble,
"pgf.texsystem": "pdflatex",
"pgf.rcfonts": True,
"pgf.preamble": preamble,
"axes.unicode_minus": False,
}
matplotlib.rcParams.update(params)


# \documentclass{article}
# \usepackage{layouts}
# \begin{document}
# \begin{figure*}
# \currentpage\pagedesign
# \end{figure*}
# \end{document}
137 changes: 137 additions & 0 deletions experiments/vis/result_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import pandas as pd
import wandb
from typing import Tuple, List

def setup_wandb(wandb_project_id: str= "mantra-dev-run-3"):
wandb.login()
api = wandb.Api()
runs = api.runs(wandb_project_id)

full_history, config_list, name_list = [], [], []
for run in runs:
h = run._full_history()
h = [ r | run.config for r in h]
full_history.extend(h)

import json
with open("raw_results.json","w") as f:
json.dump(full_history,f)
return full_history






def convert_history(full_history):
df = pd.DataFrame(full_history)
df = df.set_index([ "task","model_name","node_features", "run_id"],inplace=False)
mean_df = df.groupby(level=[0,1,2]).mean()
mean_df = mean_df.add_suffix("_mean")

std_df = df.groupby(level=[0,1,2]).std()
std_df = std_df.add_suffix("_std")

res_df = pd.concat([mean_df,std_df],join="outer",axis=1)

res_df.to_csv("results.csv")

df__ = res_df[[
"validation_accuracy_mean",
"validation_accuracy_std",
"train_accuracy_mean",
"train_accuracy_std",
"validation_accuracy_betti_0_mean",
"validation_accuracy_betti_0_std",
"validation_accuracy_betti_1_mean",
"validation_accuracy_betti_1_std",
"validation_accuracy_betti_2_mean",
"validation_accuracy_betti_2_std",
"train_accuracy_betti_0_mean",
"train_accuracy_betti_0_std",
"train_accuracy_betti_1_mean",
"train_accuracy_betti_1_std",
"train_accuracy_betti_2_mean",
"train_accuracy_betti_2_std",
]]

return df__


def process_df(df):
reshaped_df = pd.DataFrame(columns=["Task", "Model Name", "Node Features", "Mean Accuracy", "Std Accuracy", "Mean Train Accuracy", "Std Train Accuracy"])

if 'task' not in df.columns:
df.reset_index(inplace=True)

for index, row in df.iterrows():
if pd.notna(row['validation_accuracy_betti_0_mean']):
val_acc_std_betti = lambda i: f'validation_accuracy_betti_{int(i)}_std'
val_acc_mean_betti = lambda i: f'validation_accuracy_betti_{int(i)}_mean'
train_acc_betti_mean = lambda i: f'train_accuracy_betti_{int(i)}_std'
train_acc_betti_std = lambda i: f'validation_accuracy_betti_{int(i)}_mean'
betti_task = lambda i: f'betti_{int(i)}'

for i in range(3):

new_row_dict = {
"Task": [betti_task(i)],
"Model Name": [row['model_name']],
"Node Features": [row['node_features']],
"Mean Accuracy": [row[val_acc_mean_betti(i)]],
"Std Accuracy": [row[val_acc_std_betti(i)]],
"Mean Train Accuracy": [row[train_acc_betti_mean(i)]],
"Std Train Accuracy": [row[train_acc_betti_std(i)]]
}
new_row = pd.DataFrame(new_row_dict, index=[0])
reshaped_df = pd.concat([reshaped_df, new_row], ignore_index=True)
else:
new_row_dict = {
"Task": [row['task']],
"Model Name": [row['model_name']],
"Node Features": [row['node_features']],
"Mean Accuracy": [row['validation_accuracy_mean']],
"Std Accuracy": [row['validation_accuracy_std']],
"Mean Train Accuracy": [row['train_accuracy_mean']],
"Std Train Accuracy": [row['train_accuracy_std']]
}
new_row = pd.DataFrame(new_row_dict, index=[0])
reshaped_df = pd.concat([reshaped_df, new_row], ignore_index=True)
return reshaped_df


def pull_from_wandb():
full_history = setup_wandb()
df = convert_history(full_history)
df = process_df(df)
return df


class ResultHandler:
df: pd.DataFrame

def __init__(self, wandb_project_id: str = "mantra-dev-run-3") -> None:
full_history = setup_wandb(wandb_project_id=wandb_project_id)
df = convert_history(full_history)
df = process_df(df)
self.df = df

def get(self):
return self.df

def get_task_means(self) -> Tuple[List[float], List[float]]:
categories = ['betti_0', 'betti_1', 'betti_2', 'name', 'orientability']# categories = self.df.to_numpy()
values = []
errors = []

for c in categories:
indeces = self.df['Task'] == c
filtered_df = self.df.iloc[indeces.to_numpy()]
mean_accuracy = filtered_df['Mean Accuracy'].mean()
std_accuracy = filtered_df['Std Accuracy'].mean()
mean_train_accuracy = filtered_df['Mean Train Accuracy'].mean()
std_train_accuracy = filtered_df['Std Train Accuracy'].mean()
values.append(mean_accuracy)
errors.append(std_accuracy)

return values, errors
Loading

0 comments on commit 665eb48

Please sign in to comment.