Skip to content

Commit

Permalink
Merge pull request #169 from thearyadev/168-phashing-for-image-comp
Browse files Browse the repository at this point in the history
168 dhashing for image comp
  • Loading branch information
thearyadev authored Mar 11, 2024
2 parents d0c3566 + 96c4037 commit 36276bb
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 8 deletions.
24 changes: 23 additions & 1 deletion heroes/hero_comparison.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations
from PIL.Image import Image
from PIL import Image as ImageOpen, ImageFilter
from functools import lru_cache
import torch
import importlib
from pathlib import Path
from typing import Any

import imagehash
model_cache: dict[Any, Any] = dict()
class Heroes:
def __init__(self):
Expand Down Expand Up @@ -138,7 +142,7 @@ def __init__(self):
}


def predict_hero_name(self, image: Image, model_directory: Path) -> str:
def predict_hero_name_neural_network(self, image: Image, model_directory: Path) -> str:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NNModel = importlib.import_module(f"models.{model_directory.name}").FrozenNeuralNetworkModel
transformer = importlib.import_module(f"models.{model_directory.name}").transformer
Expand All @@ -164,3 +168,21 @@ def predict_hero_name(self, image: Image, model_directory: Path) -> str:
prediction = int(classes[int(torch.argmax(output, dim=1).item())])

return self.hero_labels[prediction]
def predict_hero_name_dhash_comparison(self, image: Image, model_directory: Path) -> str:
input_hash = imagehash.dhash(transform(image))
hashes = []
heroes = []
for f in Path("./assets/heroes").iterdir():
hashes.append(get_image_hash(f))
heroes.append(f.name.replace(".png", "").replace("2", ""))

diffs: list[str] = [abs(input_hash - h) for h in hashes]
ret = sorted(zip(diffs, heroes), key=lambda x: x[0])
return ret[0][1]

@lru_cache
def get_image_hash(path: Path):
return imagehash.dhash(transform(ImageOpen.open(path)))

def transform(img: Image):
return img
2 changes: 1 addition & 1 deletion leaderboards/leaderboard_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def parse_leaderboard_to_leaderboard_entries(leaderboard_image: ImageType, regio
for row in split_column_entries: # each record (10)

results.append(LeaderboardEntry(
heroes=[hero_comp.predict_hero_name(hero_image, Path(f"./models/{model_name}")) for hero_image in row],
heroes=[hero_comp.predict_hero_name_dhash_comparison(hero_image, Path(f"./models/{model_name}")) for hero_image in row],
role=role,
region=region

Expand Down
98 changes: 97 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "^3.10"
imagehash = "^4.3.1"


[tool.poetry.group.dev.dependencies]
Expand All @@ -32,11 +33,6 @@ httpx = "^0.25.0"
torch = {version = "^2.2.0", source="torch"}
torchvision = {version = "^0.17.0", source="torch"}

[[tool.poetry.source]]
name = "torch"
url = "https://download.pytorch.org/whl/cu121"
secondary = true

[tool.poetry.group.server.dependencies]
hypercorn = "0.14.4"
numpy = "^1.24.3"
Expand All @@ -46,6 +42,11 @@ jinja2 = "^3.1.2"
fastapi = "0.101.1"
rich = "^13.3.1"

[[tool.poetry.source]]
name = "torch"
url = "https://download.pytorch.org/whl/cu121"
secondary = true

[tool.mypy]
check_untyped_defs = true
strict=true
Expand Down

0 comments on commit 36276bb

Please sign in to comment.