diff --git a/heroes/hero_comparison.py b/heroes/hero_comparison.py index 669d3757..9cdcca58 100644 --- a/heroes/hero_comparison.py +++ b/heroes/hero_comparison.py @@ -1,9 +1,13 @@ from __future__ import annotations from PIL.Image import Image +from PIL import Image as ImageOpen, ImageFilter +from functools import lru_cache import torch import importlib from pathlib import Path from typing import Any + +import imagehash model_cache: dict[Any, Any] = dict() class Heroes: def __init__(self): @@ -138,7 +142,7 @@ def __init__(self): } - def predict_hero_name(self, image: Image, model_directory: Path) -> str: + def predict_hero_name_neural_network(self, image: Image, model_directory: Path) -> str: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") NNModel = importlib.import_module(f"models.{model_directory.name}").FrozenNeuralNetworkModel transformer = importlib.import_module(f"models.{model_directory.name}").transformer @@ -164,3 +168,21 @@ def predict_hero_name(self, image: Image, model_directory: Path) -> str: prediction = int(classes[int(torch.argmax(output, dim=1).item())]) return self.hero_labels[prediction] + def predict_hero_name_dhash_comparison(self, image: Image, model_directory: Path) -> str: + input_hash = imagehash.dhash(transform(image)) + hashes = [] + heroes = [] + for f in Path("./assets/heroes").iterdir(): + hashes.append(get_image_hash(f)) + heroes.append(f.name.replace(".png", "").replace("2", "")) + + diffs: list[str] = [abs(input_hash - h) for h in hashes] + ret = sorted(zip(diffs, heroes), key=lambda x: x[0]) + return ret[0][1] + +@lru_cache +def get_image_hash(path: Path): + return imagehash.dhash(transform(ImageOpen.open(path))) + +def transform(img: Image): + return img diff --git a/leaderboards/leaderboard_parser.py b/leaderboards/leaderboard_parser.py index ad13c7b1..129fdf7d 100644 --- a/leaderboards/leaderboard_parser.py +++ b/leaderboards/leaderboard_parser.py @@ -109,7 +109,7 @@ def parse_leaderboard_to_leaderboard_entries(leaderboard_image: ImageType, regio for row in split_column_entries: # each record (10) results.append(LeaderboardEntry( - heroes=[hero_comp.predict_hero_name(hero_image, Path(f"./models/{model_name}")) for hero_image in row], + heroes=[hero_comp.predict_hero_name_dhash_comparison(hero_image, Path(f"./models/{model_name}")) for hero_image in row], role=role, region=region