Skip to content

Commit

Permalink
Merge pull request #66 from deepghs/dev/classify
Browse files Browse the repository at this point in the history
dev(narugo): add generic classify method
  • Loading branch information
narugo1992 authored Jan 24, 2024
2 parents 8813aa0 + ee1850b commit d716669
Show file tree
Hide file tree
Showing 38 changed files with 4,813 additions and 4,385 deletions.
11 changes: 6 additions & 5 deletions docs/source/api_doc/validate/aicheck_benchmark.plot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.generic.classify import _open_models_for_repo_id
from imgutils.validate import get_ai_created_score
from imgutils.validate.aicheck import _MODEL_NAMES
from imgutils.validate.aicheck import _REPO_ID

_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names


class AnimeAICheckBenchmark(BaseBenchmark):
Expand All @@ -11,12 +14,10 @@ def __init__(self, model):
self.model = model

def load(self):
from imgutils.validate.aicheck import _open_anime_aicheck_model
_ = _open_anime_aicheck_model(self.model)
_open_models_for_repo_id(_REPO_ID)._open_model(self.model)

def unload(self):
from imgutils.validate.aicheck import _open_anime_aicheck_model
_open_anime_aicheck_model.cache_clear()
_open_models_for_repo_id(_REPO_ID).clear()

def run(self):
image_file = random.choice(self.all_images)
Expand Down
Loading

0 comments on commit d716669

Please sign in to comment.