diff --git a/assets/benchmark/DAMAGE_S8_P1_AMERICAS/key.json b/assets/benchmark/DAMAGE_S8_P1_AMERICAS/key.json new file mode 100644 index 00000000..c7d6b0d0 --- /dev/null +++ b/assets/benchmark/DAMAGE_S8_P1_AMERICAS/key.json @@ -0,0 +1,13 @@ +{"answers": [ + ["Pharah", "Sojourn", "Echo"], + ["Sojourn", "Pharah", "Echo"], + ["Sojourn", "Widowmaker", "Genji"], + ["Sojourn", "Ashe", "Tracer"], + ["Sojourn", "Ashe", "Soldier 76"], + ["Tracer", "Genji", "Echo"], + ["Tracer", "Sojourn", "Echo"], + ["Soldier 76", "Sojourn", "Ashe"], + ["Widowmaker", "Tracer", "Cassidy"], + ["Sojourn", "Ashe", "Sombra"] + ] +} diff --git a/assets/benchmark/DAMAGE_S8_P2_AMERICAS/key.json b/assets/benchmark/DAMAGE_S8_P2_AMERICAS/key.json new file mode 100644 index 00000000..73004c42 --- /dev/null +++ b/assets/benchmark/DAMAGE_S8_P2_AMERICAS/key.json @@ -0,0 +1,14 @@ +{ + "answers": [ + ["Sojourn", "Widowmaker", "Tracer"], + ["Sojourn", "Widowmaker", "Tracer"], + ["Sojourn", "Widowmaker", "Tracer"], + ["Tracer", "Genji", "Echo"], + ["Sojourn", "Hanzo", "Ashe"], + ["Tracer", "Sojourn", "Widowmaker"], + ["Genji", "Tracer", "Echo"], + ["Tracer", "Sojourn", "Ashe"], + ["Soldier 76", "Sojourn", "Ashe"], + ["Tracer", "Echo", "Genji"] + ] +} diff --git a/assets/benchmark/DAMAGE_S8_P3_AMERICAS/key.json b/assets/benchmark/DAMAGE_S8_P3_AMERICAS/key.json new file mode 100644 index 00000000..658f37f2 --- /dev/null +++ b/assets/benchmark/DAMAGE_S8_P3_AMERICAS/key.json @@ -0,0 +1,54 @@ +{ + "answers": [ + [ + "Tracer", + "Sojourn", + "Genji" + ], + [ + "Tracer", + "Hanzo", + "Widowmaker" + ], + [ + "Tracer", + "Sojourn", + "Genji" + ], + [ + "Tracer", + "Sojourn", + "Widowmaker" + ], + [ + "Sojourn", + "Widowmaker", + "Hanzo" + ], + [ + "Tracer", + "Sojourn", + "Mei" + ], + [ + "Sojourn", + "Tracer", + "Echo" + ], + [ + "Sojourn", + "Tracer", + "Hanzo" + ], + [ + "Sojourn", + "Sombra", + "Widowmaker" + ], + [ + "Sojourn", + "Sombra", + "Widowmaker" + ] + ] +} diff --git a/assets/benchmark/DAMAGE_S8_P4_AMERICAS/key.json b/assets/benchmark/DAMAGE_S8_P4_AMERICAS/key.json new file mode 100644 index 00000000..82b63be3 --- /dev/null +++ b/assets/benchmark/DAMAGE_S8_P4_AMERICAS/key.json @@ -0,0 +1,53 @@ +{"answers": [ + [ + "Bastion", + "Blank", + "Blank" + ], + [ + "Genji", + "Sojourn", + "Echo" + ], + [ + "Genji", + "Blank", + "Blank" + ], + [ + "Sojourn", + "Widowmaker", + "Hanzo" + ], + [ + "Sojourn", + "Tracer", + "Genji" + ], + [ + "Sojourn", + "Ashe", + "Tracer" + ], + [ + "Genji", + "Sojourn", + "Echo" + ], + [ + "Tracer", + "Sojourn", + "Genji" + ], + [ + "Sojourn", + "Widowmaker", + "Tracer" + ], + [ + "Genji", + "Tracer", + "Sojourn" + ] +] +} diff --git a/assets/benchmark/SUPPORT_S8_P1_AMERICAS/key.json b/assets/benchmark/SUPPORT_S8_P1_AMERICAS/key.json new file mode 100644 index 00000000..5d59e795 --- /dev/null +++ b/assets/benchmark/SUPPORT_S8_P1_AMERICAS/key.json @@ -0,0 +1,15 @@ +{ + "answers": [ + ["Ana", "Kiriko", "Zenyatta"], + ["Baptiste", "Kiriko", "Ana"], + ["Ana", "Kiriko", "Baptiste"], + ["Baptiste", "Kiriko", "Ana"], + ["Kiriko", "Baptiste", "Ana"], + ["Ana", "Baptiste", "Kiriko"], + ["Ana", "Kiriko", "Brigitte"], + ["Kiriko", "Baptiste", "Ana"], + ["Lucio", "Brigitte", "Kiriko"], + ["Lucio", "Brigitte", "Kiriko"] +] + +} diff --git a/assets/benchmark/SUPPORT_S8_P2_AMERICAS/key.json b/assets/benchmark/SUPPORT_S8_P2_AMERICAS/key.json new file mode 100644 index 00000000..81780af5 --- /dev/null +++ b/assets/benchmark/SUPPORT_S8_P2_AMERICAS/key.json @@ -0,0 +1,14 @@ +{ + "answers": [ + ["Kiriko", "Ana", "Baptiste"], + ["Ana", "Baptiste", "Kiriko"], + ["Lucio", "Brigitte", "Baptiste"], + ["Baptiste", "Kiriko", "Ana"], + ["Brigitte", "Blank", "Blank"], + ["Kiriko", "Baptiste", "Zenyatta"], + ["Ana", "Kiriko", "Illari"], + ["Ana", "Kiriko", "Illari"], + ["Lucio", "Zenyatta", "Ana"], + ["Kiriko", "Mercy", "Moira"] + ] +} diff --git a/assets/benchmark/SUPPORT_S8_P3_AMERICAS/key.json b/assets/benchmark/SUPPORT_S8_P3_AMERICAS/key.json new file mode 100644 index 00000000..9c5936ab --- /dev/null +++ b/assets/benchmark/SUPPORT_S8_P3_AMERICAS/key.json @@ -0,0 +1,14 @@ +{ + "answers": [ + ["Kiriko", "Mercy", "Moira"], + ["Ana", "Baptiste", "Kiriko"], + ["Lucio", "Brigitte", "Kiriko"], + ["Ana", "Baptiste", "Kiriko"], + ["Baptiste", "Kiriko", "Ana"], + ["Lucio", "Kiriko", "Ana"], + ["Kiriko", "Lucio", "Brigitte"], + ["Mercy", "Lucio", "Zenyatta"], + ["Kiriko", "Ana", "Baptiste"], + ["Lucio", "Brigitte", "Baptiste"] + ] +} diff --git a/assets/benchmark/SUPPORT_S8_P4_AMERICAS/key.json b/assets/benchmark/SUPPORT_S8_P4_AMERICAS/key.json new file mode 100644 index 00000000..a5b79109 --- /dev/null +++ b/assets/benchmark/SUPPORT_S8_P4_AMERICAS/key.json @@ -0,0 +1,15 @@ +{ + "answers": [ + ["Lucio", "Brigitte", "Baptiste"], + ["Kiriko", "Baptiste", "Lucio"], + ["Ana", "LifeWeaver", "Illari"], + ["Kiriko", "Baptiste", "Ana"], + ["Ana", "Kiriko", "Zenyatta"], + ["Baptiste", "Lucio", "Ana"], + ["Kiriko", "Baptiste", "Ana"], + ["Kiriko", "Lucio", "Baptiste"], + ["Kiriko", "Brigitte", "Zenyatta"], + ["Baptiste", "Ana", "Kiriko"] +] + +} diff --git a/assets/benchmark/TANK_S8_P1_AMERICAS/key.json b/assets/benchmark/TANK_S8_P1_AMERICAS/key.json new file mode 100644 index 00000000..4484cd5f --- /dev/null +++ b/assets/benchmark/TANK_S8_P1_AMERICAS/key.json @@ -0,0 +1,52 @@ +{"answers": [ + [ + "Ramattra", + "Sigma", + "Junker Queen" + ], + [ + "Mauga", + "Junker Queen", + "Roadhog" + ], + [ + "Junker Queen", + "Sigma", + "Mauga" + ], + [ + "Roadhog", + "Mauga", + "Doomfist" + ], + [ + "Orisa", + "Junker Queen", + "Zarya" + ], + [ + "Sigma", + "Junker Queen", + "Ramattra" + ], + [ + "Mauga", + "Ramattra", + "Roadhog" + ], + [ + "Roadhog", + "Sigma", + "Zarya" + ], + [ + "Junker Queen", + "Mauga", + "Roadhog" + ], + [ + "Junker Queen", + "Roadhog", + "Ramattra" + ] +]} diff --git a/assets/benchmark/TANK_S8_P2_AMERICAS/key.json b/assets/benchmark/TANK_S8_P2_AMERICAS/key.json new file mode 100644 index 00000000..5fd441b6 --- /dev/null +++ b/assets/benchmark/TANK_S8_P2_AMERICAS/key.json @@ -0,0 +1,52 @@ +{"answers": [ + [ + "Mauga", + "Wrecking Ball", + "Ramattra" + ], + [ + "Mauga", + "Sigma", + "Zarya" + ], + [ + "Junker Queen", + "Zarya", + "Sigma" + ], + [ + "Junker Queen", + "Zarya", + "Sigma" + ], + [ + "Roadhog", + "Sigma", + "Junker Queen" + ], + [ + "Roadhog", + "Sigma", + "Junker Queen" + ], + [ + "Junker Queen", + "Roadhog", + "Ramattra" + ], + [ + "Roadhog", + "Blank", + "Blank" + ], + [ + "Mauga", + "Junker Queen", + "Sigma" + ], + [ + "Zarya", + "Ramattra", + "D.Va" + ] +]} diff --git a/assets/benchmark/TANK_S8_P3_AMERICAS/key.json b/assets/benchmark/TANK_S8_P3_AMERICAS/key.json new file mode 100644 index 00000000..7861ccd3 --- /dev/null +++ b/assets/benchmark/TANK_S8_P3_AMERICAS/key.json @@ -0,0 +1,52 @@ +{"answers": [ + [ + "Junker Queen", + "Sigma", + "Mauga" + ], + [ + "Doomfist", + "Sigma", + "Roadhog" + ], + [ + "Sigma", + "Winston", + "Junker Queen" + ], + [ + "Mauga", + "Winston", + "Sigma" + ], + [ + "Junker Queen", + "Sigma", + "Zarya" + ], + [ + "Sigma", + "Junker Queen", + "Mauga" + ], + [ + "Doomfist", + "Junker Queen", + "Winston" + ], + [ + "Ramattra", + "Junker Queen", + "Sigma" + ], + [ + "Mauga", + "Ramattra", + "Winston" + ], + [ + "Sigma", + "Ramattra", + "Junker Queen" + ] +]} diff --git a/assets/benchmark/TANK_S8_P4_AMERICAS/key.json b/assets/benchmark/TANK_S8_P4_AMERICAS/key.json new file mode 100644 index 00000000..b7a48bee --- /dev/null +++ b/assets/benchmark/TANK_S8_P4_AMERICAS/key.json @@ -0,0 +1,52 @@ +{"answers": [ + [ + "Mauga", + "Sigma", + "Doomfist" + ], + [ + "Sigma", + "Junker Queen", + "Winston" + ], + [ + "Sigma", + "Ramattra", + "Junker Queen" + ], + [ + "Sigma", + "Ramattra", + "Junker Queen" + ], + [ + "Mauga", + "Zarya", + "Ramattra" + ], + [ + "Junker Queen", + "Ramattra", + "Reinhardt" + ], + [ + "Mauga", + "Blank", + "Blank" + ], + [ + "Sigma", + "Junker Queen", + "Ramattra" + ], + [ + "Roadhog", + "Mauga", + "Sigma" + ], + [ + "Junker Queen", + "Sigma", + "D.Va" + ] +]} diff --git a/heroes/hero_comparison.py b/heroes/hero_comparison.py index 9b03ad48..de4c1780 100644 --- a/heroes/hero_comparison.py +++ b/heroes/hero_comparison.py @@ -3,8 +3,8 @@ import torch import importlib from pathlib import Path - - +from typing import Any +model_cache: dict[Any, Any] = dict() class Heroes: def __init__(self): self.hero_labels: dict[int, str] = { @@ -137,12 +137,18 @@ def __init__(self): "Mauga": "#DC847D", } + def predict_hero_name(self, image: Image, model_directory: Path) -> str: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") NNModel = importlib.import_module(f"models.{model_directory.name}").FrozenNeuralNetworkModel transformer = importlib.import_module(f"models.{model_directory.name}").transformer - st_dict = torch.load(model_directory / "model.pth") + + if model_directory.name not in model_cache.keys(): + model_cache[model_directory.name] = torch.load(model_directory / "model.pth") + st_dict = model_cache[model_directory.name] + + model = NNModel(num_classes=40) model.to(device) model.load_state_dict(st_dict) diff --git a/utils/benchmarks.py b/utils/benchmarks.py index 5a034d44..a7584266 100644 --- a/utils/benchmarks.py +++ b/utils/benchmarks.py @@ -1,5 +1,7 @@ import sys +from PIL import Image + sys.path.append("./") # used to import from root directory import json import os @@ -77,17 +79,16 @@ def main(): passed_tests: int = 0 failed_tests: int = 0 - answers = load_answers("./assets/test_leaderboard_images") + answers = load_answers("./assets/benchmark") heroes_present: set[str] = set() for image, heroes in answers.items(): # iter answer key-value pairs result: list[ leaderboards.LeaderboardEntry - ] = leaderboards.parse( # parse leaderboard - image_path=f"./assets/test_leaderboard_images/{image}/LB-IMG.png", - assets_path="./assets/hero_images", + ] = leaderboards.parse_leaderboard_to_leaderboard_entries( # parse leaderboard + leaderboard_image=Image.open(f"./assets/benchmark/{image}/LB-IMG.png"), region=leaderboards.Region.AMERICAS, # doesnt matter role=leaderboards.Role.DAMAGE, # doesnt matter - model_name="thearyadev-2023-12-20", + model_name="thearyadev-initial-15-02-2024", ) for entry, answer in zip(result, heroes): # validate results @@ -119,9 +120,8 @@ def main(): print(f"Passed tests: {passed_tests}") print(f"Failed tests: {failed_tests}") print(f"[yellow bold]Success rate: {round(passed_tests / total_tests * 100, 2)}%") - if len(heroes_present) == 0: - print( - f"Heroes not present in answer set: {heroes_present.symmetric_difference(set(HeroComparisonClass.Heroes('./assets/hero_images').hero_labels.values()))}" + print( + f"Heroes not present in answer set: {heroes_present.symmetric_difference(set(HeroComparisonClass.Heroes().hero_labels.values()))}" )