-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4b204a8
commit 665eb48
Showing
6 changed files
with
530 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .result_handler import ResultHandler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import matplotlib.pyplot as plt | ||
import plotting | ||
import numpy as np | ||
from .result_handler import ResultHandler | ||
|
||
def overview_barplot(result_handler: ResultHandler): | ||
other_colors = ['lightblue', 'darkgray'] | ||
|
||
categories = ['betti_0', 'betti_1', 'betti_2', 'name', 'orientability'] | ||
values, errors = result_handler.get_task_means() | ||
|
||
plotting.prepare_for_latex() | ||
|
||
fig, ax = plt.subplots(figsize=(8, 6)) | ||
|
||
bar_width = 0.6 | ||
betti_positions = np.arange(3) | ||
other_positions = np.arange(3, 5) + 1 | ||
|
||
for i, pos in enumerate(betti_positions): | ||
ax.bar(pos, values[i], yerr=errors[i], color='lightcoral', width=bar_width, label=categories[i]) | ||
|
||
for i, pos in enumerate(other_positions): | ||
ax.bar(pos, values[i + 3], yerr=errors[i + 3], color=other_colors[i], width=bar_width, label=categories[i + 3]) | ||
|
||
ax.set_xticklabels( ['','betti_0', 'betti_1', 'betti_2', '','name', 'orientability']) | ||
ax.set_ylabel("Mean Accuracy") | ||
ax.set_xlabel("Task") | ||
ax.grid(True, linestyle='--', linewidth=0.5) | ||
plt.savefig("plot.pdf") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import random | ||
|
||
import matplotlib | ||
import matplotlib.lines as mlines | ||
import matplotlib.patches as mpatches | ||
import matplotlib.pyplot as plt | ||
import matplotlib.ticker | ||
import numpy as np | ||
|
||
import experiments.vis.us_cmap as us_cmap | ||
|
||
""" | ||
> \TU/OpenSansLight(0)/m/n/10.95 . | ||
<recently read> \font | ||
""" | ||
|
||
|
||
def above_legend_args(ax): | ||
return dict( | ||
loc="lower center", | ||
bbox_to_anchor=(0.5, 1.0), | ||
bbox_transform=ax.transAxes, | ||
borderaxespad=0.25, | ||
) | ||
|
||
|
||
def add_single_row_legend(ax: matplotlib.pyplot.Axes, title: str, **legend_args): | ||
# Extracting handles and labels | ||
try: | ||
h, l = legend_args.pop("legs") | ||
except KeyError: | ||
h, l = ax.get_legend_handles_labels() | ||
ph = mlines.Line2D([], [], color="white") | ||
handles = [ph] + h | ||
labels = [title] + l | ||
legend_args["ncol"] = legend_args.get("ncol", len(handles)) | ||
leg = ax.legend(handles, labels, **legend_args) | ||
for vpack in leg._legend_handle_box.get_children()[:1]: | ||
for hpack in vpack.get_children()[:1]: | ||
hpack.get_children()[0].set_width(-30) | ||
|
||
|
||
def filter_duplicate_handles(ax): | ||
""" | ||
usage: ax.legend(*filter_duplicate_handles(ax), kwargs...) | ||
:param ax: | ||
:return: | ||
""" | ||
|
||
handles, labels = ax.get_legend_handles_labels() | ||
unique = [ | ||
(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i] | ||
] | ||
return zip(*unique) | ||
|
||
|
||
class MaxTickSciFormatter(matplotlib.ticker.Formatter): | ||
""" | ||
Only formats ticks that are above a given maximum. Useful for log plots, where the last tick label is not shown. | ||
Usage: ax.yaxis.set_minor_formatter(MaxTickSciFormatter(last_tick_value)) | ||
""" | ||
|
||
def __init__(self, last_tick_value): | ||
""" | ||
:param last_tick_value: format all labels with an x/y value equal or above this value | ||
""" | ||
super().__init__() | ||
self.last_tick_value = last_tick_value | ||
self._sci_formatter = matplotlib.ticker.LogFormatterSciNotation() | ||
|
||
def __call__(self, x, pos=None): | ||
if x >= self.last_tick_value: | ||
return self._sci_formatter(x, pos) | ||
else: | ||
return "" | ||
|
||
|
||
def get_dimensions(height=140, num_cols=1, half_size=False): | ||
# \showthe\columnwidth | ||
fac = 0.48 if half_size else 1 | ||
single_col_pts = 426.79134999999251932 * fac | ||
double_col_pts = 426.79134999999251932 * fac | ||
inches_per_pt = 1 / 72.27 | ||
|
||
if num_cols == 1: | ||
width_inches = ( | ||
single_col_pts * inches_per_pt + 0.23 | ||
) # added default matplotlib padding | ||
elif num_cols == 2: | ||
width_inches = double_col_pts * inches_per_pt + 0.23 | ||
else: | ||
width_inches = single_col_pts * num_cols * inches_per_pt + 0.23 | ||
|
||
height_inches = height * inches_per_pt | ||
return width_inches, height_inches | ||
|
||
|
||
def prepare_matplotlib(): | ||
us_cmap.activate() | ||
params = { | ||
"savefig.pad_inches": 0.0, | ||
"savefig.bbox": "tight", | ||
"savefig.transparent": True, | ||
"font.family": "sans-serif", | ||
"mathtext.fontset": "dejavuserif", | ||
"font.size": 10.95, | ||
"xtick.labelsize": 10.95, | ||
"ytick.labelsize": 10.95, | ||
"axes.titlesize": 10.95, | ||
"axes.labelsize": 10.95, | ||
"legend.fontsize": 10.95, | ||
"figure.titlesize": 10.95, | ||
"figure.autolayout": True, | ||
"axes.labelweight": "normal", | ||
"axes.titleweight": "normal", | ||
"legend.columnspacing": 0.75, | ||
"legend.handlelength": 1, | ||
"legend.handletextpad": 0.2, | ||
"legend.frameon": False, | ||
"legend.borderpad": 0, | ||
} | ||
matplotlib.rcParams.update(params) | ||
|
||
|
||
def prepare_for_latex(preamble=""): | ||
if "siunitx" not in preamble: | ||
preamble += "\n" + r"\usepackage{siunitx}" | ||
prepare_matplotlib() | ||
params = { | ||
"backend": "pgf", | ||
"text.usetex": True, | ||
"text.latex.preamble": preamble, | ||
"pgf.texsystem": "pdflatex", | ||
"pgf.rcfonts": True, | ||
"pgf.preamble": preamble, | ||
"axes.unicode_minus": False, | ||
} | ||
matplotlib.rcParams.update(params) | ||
|
||
|
||
# \documentclass{article} | ||
# \usepackage{layouts} | ||
# \begin{document} | ||
# \begin{figure*} | ||
# \currentpage\pagedesign | ||
# \end{figure*} | ||
# \end{document} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import pandas as pd | ||
import wandb | ||
from typing import Tuple, List | ||
|
||
def setup_wandb(wandb_project_id: str= "mantra-dev-run-3"): | ||
wandb.login() | ||
api = wandb.Api() | ||
runs = api.runs(wandb_project_id) | ||
|
||
full_history, config_list, name_list = [], [], [] | ||
for run in runs: | ||
h = run._full_history() | ||
h = [ r | run.config for r in h] | ||
full_history.extend(h) | ||
|
||
import json | ||
with open("raw_results.json","w") as f: | ||
json.dump(full_history,f) | ||
return full_history | ||
|
||
|
||
|
||
|
||
|
||
|
||
def convert_history(full_history): | ||
df = pd.DataFrame(full_history) | ||
df = df.set_index([ "task","model_name","node_features", "run_id"],inplace=False) | ||
mean_df = df.groupby(level=[0,1,2]).mean() | ||
mean_df = mean_df.add_suffix("_mean") | ||
|
||
std_df = df.groupby(level=[0,1,2]).std() | ||
std_df = std_df.add_suffix("_std") | ||
|
||
res_df = pd.concat([mean_df,std_df],join="outer",axis=1) | ||
|
||
res_df.to_csv("results.csv") | ||
|
||
df__ = res_df[[ | ||
"validation_accuracy_mean", | ||
"validation_accuracy_std", | ||
"train_accuracy_mean", | ||
"train_accuracy_std", | ||
"validation_accuracy_betti_0_mean", | ||
"validation_accuracy_betti_0_std", | ||
"validation_accuracy_betti_1_mean", | ||
"validation_accuracy_betti_1_std", | ||
"validation_accuracy_betti_2_mean", | ||
"validation_accuracy_betti_2_std", | ||
"train_accuracy_betti_0_mean", | ||
"train_accuracy_betti_0_std", | ||
"train_accuracy_betti_1_mean", | ||
"train_accuracy_betti_1_std", | ||
"train_accuracy_betti_2_mean", | ||
"train_accuracy_betti_2_std", | ||
]] | ||
|
||
return df__ | ||
|
||
|
||
def process_df(df): | ||
reshaped_df = pd.DataFrame(columns=["Task", "Model Name", "Node Features", "Mean Accuracy", "Std Accuracy", "Mean Train Accuracy", "Std Train Accuracy"]) | ||
|
||
if 'task' not in df.columns: | ||
df.reset_index(inplace=True) | ||
|
||
for index, row in df.iterrows(): | ||
if pd.notna(row['validation_accuracy_betti_0_mean']): | ||
val_acc_std_betti = lambda i: f'validation_accuracy_betti_{int(i)}_std' | ||
val_acc_mean_betti = lambda i: f'validation_accuracy_betti_{int(i)}_mean' | ||
train_acc_betti_mean = lambda i: f'train_accuracy_betti_{int(i)}_std' | ||
train_acc_betti_std = lambda i: f'validation_accuracy_betti_{int(i)}_mean' | ||
betti_task = lambda i: f'betti_{int(i)}' | ||
|
||
for i in range(3): | ||
|
||
new_row_dict = { | ||
"Task": [betti_task(i)], | ||
"Model Name": [row['model_name']], | ||
"Node Features": [row['node_features']], | ||
"Mean Accuracy": [row[val_acc_mean_betti(i)]], | ||
"Std Accuracy": [row[val_acc_std_betti(i)]], | ||
"Mean Train Accuracy": [row[train_acc_betti_mean(i)]], | ||
"Std Train Accuracy": [row[train_acc_betti_std(i)]] | ||
} | ||
new_row = pd.DataFrame(new_row_dict, index=[0]) | ||
reshaped_df = pd.concat([reshaped_df, new_row], ignore_index=True) | ||
else: | ||
new_row_dict = { | ||
"Task": [row['task']], | ||
"Model Name": [row['model_name']], | ||
"Node Features": [row['node_features']], | ||
"Mean Accuracy": [row['validation_accuracy_mean']], | ||
"Std Accuracy": [row['validation_accuracy_std']], | ||
"Mean Train Accuracy": [row['train_accuracy_mean']], | ||
"Std Train Accuracy": [row['train_accuracy_std']] | ||
} | ||
new_row = pd.DataFrame(new_row_dict, index=[0]) | ||
reshaped_df = pd.concat([reshaped_df, new_row], ignore_index=True) | ||
return reshaped_df | ||
|
||
|
||
def pull_from_wandb(): | ||
full_history = setup_wandb() | ||
df = convert_history(full_history) | ||
df = process_df(df) | ||
return df | ||
|
||
|
||
class ResultHandler: | ||
df: pd.DataFrame | ||
|
||
def __init__(self, wandb_project_id: str = "mantra-dev-run-3") -> None: | ||
full_history = setup_wandb(wandb_project_id=wandb_project_id) | ||
df = convert_history(full_history) | ||
df = process_df(df) | ||
self.df = df | ||
|
||
def get(self): | ||
return self.df | ||
|
||
def get_task_means(self) -> Tuple[List[float], List[float]]: | ||
categories = ['betti_0', 'betti_1', 'betti_2', 'name', 'orientability']# categories = self.df.to_numpy() | ||
values = [] | ||
errors = [] | ||
|
||
for c in categories: | ||
indeces = self.df['Task'] == c | ||
filtered_df = self.df.iloc[indeces.to_numpy()] | ||
mean_accuracy = filtered_df['Mean Accuracy'].mean() | ||
std_accuracy = filtered_df['Std Accuracy'].mean() | ||
mean_train_accuracy = filtered_df['Mean Train Accuracy'].mean() | ||
std_train_accuracy = filtered_df['Std Train Accuracy'].mean() | ||
values.append(mean_accuracy) | ||
errors.append(std_accuracy) | ||
|
||
return values, errors |
Oops, something went wrong.