diff --git a/docs/images/leaderboard.svg b/docs/images/leaderboard.svg index 3955b4ea..46da272c 100644 --- a/docs/images/leaderboard.svg +++ b/docs/images/leaderboard.svg @@ -6,7 +6,7 @@ - 2024-10-11T08:04:26.434282 + 2024-10-11T17:12:02.777458 image/svg+xml @@ -43,529 +43,529 @@ L 297.171875 44.961439 L 297.171875 70.161439 L 241.371875 70.161439 L 241.371875 44.961439 -" clip-path="url(#p1915b87814)" style="fill: #0000ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0000ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #4747ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #4848ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #7373ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #3d3dff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #9090ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0000ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0000ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #a8a8ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #5151ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0808ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #1c1cff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #3535ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #6969ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #2424ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #8282ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #5f5fff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ffffff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #2e2eff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0000ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0707ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #acacff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #2424ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #8787ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #7575ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #5c5cff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0b0bff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #2929ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0000ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #2c2cff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #7272ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #d6d6ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #4646ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #7070ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #2525ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ffffff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #4c4cff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #3737ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #2929ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #6767ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #b7b7ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #3d3dff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #4848ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0404ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #4c4cff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0000ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #7474ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ccccff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0808ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #f5f5ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #3131ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #adadff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #5454ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #4d4dff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #4242ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #7d7dff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #2323ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0000ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #0000ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #b1b1ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #4444ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #efefff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #babaff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ddddff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ffffff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #aeaeff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #d4d4ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #6363ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #8d8dff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #8585ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #3f3fff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #f8f8ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #7070ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #6666ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #1717ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #c2c2ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ababff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ffffff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ffffff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #e2e2ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #a3a3ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ebebff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ffffff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #b5b5ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ffffff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #d9d9ff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #cbcbff"/> +" clip-path="url(#pbddcf3c616)" style="fill: #ffffff"/> @@ -1106,12 +1106,12 @@ z - - + @@ -1298,313 +1298,12 @@ z - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + @@ -1738,15 +1513,153 @@ z - - + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + + - + - - - + + + @@ -1830,28 +1743,27 @@ z - - - - - - - - - - + + + + + + + + + - - + + - + - + - + - + + - + - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -1981,28 +1937,28 @@ z - - - - - - - - - - + + + + + + + + + + - - + + - + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + @@ -2299,83 +2297,83 @@ z - - + + - + - + - - + + - - + + - + - - + + - - + + - - + + - + - - + + - + - - + + - + - - + + - + - - + + @@ -2459,563 +2457,563 @@ z - + - - - + + + - + - - + + - + - - + + - - + + - - - + + + - + - + - + - + - + - - + + - - + + - - + + - - + + - - - + + + - + - - + + - + - + - - + + - - - + + + - + - - + + - + - - + + - + - - + + - + - - + + - - + + - - + + - + - + - + - - + + - + - - + + - + - + - + - - + + - + - + - + - + - + - - + + - - + + - - + + - + - - + + - + - + - + - + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - - + + + - + - + - + - - + + - - + + - - + + - + - + - - + + - - + + - - + + - - + + - + - + - - + + - - - + + + - - + + - + - - + + - - - + + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - + - + - - + + - - - + + + - - + + - + - - + + - - - + + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - - + + - + - + @@ -3101,7 +3099,7 @@ z - + diff --git a/docs/images/starplot.png b/docs/images/starplot.png index 29499ca9..97e2cebd 100644 Binary files a/docs/images/starplot.png and b/docs/images/starplot.png differ diff --git a/docs/leaderboards.md b/docs/leaderboards.md index 4c3e54aa..e2aaba2a 100644 --- a/docs/leaderboards.md +++ b/docs/leaderboards.md @@ -5,7 +5,7 @@ hide: # Leaderboards -We evaluated the following FMs on the 6 supported WSI-classification tasks. We report *Balanced Accuracy* for binary & multiclass tasks and generalized Dice score (no background) for segmentation tasks. The score shows the average performance over 5 runs. +We evaluated the following FMs on the 6 supported WSI-classification tasks. We report *Balanced Accuracy* for binary & multiclass tasks and generalized Dice score (no background) for segmentation tasks. The score shows the average performance over 5 runs. Note the leaderboard orders from best to worst according to the average performance across all tasks, excluding BACH (not comparable due to much larger patch size).
diff --git a/tools/data/leaderboard.csv b/tools/data/leaderboard.csv index 2eb33307..c78416df 100644 --- a/tools/data/leaderboard.csv +++ b/tools/data/leaderboard.csv @@ -1,14 +1,12 @@ -model,bach,crc,mhist,patch_camelyon,camelyon16_small,panda_small,consep,monusac -dino_vitb8_kaiko,0.867,0.952,0.814,0.921,0.818,0.638,0.703,0.641 -bioptimus_h_optimus_0,0.767,0.951,0.836,0.942,0.82,0.645,0.69,0.588 -histai_hibou_l,0.81,0.934,0.823,0.949,0.832,0.633,0.69,0.586 -dino_vitl16_uni,0.797,0.95,0.835,0.939,0.834,0.656,0.662,0.554 -dino_vits8_kaiko,0.825,0.948,0.826,0.887,0.814,0.654,0.688,0.599 -prov_gigapath,0.758,0.953,0.814,0.948,0.814,0.664,0.661,0.558 -dino_vitl14_kaiko,0.862,0.935,0.822,0.907,0.812,0.65,0.679,0.59 -dino_vitb16_kaiko,0.846,0.959,0.839,0.906,0.816,0.621,0.636,0.551 -owkin_phikon,0.715,0.942,0.766,0.925,0.797,0.64,0.68,0.54 -dino_vits16_kaiko,0.8,0.949,0.831,0.902,0.789,0.618,0.611,0.549 -dino_vits16_lunit,0.77,0.936,0.751,0.905,0.767,0.625,0.63,0.537 -dino_vits16_imagenet,0.675,0.936,0.827,0.861,0.685,0.545,0.531,0.495 -dino_vits16_random,0.411,0.613,0.5,0.752,0.568,0.297,0.5,0.374 \ No newline at end of file +bach,crc,mhist,patch_camelyon,camelyon16_small,panda_small,consep,monusac,model +0.77,0.936,0.751,0.905,0.767,0.625,0.63,0.537,dino_vits16_lunit +0.715,0.942,0.766,0.925,0.797,0.64,0.68,0.54,owkin_phikon +0.797,0.95,0.835,0.939,0.834,0.656,0.662,0.554,dino_vitl16_uni +0.767,0.951,0.836,0.942,0.82,0.645,0.69,0.588,bioptimus_h_optimus_0 +0.758,0.953,0.814,0.948,0.814,0.664,0.661,0.558,prov_gigapath +0.81,0.934,0.823,0.949,0.832,0.633,0.69,0.586,histai_hibou_l +0.8,0.949,0.831,0.902,0.789,0.618,0.611,0.549,dino_vits16_kaiko +0.825,0.948,0.826,0.887,0.814,0.654,0.688,0.599,dino_vits8_kaiko +0.846,0.959,0.839,0.906,0.816,0.621,0.636,0.551,dino_vitb16_kaiko +0.867,0.952,0.814,0.921,0.818,0.638,0.703,0.641,dino_vitb8_kaiko +0.862,0.935,0.822,0.907,0.812,0.65,0.679,0.59,dino_vitl14_kaiko diff --git a/tools/generate_leaderboard_plot.py b/tools/generate_leaderboard_plot.py deleted file mode 100644 index 2914439e..00000000 --- a/tools/generate_leaderboard_plot.py +++ /dev/null @@ -1,124 +0,0 @@ -# Run this script with `python tools/generate_leaderboard_plot.py` -# to create the image of the leaderboard heatmap displayed in -# docs/leaderboards.md -# -# Note: the code below assumes that the eva results are stored in -# `eva/logs//`. - -import os -import json -import argparse -import pandas as pd -import matplotlib.pyplot as plt -import matplotlib.colors -import seaborn as sns - - -def main(): - _tasks_to_metric = { - "bach": "MulticlassAccuracy", - "crc": "MulticlassAccuracy", - "mhist": "BinaryBalancedAccuracy", - "patch_camelyon": "BinaryBalancedAccuracy", - "camelyon16_small": "BinaryBalancedAccuracy", - "panda_small": "MulticlassAccuracy", - "consep": "GeneralizedDiceScore", - "monusac": "GeneralizedDiceScore", - } - _fm_name_map = { - "dino_vits16_lunit": "Lunit - ViT-S16 | TCGA", - "owkin_phikon": "Owkin (Phikon) - iBOT ViT-B16 | TCGA", - "dino_vitl16_uni": "UNI - DINOv2 ViT-L16 | Mass-100k", - "dino_vits16_kaiko": "kaiko.ai - DINO ViT-S16 | TCGA", - "dino_vits8_kaiko": "kaiko.ai - DINO ViT-S8 | TCGA", - "dino_vitb16_kaiko": "kaiko.ai - DINO ViT-B16 | TCGA", - "dino_vitb8_kaiko": "kaiko.ai - DINO ViT-B8 | TCGA", - "dino_vitl14_kaiko": "kaiko.ai - DINOv2 ViT-L14 | TCGA", - "bioptimus_h_optimus_0": "H-optimus-0 - ViT-G14 | 500k slides", - "prov_gigapath": "Prov-GigaPath - DINOv2 ViT-G14 | 181k slides", - "histai_hibou_l": "hibou-L - DINOv2 ViT-B14 | 1M slides", - } - _tasks_names_map = { - "bach": "BACH", - "crc": "CRC", - "mhist": "MHIST", - "patch_camelyon": "PCam", - "camelyon16_small": "Cam16Small", - "panda_small": "PANDASmall", - "consep": "CoNSeP", - "monusac": "MoNuSAC", - } - - # get log_dir from arg parser - parser = argparse.ArgumentParser() - parser.add_argument("--logs_dir", type=str, default="logs") - parser.add_argument("--output_file", type=str, default="docs/images/leaderboard.svg") - args = parser.parse_args() - - # load existing leaderboard if available: - if os.path.isfile("tools/data/leaderboard.csv"): - df_existing = pd.read_csv("tools/data/leaderboard.csv") - else: - df_existing = pd.DataFrame() - - # load results into data frame: - if args.logs_dir: - all_scores = [] - for model in _fm_name_map.keys(): - scores = [] - for task in _tasks_to_metric.keys(): - results_folder = [ - d for d in os.listdir(f"{args.logs_dir}/{task}/{model}") if d.startswith("20") - ][0] - with open( - os.path.join(f"{args.logs_dir}/{task}/{model}/{results_folder}/results.json") - ) as f: - d = json.load(f) - split = "test" if d["metrics"]["test"] else "val" - metric = _tasks_to_metric.get(task) - if metric is None: - raise Exception(f"no metric defined for task {task}") - scores.append(round(d["metrics"][split][0][f"{split}/{metric}"]["mean"], 3)) - all_scores.append(scores) - df = pd.DataFrame(all_scores, columns=_tasks_to_metric.keys()) - df["model"] = _fm_name_map.keys() - - # combine existing and new data frame - df = pd.concat([df, df_existing]).drop_duplicates() - df.to_csv("tools/data/leaderboard.csv", index=False) - else: - df = df_existing - - df = df.set_index("model", drop=True) - # create colormap: - colors = [[0, "white"], [1, "#0000ff"]] - cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", colors) - - # prepare data frame: - df = df[[fm in _fm_name_map.keys() for fm in df.index]] - df.index = df.reset_index()["model"].apply(lambda x: _fm_name_map.get(x) or x) - df.index.names = [""] - df.loc[:, "overall_performance"] = df.mean(axis=1) - df = df.sort_values(by="overall_performance", ascending=False) - df = df.drop(columns=["overall_performance"]) - df.columns = [_tasks_names_map.get(c) or c for c in df.columns] - scaled_df = (df - df.min(axis=0)) / (df.max(axis=0) - df.min(axis=0)) - - # create plot: - fig, ax = plt.subplots(figsize=(8, 5)) - sns.heatmap(scaled_df, annot=df, cmap=cmap, ax=ax, cbar=False, fmt=".3f") - plt.tick_params( - axis="x", - which="major", - labelsize=10, - labelbottom=False, - bottom=False, - top=False, - labeltop=True, - rotation=20, - ) - plt.savefig(args.output_file, format="svg", dpi=1200, bbox_inches="tight") - - -if __name__ == "__main__": - main() diff --git a/tools/generate_leaderboard_plots.py b/tools/generate_leaderboard_plots.py new file mode 100644 index 00000000..3468fa76 --- /dev/null +++ b/tools/generate_leaderboard_plots.py @@ -0,0 +1,235 @@ +# Run this script with `python tools/generate_leaderboard_plot.py` +# to create the image of the leaderboard heatmap and starplot displayed +# in docs/leaderboards.md +# +# Note: the code below assumes that the eva results are stored in +# `eva/logs///results`. + +import os +import json +import argparse +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.colors +import numpy as np +import seaborn as sns + +from typing import Optional + + +_tasks_to_metric = { + "bach": "MulticlassAccuracy", + "crc": "MulticlassAccuracy", + "mhist": "BinaryBalancedAccuracy", + "patch_camelyon": "BinaryBalancedAccuracy", + "camelyon16_small": "BinaryBalancedAccuracy", + "panda_small": "MulticlassAccuracy", + "consep": "GeneralizedDiceScore", + "monusac": "GeneralizedDiceScore", +} +_fm_name_map = { + "dino_vits16_lunit": "Lunit - ViT-S16 | TCGA", + "owkin_phikon": "Owkin (Phikon) - iBOT ViT-B16 | TCGA", + "dino_vitl16_uni": "UNI - DINOv2 ViT-L16 | Mass-100k", + "bioptimus_h_optimus_0": "H-optimus-0 - ViT-G14 | 500k slides", + "prov_gigapath": "Prov-GigaPath - DINOv2 ViT-G14 | 181k slides", + "histai_hibou_l": "hibou-L - DINOv2 ViT-B14 | 1M slides", + "dino_vits16_kaiko": "kaiko.ai - DINO ViT-S16 | TCGA", + "dino_vits8_kaiko": "kaiko.ai - DINO ViT-S8 | TCGA", + "dino_vitb16_kaiko": "kaiko.ai - DINO ViT-B16 | TCGA", + "dino_vitb8_kaiko": "kaiko.ai - DINO ViT-B8 | TCGA", + "dino_vitl14_kaiko": "kaiko.ai - DINOv2 ViT-L14 | TCGA", +} +_tasks_names_map = { + "bach": "BACH", + "crc": "CRC", + "mhist": "MHIST", + "patch_camelyon": "PCam", + "camelyon16_small": "Cam16Small", + "panda_small": "PANDASmall", + "consep": "CoNSeP", + "monusac": "MoNuSAC", +} +_colors_for_startplot = [ + "#7F7F7F", + "#FFC000", + "#C1E814", + "#FF0000", + "#D400FF", + "#4FFF87", + "#00A735", + "#6666FF", + "#0000FF", + "#00007F", + "#0000FF", +] +_label_offsets_startplot = { + "BACH": (0, -0.1), + "CRC": (0, -0.1), + "MHIST": (0.07, -0.1), + "PCam": (0.05, 0), + "Camelyon16": (-0.05, -0.08), + "PANDA": (0, -0.1), + "CoNSeP": (0.1, -0.07), + "MoNuSAC": (0.15, 0.08), +} + + +def get_leaderboard(logs_dir: Optional[str] = None) -> pd.DataFrame: + """Get the leaderboard data frame.""" + + # load existing leaderboard if available: + if os.path.isfile("tools/data/leaderboard.csv"): + df_existing = pd.read_csv("tools/data/leaderboard.csv") + else: + df_existing = pd.DataFrame() + + # load results into data frame: + if logs_dir: + all_scores = [] + for model in _fm_name_map.keys(): + scores = [] + for task in _tasks_to_metric.keys(): + run_folder = [ + d + for d in sorted(os.listdir(f"{logs_dir}/{task}/{model}/results")) + if d.startswith("20") + ][-1] + with open( + os.path.join(f"{logs_dir}/{task}/{model}/results/{run_folder}/results.json") + ) as f: + d = json.load(f) + split = "test" if d["metrics"]["test"] else "val" + metric = _tasks_to_metric.get(task) + if metric is None: + raise Exception(f"no metric defined for task {task}") + scores.append(round(d["metrics"][split][0][f"{split}/{metric}"]["mean"], 3)) + all_scores.append(scores) + df = pd.DataFrame(all_scores, columns=_tasks_to_metric.keys()) + df["model"] = _fm_name_map.keys() + + # combine existing and new data frame + df = pd.concat([df, df_existing]).drop_duplicates() + df.to_csv("tools/data/leaderboard.csv", index=False) + else: + df = df_existing + + df = df.set_index("model", drop=True) + df = df[[fm in _fm_name_map.keys() for fm in df.index]] + df.index = df.reset_index()["model"].apply(lambda x: _fm_name_map.get(x) or x) + return df + + +def plot_leaderboard(df: pd.DataFrame, output_file: str = "docs/images/leaderboard.svg"): + """Plot the leaderboard heatmap.""" + + # create colormap: + colors = [[0, "white"], [1, "#0000ff"]] + cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", colors) + + # prepare data frame: + df.index.names = [""] + + # exclude BACH from the average + df_for_avg = df[[c for c in df.columns if c != "bach"]] + + df.loc[:, "overall_performance"] = df_for_avg.mean(axis=1) + df = df.sort_values(by="overall_performance", ascending=False) + df = df.drop(columns=["overall_performance"]) + df.columns = [_tasks_names_map.get(c) or c for c in df.columns] + scaled_df = (df - df.min(axis=0)) / (df.max(axis=0) - df.min(axis=0)) + + # create plot: + fig, ax = plt.subplots(figsize=(8, 5)) + sns.heatmap(scaled_df, annot=df, cmap=cmap, ax=ax, cbar=False, fmt=".3f") + plt.tick_params( + axis="x", + which="major", + labelsize=10, + labelbottom=False, + bottom=False, + top=False, + labeltop=True, + rotation=20, + ) + plt.savefig(output_file, format="svg", dpi=1200, bbox_inches="tight") + + +def plot_startplot(df: pd.DataFrame, output_file: str = "docs/images/starplot.png"): + """Plot the star plot.""" + + plt.style.use("seaborn-v0_8-ticks") + + df = df[_tasks_to_metric.keys()] + datasets = _label_offsets_startplot.keys() + models = df.index.tolist() + accuracy_values_new = df.to_numpy() + angles = np.linspace(0, 2 * np.pi, len(datasets), endpoint=False).tolist() + angles += angles[:1] + accuracy_values = np.concatenate((accuracy_values_new, accuracy_values_new[:, [0]]), axis=1) + fig, ax = plt.subplots(figsize=(10, 10), subplot_kw=dict(polar=True)) + + for angle, label in zip(angles[:-1], datasets): + ha = "left" if angle < np.pi else "right" + va = "bottom" if angle < np.pi else "top" + offset_x, offset_y = _label_offsets_startplot[label] + ax.text( + angle + offset_x, + 1.1 + offset_y, + label, + horizontalalignment=ha, + size=20, + color="black", + verticalalignment=va, + ) + ax.plot([angle, angle], [0, 1], color="grey", linestyle="-", linewidth=0.5) + + ax.set_yticklabels([]) + ax.xaxis.set_visible(False) + + ax.set_rlabel_position(0) + y_ticks = [0.5, 0.6, 0.7, 0.8, 0.9] + plt.ylim(0.3, 0.98) + + for idx, (model, values) in enumerate(zip(models, accuracy_values)): + # if np.any(np.isnan(values)): + if np.any(pd.isna(values)): + values = np.nan_to_num(values) + color = _colors_for_startplot[idx % len(_colors_for_startplot)] + ax.plot( + angles, values, label=model, color=color, linewidth=5, linestyle="solid", alpha=0.45 + ) + + # Annotate y tick values slightly further from the axes + for tick in y_ticks: + for angle in angles[:-1]: + alignment = "left" if angle < np.pi else "right" + ax.text( + angle, + tick + 0.02, + f"{tick}", + horizontalalignment=alignment, + size=8, + color="grey", + verticalalignment="center", + ) + + legend = ax.legend(loc="upper right", bbox_to_anchor=(1.95, 1), title="", fontsize=18) + plt.savefig(output_file, bbox_inches="tight") + + +def main(): + # get log_dir from arg parser + parser = argparse.ArgumentParser() + parser.add_argument("--logs_dir", type=str, default="logs") + parser.add_argument("--output_leaderboard", type=str, default="docs/images/leaderboard.svg") + parser.add_argument("--output_starplot", type=str, default="docs/images/starplot.png") + args = parser.parse_args() + + leaderboard_df = get_leaderboard(args.logs_dir) + plot_leaderboard(leaderboard_df, args.output_leaderboard) + plot_startplot(leaderboard_df, args.output_starplot) + + +if __name__ == "__main__": + main()