Skip to content

Commit

Permalink
add dhash prediction function
Browse files Browse the repository at this point in the history
  • Loading branch information
thearyadev committed Mar 11, 2024
1 parent 81065e2 commit 850d13b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
24 changes: 23 additions & 1 deletion heroes/hero_comparison.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations
from PIL.Image import Image
from PIL import Image as ImageOpen, ImageFilter
from functools import lru_cache
import torch
import importlib
from pathlib import Path
from typing import Any

import imagehash
model_cache: dict[Any, Any] = dict()
class Heroes:
def __init__(self):
Expand Down Expand Up @@ -138,7 +142,7 @@ def __init__(self):
}


def predict_hero_name(self, image: Image, model_directory: Path) -> str:
def predict_hero_name_neural_network(self, image: Image, model_directory: Path) -> str:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NNModel = importlib.import_module(f"models.{model_directory.name}").FrozenNeuralNetworkModel
transformer = importlib.import_module(f"models.{model_directory.name}").transformer
Expand All @@ -164,3 +168,21 @@ def predict_hero_name(self, image: Image, model_directory: Path) -> str:
prediction = int(classes[int(torch.argmax(output, dim=1).item())])

return self.hero_labels[prediction]
def predict_hero_name_dhash_comparison(self, image: Image, model_directory: Path) -> str:
input_hash = imagehash.dhash(transform(image))
hashes = []
heroes = []
for f in Path("./assets/heroes").iterdir():
hashes.append(get_image_hash(f))
heroes.append(f.name.replace(".png", "").replace("2", ""))

diffs: list[str] = [abs(input_hash - h) for h in hashes]
ret = sorted(zip(diffs, heroes), key=lambda x: x[0])
return ret[0][1]

@lru_cache
def get_image_hash(path: Path):
return imagehash.dhash(transform(ImageOpen.open(path)))

def transform(img: Image):
return img
2 changes: 1 addition & 1 deletion leaderboards/leaderboard_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def parse_leaderboard_to_leaderboard_entries(leaderboard_image: ImageType, regio
for row in split_column_entries: # each record (10)

results.append(LeaderboardEntry(
heroes=[hero_comp.predict_hero_name(hero_image, Path(f"./models/{model_name}")) for hero_image in row],
heroes=[hero_comp.predict_hero_name_dhash_comparison(hero_image, Path(f"./models/{model_name}")) for hero_image in row],
role=role,
region=region

Expand Down

0 comments on commit 850d13b

Please sign in to comment.