Skip to content

Commit

Permalink
result table formatting and tabulate dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Aug 19, 2024
1 parent 8c0b8d3 commit 8ef4adb
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion code/experiments/generate_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def process_betti(result_csv_prefix: str):


def process():
results_csv_prefix = "ignore_temp_"
results_csv_prefix = "results_"

res_betti = process_betti(results_csv_prefix)
res_name_or = process_name_orientability(results_csv_prefix)
Expand Down
8 changes: 3 additions & 5 deletions code/experiments/utils/results_processing/per_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from models.models import ModelType
from datasets.transforms import TransformType
import numpy as np
from .utils import get_matching_indeces, get_metric_col_names, get_result_path
from .utils import get_matching_indeces, get_metric_col_names, get_result_path, format_res_val


def reduce(
Expand Down Expand Up @@ -40,10 +40,8 @@ def reduce(
get_matching_indeces(df, model_type, transform_type)
]
metric_results.append(filtered_results[metric].max())
row[metric[5:]] = (
f"{np.max(metric_results):.4f} (std:{np.std(metric_results):.4f})"
)

row[metric[5:]] = format_res_val(np.max(metric_results), np.std(metric_results))

new_row_df = pd.DataFrame([row])
concat_df = (
[df_results, new_row_df] if len(df_results) > 0 else [new_row_df]
Expand Down
5 changes: 2 additions & 3 deletions code/experiments/utils/results_processing/per_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from models.models import ModelType
from datasets.transforms import TransformType
import numpy as np
from .utils import get_matching_indeces, get_metric_col_names, get_result_path
from .utils import get_matching_indeces, get_metric_col_names, get_result_path, format_res_val


def reduce(
Expand Down Expand Up @@ -36,7 +36,6 @@ def reduce(
]
max_ = filtered_results[metric].max()
model_results.append(max_)

metric_res_list.append(np.max(model_results))
return metric_res_list

Expand Down Expand Up @@ -84,7 +83,7 @@ def per_task(
)

if metric != "test_loss":
row = {"Metric": metric, "Mean": np.mean(res_for_metric)}
row = {"Metric": metric, "Mean": format_res_val(np.mean(res_for_metric))}
new_row_df = pd.DataFrame([row])
concat_df = (
[df_results, new_row_df]
Expand Down
6 changes: 2 additions & 4 deletions code/experiments/utils/results_processing/per_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from models.models import ModelType
from datasets.transforms import TransformType
import numpy as np
from .utils import get_matching_indeces, get_metric_col_names, get_result_path
from .utils import get_matching_indeces, get_metric_col_names, get_result_path, format_res_val


def reduce(
Expand Down Expand Up @@ -37,9 +37,7 @@ def reduce(
]
results.append(filtered_results[metric].max())

row[transform_type.name.lower()] = (
f"{np.max(results):.4f} (std:{np.std(results):.4f})"
)
row[transform_type.name.lower()] = format_res_val(np.max(results), np.std(results))

return row

Expand Down
9 changes: 8 additions & 1 deletion code/experiments/utils/results_processing/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from typing import List, Tuple
from typing import List, Tuple, Optional
import pandas as pd
from metrics.tasks import TaskType
from models.models import ModelType
from datasets.transforms import TransformType

def format_res_val(
value: float, std: Optional[float] = None
):
if std is None:
return f"{value:.2f}"
else:
return f"{value:.2f} ({std:.2f} SD)"

def get_matching_indeces(
df: pd.DataFrame, model_type: ModelType, transform_type: TransformType
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ torchvision = "0.18.0"
gensim = "4.3.2"
omegaconf = "2.3.0"
pydantic-settings = "2.2.1"
tabulate = "0.9.0"

[tool.poetry.dev-dependencies]
black = {version = "^21.12b0", allow-prereleases = true}
Expand Down

0 comments on commit 8ef4adb

Please sign in to comment.