From e4c98a5e776e7ad072883c48579657f96e7e3e86 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 9 Jan 2024 12:38:10 +0800 Subject: [PATCH 01/19] dev(narugo): add generic classify method --- imgutils/generic/__init__.py | 1 + imgutils/generic/classify.py | 113 +++++++++++++++++++++++++++++++++++ imgutils/validate/real.py | 96 ++--------------------------- 3 files changed, 120 insertions(+), 90 deletions(-) create mode 100644 imgutils/generic/__init__.py create mode 100644 imgutils/generic/classify.py diff --git a/imgutils/generic/__init__.py b/imgutils/generic/__init__.py new file mode 100644 index 00000000000..bbae43f1e65 --- /dev/null +++ b/imgutils/generic/__init__.py @@ -0,0 +1 @@ +from .classify import * diff --git a/imgutils/generic/classify.py b/imgutils/generic/classify.py new file mode 100644 index 00000000000..751423b53b3 --- /dev/null +++ b/imgutils/generic/classify.py @@ -0,0 +1,113 @@ +import json +import os +from functools import lru_cache +from typing import Tuple, Optional, List, Dict + +import numpy as np +from PIL import Image +from huggingface_hub import hf_hub_download, HfFileSystem + +from ..data import rgb_encode, ImageTyping, load_image +from ..utils import open_onnx_model + +__all__ = [ + 'ClassifyModel', + 'classify_predict_score', + 'classify_predict', +] + + +def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), + normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): + image = image.resize(size, Image.BILINEAR) + data = rgb_encode(image, order_='CHW') + + if normalize is not None: + mean_, std_ = normalize + mean = np.asarray([mean_]).reshape((-1, 1, 1)) + std = np.asarray([std_]).reshape((-1, 1, 1)) + data = (data - mean) / std + + return data.astype(np.float32) + + +class ClassifyModel: + def __init__(self, repo_id: str): + self.repo_id = repo_id + self._model_names = None + self._models = {} + self._labels = {} + + @classmethod + def _get_hf_token(cls): + return os.environ.get('HF_TOKEN') + + @property + def model_names(self) -> List[str]: + if self._model_names is None: + hf_fs = HfFileSystem(token=self._get_hf_token()) + self._model_names = [ + os.path.dirname(os.path.relpath(item, self.repo_id)) for item in + hf_fs.glob(f'{self.repo_id}/*/model.onnx') + ] + + return self._model_names + + def _check_model_name(self, model_name: str): + if model_name not in self.model_names: + raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, ' + f'models {self.model_names!r} are available.') + + def _open_model(self, model_name: str): + if model_name not in self._models: + self._check_model_name(model_name) + self._models[model_name] = open_onnx_model(hf_hub_download( + self.repo_id, + f'{model_name}/model.onnx', + token=self._get_hf_token(), + )) + return self._models[model_name] + + def _open_label(self, model_name: str) -> List[str]: + if model_name not in self._labels: + self._check_model_name(model_name) + with open(hf_hub_download( + self.repo_id, + f'{model_name}/meta.json', + token=self._get_hf_token(), + ), 'r') as f: + self._labels[model_name] = json.load(f)['labels'] + return self._labels[model_name] + + def _raw_predict(self, image: ImageTyping, model_name: str): + image = load_image(image, force_background='white', mode='RGB') + input_ = _img_encode(image)[None, ...] + output, = self._open_model(model_name).run(['output'], {'input': input_}) + return output + + def predict_score(self, image: ImageTyping, model_name: str) -> Dict[str, float]: + output = self._raw_predict(image, model_name) + values = dict(zip(self._open_label(model_name), map(lambda x: x.item(), output[0]))) + return values + + def predict(self, image: ImageTyping, model_name: str) -> Tuple[str, float]: + output = self._raw_predict(image, model_name)[0] + max_id = np.argmax(output) + return self._open_label(model_name)[max_id], output[max_id].item() + + def clear(self): + self._models.clear() + self._labels.clear() + + +@lru_cache() +def _open_models_for_repo_id(repo_id: str) -> ClassifyModel: + return ClassifyModel(repo_id) + + +def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str) -> Dict[str, float]: + return _open_models_for_repo_id(repo_id).predict_score(image, model_name) + + +def classify_predict(image: ImageTyping, repo_id: str, model_name: str) -> Tuple[str, float]: + return _open_models_for_repo_id(repo_id).predict(image, model_name) diff --git a/imgutils/validate/real.py b/imgutils/validate/real.py index ddf935ce0a7..33aca57473c 100644 --- a/imgutils/validate/real.py +++ b/imgutils/validate/real.py @@ -15,16 +15,10 @@ The models are hosted on `huggingface - deepghs/anime_real_cls `_. """ -import json -from functools import lru_cache -from typing import Tuple, Optional, Dict, List +from typing import Tuple, Dict -import numpy as np -from PIL import Image -from huggingface_hub import hf_hub_download - -from imgutils.data import rgb_encode, ImageTyping, load_image -from imgutils.utils import open_onnx_model +from ..data import ImageTyping +from ..generic import classify_predict, classify_predict_score __all__ = [ 'anime_real_score', @@ -32,81 +26,7 @@ ] _DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist' - - -@lru_cache() -def _open_anime_real_model(model_name): - """ - Open the anime real model. - - :param model_name: The model name. - :type model_name: str - :return: The ONNX model. - """ - return open_onnx_model(hf_hub_download( - f'deepghs/anime_real_cls', - f'{model_name}/model.onnx', - )) - - -@lru_cache() -def _get_anime_real_labels(model_name) -> List[str]: - """ - Get the labels for the anime real model. - - :param model_name: The model name. - :type model_name: str - :return: The list of labels. - :rtype: List[str] - """ - with open(hf_hub_download( - f'deepghs/anime_real_cls', - f'{model_name}/meta.json', - ), 'r') as f: - return json.load(f)['labels'] - - -def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), - normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): - """ - Encode the input image. - - :param image: The input image. - :type image: Image.Image - :param size: The desired size of the image. - :type size: Tuple[int, int] - :param normalize: Mean and standard deviation for normalization. Default is (0.5, 0.5). - :type normalize: Optional[Tuple[float, float]] - :return: The encoded image data. - :rtype: np.ndarray - """ - image = image.resize(size, Image.BILINEAR) - data = rgb_encode(image, order_='CHW') - - if normalize is not None: - mean_, std_ = normalize - mean = np.asarray([mean_]).reshape((-1, 1, 1)) - std = np.asarray([std_]).reshape((-1, 1, 1)) - data = (data - mean) / std - - return data.astype(np.float32) - - -def _raw_anime_real(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME): - """ - Perform raw anime real processing on the input image. - - :param image: The input image. - :type image: ImageTyping - :param model_name: The model name. Default is 'mobilenetv3_v0_dist'. - :type model_name: str - :return: The processed image data. - :rtype: np.ndarray - """ - image = load_image(image, force_background='white', mode='RGB') - input_ = _img_encode(image)[None, ...] - output, = _open_anime_real_model(model_name).run(['output'], {'input': input_}) - return output +_REPO_ID = 'deepghs/anime_real_cls' def anime_real_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]: @@ -156,9 +76,7 @@ def anime_real_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) >>> anime_real_score('real/real/16.jpg') {'anime': 1.5513256585109048e-05, 'real': 0.9999845027923584} """ - output = _raw_anime_real(image, model_name) - values = dict(zip(_get_anime_real_labels(model_name), map(lambda x: x.item(), output[0]))) - return values + return classify_predict_score(image, _REPO_ID, model_name) def anime_real(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]: @@ -208,6 +126,4 @@ def anime_real(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tup >>> anime_real('real/real/16.jpg') ('real', 0.9999845027923584) """ - output = _raw_anime_real(image, model_name)[0] - max_id = np.argmax(output) - return _get_anime_real_labels(model_name)[max_id], output[max_id].item() + return classify_predict(image, _REPO_ID, model_name) From a7ea60548309176a4f050aa6b6a8a928bd18628a Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 9 Jan 2024 12:44:47 +0800 Subject: [PATCH 02/19] dev(narugo): fix bug in unittest --- test/validate/test_real.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/validate/test_real.py b/test/validate/test_real.py index 1d13c5a14d0..dabb4be7442 100644 --- a/test/validate/test_real.py +++ b/test/validate/test_real.py @@ -3,8 +3,9 @@ import pytest -from imgutils.validate import anime_real -from imgutils.validate.real import _open_anime_real_model, anime_real_score +from imgutils.generic.classify import _open_models_for_repo_id +from imgutils.validate import anime_real, anime_real_score +from imgutils.validate.real import _REPO_ID from test.testings import get_testfile _ROOT_DIR = get_testfile('real') @@ -19,7 +20,7 @@ def _release_model_after_run(): try: yield finally: - _open_anime_real_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() @pytest.mark.unittest From c2487d921fbc0fe7988ff1cfd27dd1eb16fae5d1 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 9 Jan 2024 12:47:38 +0800 Subject: [PATCH 03/19] dev(narugo): regenerate doc --- .../api_doc/validate/real_benchmark.plot.py | 19 +- .../validate/real_benchmark.plot.py.svg | 2296 ----------------- 2 files changed, 5 insertions(+), 2310 deletions(-) delete mode 100644 docs/source/api_doc/validate/real_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/real_benchmark.plot.py b/docs/source/api_doc/validate/real_benchmark.plot.py index 8975e83a7ba..3572031d2f8 100644 --- a/docs/source/api_doc/validate/real_benchmark.plot.py +++ b/docs/source/api_doc/validate/real_benchmark.plot.py @@ -1,18 +1,11 @@ -import os import random -from huggingface_hub import HfFileSystem - from benchmark import BaseBenchmark, create_plot_cli +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_real +from imgutils.validate.real import _REPO_ID -hf_fs = HfFileSystem() - -_REPOSITORY = 'deepghs/anime_real_cls' -_MODEL_NAMES = [ - os.path.relpath(file, _REPOSITORY).split('/')[0] for file in - hf_fs.glob(f'{_REPOSITORY}/*/model.onnx') -] +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names class AnimeRealBenchmark(BaseBenchmark): @@ -21,12 +14,10 @@ def __init__(self, model): self.model = model def load(self): - from imgutils.validate.real import _open_anime_real_model - _ = _open_anime_real_model(self.model) + _open_models_for_repo_id(_REPO_ID)._open_model(self.model) def unload(self): - from imgutils.validate.real import _open_anime_real_model - _open_anime_real_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def run(self): image_file = random.choice(self.all_images) diff --git a/docs/source/api_doc/validate/real_benchmark.plot.py.svg b/docs/source/api_doc/validate/real_benchmark.plot.py.svg deleted file mode 100644 index 6cacfa121e8..00000000000 --- a/docs/source/api_doc/validate/real_benchmark.plot.py.svg +++ /dev/null @@ -1,2296 +0,0 @@ - - - - - - - - 2023-12-16T10:54:00.694168 - image/svg+xml - - - Matplotlib v3.7.4, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - From 1012fd6393979025e8248284b26fb6827003fc6a Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 9 Jan 2024 04:52:33 +0000 Subject: [PATCH 04/19] dev(narugo): auto sync Tue, 09 Jan 2024 04:52:33 +0000 --- .../validate/real_benchmark.plot.py.svg | 2296 +++++++++++++++++ 1 file changed, 2296 insertions(+) create mode 100644 docs/source/api_doc/validate/real_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/real_benchmark.plot.py.svg b/docs/source/api_doc/validate/real_benchmark.plot.py.svg new file mode 100644 index 00000000000..b97b807fb0e --- /dev/null +++ b/docs/source/api_doc/validate/real_benchmark.plot.py.svg @@ -0,0 +1,2296 @@ + + + + + + + + 2024-01-09T04:52:16.025388 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 988b1d666a507e097fc93e0e6ac3404339adeb3b Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 9 Jan 2024 13:44:43 +0800 Subject: [PATCH 05/19] dev(narugo): save it --- .../validate/portrait_benchmark.plot.py | 19 +- .../validate/portrait_benchmark.plot.py.svg | 2418 ----------------- .../validate/style_age_benchmark.plot.py | 19 +- .../validate/style_age_benchmark.plot.py.svg | 2274 ---------------- imgutils/validate/portrait.py | 96 +- imgutils/validate/style_age.py | 96 +- test/validate/test_portrait.py | 7 +- test/validate/test_style_age.py | 7 +- 8 files changed, 30 insertions(+), 4906 deletions(-) delete mode 100644 docs/source/api_doc/validate/portrait_benchmark.plot.py.svg delete mode 100644 docs/source/api_doc/validate/style_age_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/portrait_benchmark.plot.py b/docs/source/api_doc/validate/portrait_benchmark.plot.py index 7bd41b3e414..3560bb16f53 100644 --- a/docs/source/api_doc/validate/portrait_benchmark.plot.py +++ b/docs/source/api_doc/validate/portrait_benchmark.plot.py @@ -1,18 +1,11 @@ -import os import random -from huggingface_hub import HfFileSystem - from benchmark import BaseBenchmark, create_plot_cli +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_portrait +from imgutils.validate.portrait import _REPO_ID -hf_fs = HfFileSystem() - -_REPOSITORY = 'deepghs/anime_portrait' -_MODEL_NAMES = [ - os.path.relpath(file, _REPOSITORY).split('/')[0] for file in - hf_fs.glob(f'{_REPOSITORY}/*/model.onnx') -] +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names class AnimePortraitBenchmark(BaseBenchmark): @@ -21,12 +14,10 @@ def __init__(self, model): self.model = model def load(self): - from imgutils.validate.portrait import _open_anime_portrait_model - _ = _open_anime_portrait_model(self.model) + _open_models_for_repo_id(_REPO_ID)._open_model(self.model) def unload(self): - from imgutils.validate.portrait import _open_anime_portrait_model - _open_anime_portrait_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def run(self): image_file = random.choice(self.all_images) diff --git a/docs/source/api_doc/validate/portrait_benchmark.plot.py.svg b/docs/source/api_doc/validate/portrait_benchmark.plot.py.svg deleted file mode 100644 index 3cd6bb50938..00000000000 --- a/docs/source/api_doc/validate/portrait_benchmark.plot.py.svg +++ /dev/null @@ -1,2418 +0,0 @@ - - - - - - - - 2023-10-13T12:10:57.047588 - image/svg+xml - - - Matplotlib v3.7.3, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/source/api_doc/validate/style_age_benchmark.plot.py b/docs/source/api_doc/validate/style_age_benchmark.plot.py index 2bb15217460..8fb5fbe744d 100644 --- a/docs/source/api_doc/validate/style_age_benchmark.plot.py +++ b/docs/source/api_doc/validate/style_age_benchmark.plot.py @@ -1,18 +1,11 @@ -import os import random -from huggingface_hub import HfFileSystem - from benchmark import BaseBenchmark, create_plot_cli +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_style_age +from imgutils.validate.style_age import _REPO_ID -hf_fs = HfFileSystem() - -_REPOSITORY = 'deepghs/anime_style_ages' -_MODEL_NAMES = [ - os.path.relpath(file, _REPOSITORY).split('/')[0] for file in - hf_fs.glob(f'{_REPOSITORY}/*/model.onnx') -] +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names class AnimeStyleAgeBenchmark(BaseBenchmark): @@ -21,12 +14,10 @@ def __init__(self, model): self.model = model def load(self): - from imgutils.validate.style_age import _open_anime_style_age_model - _ = _open_anime_style_age_model(self.model) + _open_models_for_repo_id(_REPO_ID)._open_model(self.model) def unload(self): - from imgutils.validate.style_age import _open_anime_style_age_model - _open_anime_style_age_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def run(self): image_file = random.choice(self.all_images) diff --git a/docs/source/api_doc/validate/style_age_benchmark.plot.py.svg b/docs/source/api_doc/validate/style_age_benchmark.plot.py.svg deleted file mode 100644 index 49c7c70607a..00000000000 --- a/docs/source/api_doc/validate/style_age_benchmark.plot.py.svg +++ /dev/null @@ -1,2274 +0,0 @@ - - - - - - - - 2023-12-16T09:42:47.339987 - image/svg+xml - - - Matplotlib v3.7.4, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/imgutils/validate/portrait.py b/imgutils/validate/portrait.py index ffd3ff225d7..79d78f68f22 100644 --- a/imgutils/validate/portrait.py +++ b/imgutils/validate/portrait.py @@ -15,16 +15,10 @@ The models are hosted on `huggingface - deepghs/anime_portrait `_. """ -import json -from functools import lru_cache -from typing import Tuple, Optional, Dict, List +from typing import Tuple, Dict -import numpy as np -from PIL import Image -from huggingface_hub import hf_hub_download - -from imgutils.data import rgb_encode, ImageTyping, load_image -from imgutils.utils import open_onnx_model +from ..data import ImageTyping +from ..generic import classify_predict, classify_predict_score __all__ = [ 'anime_portrait_score', @@ -32,81 +26,7 @@ ] _DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist' - - -@lru_cache() -def _open_anime_portrait_model(model_name): - """ - Open the anime portrait model. - - :param model_name: The model name. - :type model_name: str - :return: The ONNX model. - """ - return open_onnx_model(hf_hub_download( - f'deepghs/anime_portrait', - f'{model_name}/model.onnx', - )) - - -@lru_cache() -def _get_anime_portrait_labels(model_name) -> List[str]: - """ - Get the labels for the anime portrait model. - - :param model_name: The model name. - :type model_name: str - :return: The list of labels. - :rtype: List[str] - """ - with open(hf_hub_download( - f'deepghs/anime_portrait', - f'{model_name}/meta.json', - ), 'r') as f: - return json.load(f)['labels'] - - -def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), - normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): - """ - Encode the input image. - - :param image: The input image. - :type image: Image.Image - :param size: The desired size of the image. - :type size: Tuple[int, int] - :param normalize: Mean and standard deviation for normalization. Default is (0.5, 0.5). - :type normalize: Optional[Tuple[float, float]] - :return: The encoded image data. - :rtype: np.ndarray - """ - image = image.resize(size, Image.BILINEAR) - data = rgb_encode(image, order_='CHW') - - if normalize is not None: - mean_, std_ = normalize - mean = np.asarray([mean_]).reshape((-1, 1, 1)) - std = np.asarray([std_]).reshape((-1, 1, 1)) - data = (data - mean) / std - - return data.astype(np.float32) - - -def _raw_anime_portrait(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME): - """ - Perform raw anime portrait processing on the input image. - - :param image: The input image. - :type image: ImageTyping - :param model_name: The model name. Default is 'mobilenetv3_v0_dist'. - :type model_name: str - :return: The processed image data. - :rtype: np.ndarray - """ - image = load_image(image, force_background='white', mode='RGB') - input_ = _img_encode(image)[None, ...] - output, = _open_anime_portrait_model(model_name).run(['output'], {'input': input_}) - return output +_REPO_ID = 'deepghs/anime_portrait' def anime_portrait_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]: @@ -142,9 +62,7 @@ def anime_portrait_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NA >>> anime_portrait_score('head/9.jpg') {'person': 5.736660568800289e-07, 'halfbody': 7.199210472208506e-08, 'head': 0.9999992847442627} """ - output = _raw_anime_portrait(image, model_name) - values = dict(zip(_get_anime_portrait_labels(model_name), map(lambda x: x.item(), output[0]))) - return values + return classify_predict_score(image, _REPO_ID, model_name) def anime_portrait(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]: @@ -180,6 +98,4 @@ def anime_portrait(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> >>> anime_portrait('head/9.jpg') ('head', 0.9999992847442627) """ - output = _raw_anime_portrait(image, model_name)[0] - max_id = np.argmax(output) - return _get_anime_portrait_labels(model_name)[max_id], output[max_id].item() + return classify_predict(image, _REPO_ID, model_name) diff --git a/imgutils/validate/style_age.py b/imgutils/validate/style_age.py index 81a861b5ea7..78624715686 100644 --- a/imgutils/validate/style_age.py +++ b/imgutils/validate/style_age.py @@ -16,16 +16,10 @@ The models are hosted on `huggingface - deepghs/anime_style_ages `_. """ -import json -from functools import lru_cache -from typing import Tuple, Optional, Dict, List +from typing import Tuple, Dict -import numpy as np -from PIL import Image -from huggingface_hub import hf_hub_download - -from imgutils.data import rgb_encode, ImageTyping, load_image -from imgutils.utils import open_onnx_model +from ..data import ImageTyping +from ..generic import classify_predict, classify_predict_score __all__ = [ 'anime_style_age_score', @@ -33,81 +27,7 @@ ] _DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist' - - -@lru_cache() -def _open_anime_style_age_model(model_name): - """ - Open the anime style age model. - - :param model_name: The model name. - :type model_name: str - :return: The ONNX model. - """ - return open_onnx_model(hf_hub_download( - f'deepghs/anime_style_ages', - f'{model_name}/model.onnx', - )) - - -@lru_cache() -def _get_anime_style_age_labels(model_name) -> List[str]: - """ - Get the labels for the anime style age model. - - :param model_name: The model name. - :type model_name: str - :return: The list of labels. - :rtype: List[str] - """ - with open(hf_hub_download( - f'deepghs/anime_style_ages', - f'{model_name}/meta.json', - ), 'r') as f: - return json.load(f)['labels'] - - -def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), - normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): - """ - Encode the input image. - - :param image: The input image. - :type image: Image.Image - :param size: The desired size of the image. - :type size: Tuple[int, int] - :param normalize: Mean and standard deviation for normalization. Default is (0.5, 0.5). - :type normalize: Optional[Tuple[float, float]] - :return: The encoded image data. - :rtype: np.ndarray - """ - image = image.resize(size, Image.BILINEAR) - data = rgb_encode(image, order_='CHW') - - if normalize is not None: - mean_, std_ = normalize - mean = np.asarray([mean_]).reshape((-1, 1, 1)) - std = np.asarray([std_]).reshape((-1, 1, 1)) - data = (data - mean) / std - - return data.astype(np.float32) - - -def _raw_anime_style_age(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME): - """ - Perform raw anime style age processing on the input image. - - :param image: The input image. - :type image: ImageTyping - :param model_name: The model name. Default is 'mobilenetv3_v0_dist'. - :type model_name: str - :return: The processed image data. - :rtype: np.ndarray - """ - image = load_image(image, force_background='white', mode='RGB') - input_ = _img_encode(image)[None, ...] - output, = _open_anime_style_age_model(model_name).run(['output'], {'input': input_}) - return output +_REPO_ID = 'deepghs/anime_style_ages' def anime_style_age_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]: @@ -139,9 +59,7 @@ def anime_style_age_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_N >>> anime_style_age_score('style_age/2020s/25.jpg') {'1970s-': 1.9200742826797068e-05, '1980s': 0.00017117452807724476, '1990s': 9.518441947875544e-05, '2000s': 2.885544381570071e-05, '2010s': 1.4389253010449465e-05, '2015s': 3.1696006772108376e-05, '2020s': 0.9996393918991089} """ - output = _raw_anime_style_age(image, model_name) - values = dict(zip(_get_anime_style_age_labels(model_name), map(lambda x: x.item(), output[0]))) - return values + return classify_predict_score(image, _REPO_ID, model_name) def anime_style_age(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]: @@ -173,6 +91,4 @@ def anime_style_age(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) - >>> anime_style_age('style_age/2020s/25.jpg') ('2020s', 0.9996393918991089) """ - output = _raw_anime_style_age(image, model_name)[0] - max_id = np.argmax(output) - return _get_anime_style_age_labels(model_name)[max_id], output[max_id].item() + return classify_predict(image, _REPO_ID, model_name) diff --git a/test/validate/test_portrait.py b/test/validate/test_portrait.py index 586731f8785..477facc28dd 100644 --- a/test/validate/test_portrait.py +++ b/test/validate/test_portrait.py @@ -3,8 +3,9 @@ import pytest -from imgutils.validate import anime_portrait -from imgutils.validate.portrait import _open_anime_portrait_model, anime_portrait_score +from imgutils.generic.classify import _open_models_for_repo_id +from imgutils.validate import anime_portrait, anime_portrait_score +from imgutils.validate.portrait import _REPO_ID from test.testings import get_testfile _ROOT_DIR = get_testfile('anime_portrait') @@ -19,7 +20,7 @@ def _release_model_after_run(): try: yield finally: - _open_anime_portrait_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() @pytest.mark.unittest diff --git a/test/validate/test_style_age.py b/test/validate/test_style_age.py index d4c32b019d4..93cb9f9f066 100644 --- a/test/validate/test_style_age.py +++ b/test/validate/test_style_age.py @@ -3,8 +3,9 @@ import pytest -from imgutils.validate import anime_style_age -from imgutils.validate.style_age import _open_anime_style_age_model, anime_style_age_score +from imgutils.generic.classify import _open_models_for_repo_id +from imgutils.validate import anime_style_age, anime_style_age_score +from imgutils.validate.style_age import _REPO_ID from test.testings import get_testfile _ROOT_DIR = get_testfile('anime_style_age') @@ -19,7 +20,7 @@ def _release_model_after_run(): try: yield finally: - _open_anime_style_age_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() @pytest.mark.unittest From 33ebedabd9a38ff9889b66aeee0be33569c775a2 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Tue, 9 Jan 2024 05:51:51 +0000 Subject: [PATCH 06/19] dev(narugo): auto sync Tue, 09 Jan 2024 05:51:51 +0000 --- .../validate/portrait_benchmark.plot.py.svg | 2346 +++++++++++++++++ .../validate/style_age_benchmark.plot.py.svg | 2298 ++++++++++++++++ 2 files changed, 4644 insertions(+) create mode 100644 docs/source/api_doc/validate/portrait_benchmark.plot.py.svg create mode 100644 docs/source/api_doc/validate/style_age_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/portrait_benchmark.plot.py.svg b/docs/source/api_doc/validate/portrait_benchmark.plot.py.svg new file mode 100644 index 00000000000..9ce48f85475 --- /dev/null +++ b/docs/source/api_doc/validate/portrait_benchmark.plot.py.svg @@ -0,0 +1,2346 @@ + + + + + + + + 2024-01-09T05:51:35.692569 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/api_doc/validate/style_age_benchmark.plot.py.svg b/docs/source/api_doc/validate/style_age_benchmark.plot.py.svg new file mode 100644 index 00000000000..1227ddeea19 --- /dev/null +++ b/docs/source/api_doc/validate/style_age_benchmark.plot.py.svg @@ -0,0 +1,2298 @@ + + + + + + + + 2024-01-09T05:49:13.389097 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 1d9ef8f14014ef5cd8c18e1afc63ab46820ae317 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 15:00:37 +0800 Subject: [PATCH 07/19] dev(narugo): use new real models --- .../validate/real_benchmark.plot.py.svg | 2296 ----------------- imgutils/validate/real.py | 6 +- 2 files changed, 3 insertions(+), 2299 deletions(-) delete mode 100644 docs/source/api_doc/validate/real_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/real_benchmark.plot.py.svg b/docs/source/api_doc/validate/real_benchmark.plot.py.svg deleted file mode 100644 index b97b807fb0e..00000000000 --- a/docs/source/api_doc/validate/real_benchmark.plot.py.svg +++ /dev/null @@ -1,2296 +0,0 @@ - - - - - - - - 2024-01-09T04:52:16.025388 - image/svg+xml - - - Matplotlib v3.7.4, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/imgutils/validate/real.py b/imgutils/validate/real.py index 33aca57473c..c97eb9b421b 100644 --- a/imgutils/validate/real.py +++ b/imgutils/validate/real.py @@ -25,7 +25,7 @@ 'anime_real', ] -_DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist' +_DEFAULT_MODEL_NAME = 'mobilenetv3_v1_dist_ls0.1' _REPO_ID = 'deepghs/anime_real_cls' @@ -35,7 +35,7 @@ def anime_real_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) :param image: The input image. :type image: ImageTyping - :param model_name: The model name. Default is 'mobilenetv3_v0_dist'. + :param model_name: The model name. Default is 'mobilenetv3_v1_dist_ls0.1'. :type model_name: str :return: A dictionary with type scores. :rtype: Dict[str, float] @@ -85,7 +85,7 @@ def anime_real(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tup :param image: The input image. :type image: ImageTyping - :param model_name: The model name. Default is 'mobilenetv3_v0_dist'. + :param model_name: The model name. Default is 'mobilenetv3_v1_dist_ls0.1'. :type model_name: str :return: A tuple with the primary type and its score. :rtype: Tuple[str, float] From 6a6f07ce66a283207a92204bfaf618ee644a9e57 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 07:14:21 +0000 Subject: [PATCH 08/19] dev(narugo): auto sync Wed, 24 Jan 2024 07:14:21 +0000 --- .../validate/real_benchmark.plot.py.svg | 2604 +++++++++++++++++ 1 file changed, 2604 insertions(+) create mode 100644 docs/source/api_doc/validate/real_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/real_benchmark.plot.py.svg b/docs/source/api_doc/validate/real_benchmark.plot.py.svg new file mode 100644 index 00000000000..b40cd47a4ef --- /dev/null +++ b/docs/source/api_doc/validate/real_benchmark.plot.py.svg @@ -0,0 +1,2604 @@ + + + + + + + + 2024-01-24T07:14:03.658057 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 295af1f48cf27dd5b33ba1778a70c2e74481511a Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 15:44:27 +0800 Subject: [PATCH 09/19] dev(narugo): replace aicheck --- .../validate/aicheck_benchmark.plot.py | 11 +- .../validate/aicheck_benchmark.plot.py.svg | 2454 ----------------- imgutils/validate/aicheck.py | 133 +- test/validate/test_aicheck.py | 5 +- 4 files changed, 57 insertions(+), 2546 deletions(-) delete mode 100644 docs/source/api_doc/validate/aicheck_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/aicheck_benchmark.plot.py b/docs/source/api_doc/validate/aicheck_benchmark.plot.py index 0561e998ccf..8fe293d0481 100644 --- a/docs/source/api_doc/validate/aicheck_benchmark.plot.py +++ b/docs/source/api_doc/validate/aicheck_benchmark.plot.py @@ -1,8 +1,11 @@ import random from benchmark import BaseBenchmark, create_plot_cli +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import get_ai_created_score -from imgutils.validate.aicheck import _MODEL_NAMES +from imgutils.validate.aicheck import _REPO_ID + +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names class AnimeAICheckBenchmark(BaseBenchmark): @@ -11,12 +14,10 @@ def __init__(self, model): self.model = model def load(self): - from imgutils.validate.aicheck import _open_anime_aicheck_model - _ = _open_anime_aicheck_model(self.model) + _open_models_for_repo_id(_REPO_ID)._open_model(self.model) def unload(self): - from imgutils.validate.aicheck import _open_anime_aicheck_model - _open_anime_aicheck_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def run(self): image_file = random.choice(self.all_images) diff --git a/docs/source/api_doc/validate/aicheck_benchmark.plot.py.svg b/docs/source/api_doc/validate/aicheck_benchmark.plot.py.svg deleted file mode 100644 index 8881ea4f5db..00000000000 --- a/docs/source/api_doc/validate/aicheck_benchmark.plot.py.svg +++ /dev/null @@ -1,2454 +0,0 @@ - - - - - - - - 2023-06-07T12:30:16.620124 - image/svg+xml - - - Matplotlib v3.5.3, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/imgutils/validate/aicheck.py b/imgutils/validate/aicheck.py index 4034bf10733..88437e2e041 100644 --- a/imgutils/validate/aicheck.py +++ b/imgutils/validate/aicheck.py @@ -15,62 +15,72 @@ The models are hosted on `huggingface - deepghs/anime_ai_check `_. """ -from functools import lru_cache -from typing import Tuple, Optional - -import numpy as np -from PIL import Image -from huggingface_hub import hf_hub_download - -from imgutils.data import rgb_encode, ImageTyping, load_image -from imgutils.utils import open_onnx_model +from ..data import ImageTyping +from ..generic import classify_predict, classify_predict_score __all__ = [ 'get_ai_created_score', 'is_ai_created', ] -_LABELS = ['ai', 'human'] -_MODEL_NAMES = [ - 'caformer_s36_plus_sce', - 'mobilenetv3_sce', - 'mobilenetv3_sce_dist', -] _DEFAULT_MODEL_NAME = 'mobilenetv3_sce_dist' +_REPO_ID = 'deepghs/anime_ai_check' -@lru_cache() -def _open_anime_aicheck_model(model_name): - return open_onnx_model(hf_hub_download( - f'deepghs/anime_ai_check', - f'{model_name}/model.onnx', - )) - - -def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), - normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): - image = image.resize(size, Image.BILINEAR) - data = rgb_encode(image, order_='CHW') +def get_ai_created_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> float: + """ + Overview: + Predict if the given image is created by AI (mainly by stable diffusion), given a score. - if normalize is not None: - mean_, std_ = normalize - mean = np.asarray([mean_]).reshape((-1, 1, 1)) - std = np.asarray([std_]).reshape((-1, 1, 1)) - data = (data - mean) / std + :param image: Image to be predicted. + :param model_name: Name of the model. Default is ``mobilenetv3_sce_dist``. + If you need better accuracy, use ``caformer_s36_plus_sce``. + All the available values are listed on the benchmark graph. + :return: A float number which represent the score of AI-check. - return data.astype(np.float32) + Examples:: + >>> from imgutils.validate import get_ai_created_score + >>> + >>> get_ai_created_score('aicheck/ai/1.jpg') + 0.9996960163116455 + >>> get_ai_created_score('aicheck/ai/2.jpg') + 0.9999125003814697 + >>> get_ai_created_score('aicheck/ai/3.jpg') + 0.997803270816803 + >>> get_ai_created_score('aicheck/ai/4.jpg') + 0.9960069060325623 + >>> get_ai_created_score('aicheck/ai/5.jpg') + 0.9887709021568298 + >>> get_ai_created_score('aicheck/ai/6.jpg') + 0.9998629093170166 + >>> get_ai_created_score('aicheck/human/7.jpg') + 0.0013722758740186691 + >>> get_ai_created_score('aicheck/human/8.jpg') + 0.00020673229300882667 + >>> get_ai_created_score('aicheck/human/9.jpg') + 0.0001895089662866667 + >>> get_ai_created_score('aicheck/human/10.jpg') + 0.0008857478387653828 + >>> get_ai_created_score('aicheck/human/11.jpg') + 4.552320024231449e-05 + >>> get_ai_created_score('aicheck/human/12.jpg') + 0.001168627175502479 + """ + return classify_predict_score(image, _REPO_ID, model_name)['ai'] -def get_ai_created_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> float: +def is_ai_created(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, threshold: float = 0.5) -> bool: """ Overview: - Predict if the given image is created by AI (mainly by stable diffusion), given a score. + Predict if the given image is created by AI (mainly by stable diffusion). :param image: Image to be predicted. :param model_name: Name of the model. Default is ``mobilenetv3_sce_dist``. If you need better accuracy, use ``caformer_s36_plus_sce``. All the available values are listed on the benchmark graph. - :return: A float number which represent the score of AI-check. + :param threshold: Threshold of the score. When the score is no less than ``threshold``, this image + will be predicted as ``AI-created``. Default is ``0.5``. + :return: This image is ``AI-created`` or not. Examples:: >>> from imgutils.validate import is_ai_created @@ -100,52 +110,5 @@ def get_ai_created_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NA >>> is_ai_created('aicheck/human/12.jpg') False """ - image = load_image(image, force_background='white', mode='RGB') - input_ = _img_encode(image)[None, ...] - output, = _open_anime_aicheck_model(model_name).run(['output'], {'input': input_}) - - return output[0][0].item() - - -def is_ai_created(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME, threshold: float = 0.5) -> bool: - """ - Overview: - Predict if the given image is created by AI (mainly by stable diffusion). - - :param image: Image to be predicted. - :param model_name: Name of the model. Default is ``mobilenetv3_sce_dist``. - If you need better accuracy, use ``caformer_s36_plus_sce``. - All the available values are listed on the benchmark graph. - :param threshold: Threshold of the score. When the score is no less than ``threshold``, this image - will be predicted as ``AI-created``. Default is ``0.5``. - :return: This image is ``AI-created`` or not. - - Examples:: - >>> from imgutils.validate import get_ai_created_score - >>> - >>> get_ai_created_score('aicheck/ai/1.jpg') - 0.9996960163116455 - >>> get_ai_created_score('aicheck/ai/2.jpg') - 0.9999125003814697 - >>> get_ai_created_score('aicheck/ai/3.jpg') - 0.997803270816803 - >>> get_ai_created_score('aicheck/ai/4.jpg') - 0.9960069060325623 - >>> get_ai_created_score('aicheck/ai/5.jpg') - 0.9887709021568298 - >>> get_ai_created_score('aicheck/ai/6.jpg') - 0.9998629093170166 - >>> get_ai_created_score('aicheck/human/7.jpg') - 0.0013722758740186691 - >>> get_ai_created_score('aicheck/human/8.jpg') - 0.00020673229300882667 - >>> get_ai_created_score('aicheck/human/9.jpg') - 0.0001895089662866667 - >>> get_ai_created_score('aicheck/human/10.jpg') - 0.0008857478387653828 - >>> get_ai_created_score('aicheck/human/11.jpg') - 4.552320024231449e-05 - >>> get_ai_created_score('aicheck/human/12.jpg') - 0.001168627175502479 - """ - return get_ai_created_score(image, model_name) >= threshold + type_, _ = classify_predict(image, _REPO_ID, model_name) + return type_ == 'ai' diff --git a/test/validate/test_aicheck.py b/test/validate/test_aicheck.py index 96cadb550da..8720514878b 100644 --- a/test/validate/test_aicheck.py +++ b/test/validate/test_aicheck.py @@ -3,7 +3,8 @@ import pytest -from imgutils.validate.aicheck import _open_anime_aicheck_model, is_ai_created, get_ai_created_score +from imgutils.generic.classify import _open_models_for_repo_id +from imgutils.validate.aicheck import is_ai_created, get_ai_created_score, _REPO_ID from test.testings import get_testfile _ROOT_DIR = get_testfile('anime_aicheck') @@ -18,7 +19,7 @@ def _release_model_after_run(): try: yield finally: - _open_anime_aicheck_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() @pytest.mark.unittest From 7f034b313ed8d66a011606b97ad13210ddfe8d19 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 07:51:20 +0000 Subject: [PATCH 10/19] dev(narugo): auto sync Wed, 24 Jan 2024 07:51:20 +0000 --- .../validate/aicheck_benchmark.plot.py.svg | 2598 +++++++++++++++++ 1 file changed, 2598 insertions(+) create mode 100644 docs/source/api_doc/validate/aicheck_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/aicheck_benchmark.plot.py.svg b/docs/source/api_doc/validate/aicheck_benchmark.plot.py.svg new file mode 100644 index 00000000000..015432e822f --- /dev/null +++ b/docs/source/api_doc/validate/aicheck_benchmark.plot.py.svg @@ -0,0 +1,2598 @@ + + + + + + + + 2024-01-24T07:51:04.522857 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From 6cac3382d4c9867ff8f71ea909daae2778eb1a9a Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 16:02:35 +0800 Subject: [PATCH 11/19] dev(narugo): add bangumi char --- .../validate/bangumi_char_benchmark.plot.py | 19 +- .../bangumi_char_benchmark.plot.py.svg | 2260 ----------------- imgutils/validate/bangumi_char.py | 96 +- test/validate/test_bangumi_char.py | 5 +- 4 files changed, 14 insertions(+), 2366 deletions(-) delete mode 100644 docs/source/api_doc/validate/bangumi_char_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py b/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py index 1faf132d134..19a642042ba 100644 --- a/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py +++ b/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py @@ -1,18 +1,11 @@ -import os import random -from huggingface_hub import HfFileSystem - from benchmark import BaseBenchmark, create_plot_cli +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_bangumi_char +from imgutils.validate.bangumi_char import _REPO_ID -hf_fs = HfFileSystem() - -_REPOSITORY = 'deepghs/bangumi_char_type' -_MODEL_NAMES = [ - os.path.relpath(file, _REPOSITORY).split('/')[0] for file in - hf_fs.glob(f'{_REPOSITORY}/*/model.onnx') -] +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names class AnimeBangumiCharacterBenchmark(BaseBenchmark): @@ -21,12 +14,10 @@ def __init__(self, model): self.model = model def load(self): - from imgutils.validate.bangumi_char import _open_anime_bangumi_char_model - _ = _open_anime_bangumi_char_model(self.model) + _open_models_for_repo_id(_REPO_ID)._open_model(self.model) def unload(self): - from imgutils.validate.bangumi_char import _open_anime_bangumi_char_model - _open_anime_bangumi_char_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def run(self): image_file = random.choice(self.all_images) diff --git a/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py.svg b/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py.svg deleted file mode 100644 index 66c2f231ae8..00000000000 --- a/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py.svg +++ /dev/null @@ -1,2260 +0,0 @@ - - - - - - - - 2023-12-16T09:42:18.370684 - image/svg+xml - - - Matplotlib v3.7.4, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/imgutils/validate/bangumi_char.py b/imgutils/validate/bangumi_char.py index 71074817953..7b37594e335 100644 --- a/imgutils/validate/bangumi_char.py +++ b/imgutils/validate/bangumi_char.py @@ -28,16 +28,10 @@ If you are looking for a classification model that judges the proportion of the head in an image, please use the :func:`imgutils.validate.anime_portrait` function. """ -import json -from functools import lru_cache -from typing import Tuple, Optional, Dict, List +from typing import Tuple, Dict -import numpy as np -from PIL import Image -from huggingface_hub import hf_hub_download - -from imgutils.data import rgb_encode, ImageTyping, load_image -from imgutils.utils import open_onnx_model +from ..data import ImageTyping +from ..generic import classify_predict_score, classify_predict __all__ = [ 'anime_bangumi_char_score', @@ -45,81 +39,7 @@ ] _DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist' - - -@lru_cache() -def _open_anime_bangumi_char_model(model_name): - """ - Open the anime bangumi character model. - - :param model_name: The model name. - :type model_name: str - :return: The ONNX model. - """ - return open_onnx_model(hf_hub_download( - f'deepghs/bangumi_char_type', - f'{model_name}/model.onnx', - )) - - -@lru_cache() -def _get_anime_bangumi_char_labels(model_name) -> List[str]: - """ - Get the labels for the anime bangumi character model. - - :param model_name: The model name. - :type model_name: str - :return: The list of labels. - :rtype: List[str] - """ - with open(hf_hub_download( - f'deepghs/bangumi_char_type', - f'{model_name}/meta.json', - ), 'r') as f: - return json.load(f)['labels'] - - -def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), - normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): - """ - Encode the input image. - - :param image: The input image. - :type image: Image.Image - :param size: The desired size of the image. - :type size: Tuple[int, int] - :param normalize: Mean and standard deviation for normalization. Default is (0.5, 0.5). - :type normalize: Optional[Tuple[float, float]] - :return: The encoded image data. - :rtype: np.ndarray - """ - image = image.resize(size, Image.BILINEAR) - data = rgb_encode(image, order_='CHW') - - if normalize is not None: - mean_, std_ = normalize - mean = np.asarray([mean_]).reshape((-1, 1, 1)) - std = np.asarray([std_]).reshape((-1, 1, 1)) - data = (data - mean) / std - - return data.astype(np.float32) - - -def _raw_anime_bangumi_char(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME): - """ - Perform raw anime bangumi character processing on the input image. - - :param image: The input image. - :type image: ImageTyping - :param model_name: The model name. Default is 'mobilenetv3_v0_dist'. - :type model_name: str - :return: The processed image data. - :rtype: np.ndarray - """ - image = load_image(image, force_background='white', mode='RGB') - input_ = _img_encode(image)[None, ...] - output, = _open_anime_bangumi_char_model(model_name).run(['output'], {'input': input_}) - return output +_REPO_ID = 'deepghs/bangumi_char_type' def anime_bangumi_char_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]: @@ -169,9 +89,7 @@ def anime_bangumi_char_score(image: ImageTyping, model_name: str = _DEFAULT_MODE >>> anime_bangumi_char_score('bangumi_char/face/16.jpg') {'vision': 1.066640925273532e-05, 'imagery': 9.529400813335087e-06, 'halfbody': 4.089402500540018e-05, 'face': 0.9999388456344604} """ - output = _raw_anime_bangumi_char(image, model_name) - values = dict(zip(_get_anime_bangumi_char_labels(model_name), map(lambda x: x.item(), output[0]))) - return values + return classify_predict_score(image, _REPO_ID, model_name) def anime_bangumi_char(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]: @@ -221,6 +139,4 @@ def anime_bangumi_char(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME >>> anime_bangumi_char('bangumi_char/face/16.jpg') ('face', 0.9999388456344604) """ - output = _raw_anime_bangumi_char(image, model_name)[0] - max_id = np.argmax(output) - return _get_anime_bangumi_char_labels(model_name)[max_id], output[max_id].item() + return classify_predict(image, _REPO_ID, model_name) diff --git a/test/validate/test_bangumi_char.py b/test/validate/test_bangumi_char.py index 1295137c402..48df09dcfd0 100644 --- a/test/validate/test_bangumi_char.py +++ b/test/validate/test_bangumi_char.py @@ -3,8 +3,9 @@ import pytest +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_bangumi_char -from imgutils.validate.bangumi_char import _open_anime_bangumi_char_model, anime_bangumi_char_score +from imgutils.validate.bangumi_char import anime_bangumi_char_score, _REPO_ID from test.testings import get_testfile _ROOT_DIR = get_testfile('bangumi_char') @@ -19,7 +20,7 @@ def _release_model_after_run(): try: yield finally: - _open_anime_bangumi_char_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() @pytest.mark.unittest From 27adaabfd649628ee8101d55135ed423772c573a Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 16:06:13 +0800 Subject: [PATCH 12/19] dev(narugo): update classify --- .../validate/classify_benchmark.plot.py | 11 +- .../validate/classify_benchmark.plot.py.svg | 2782 ----------------- imgutils/validate/classify.py | 60 +- test/validate/test_classify.py | 5 +- 4 files changed, 15 insertions(+), 2843 deletions(-) delete mode 100644 docs/source/api_doc/validate/classify_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/classify_benchmark.plot.py b/docs/source/api_doc/validate/classify_benchmark.plot.py index b378b2dfa26..a300b75d8bc 100644 --- a/docs/source/api_doc/validate/classify_benchmark.plot.py +++ b/docs/source/api_doc/validate/classify_benchmark.plot.py @@ -1,8 +1,11 @@ import random from benchmark import BaseBenchmark, create_plot_cli +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_classify -from imgutils.validate.classify import _MODEL_NAMES +from imgutils.validate.classify import _REPO_ID + +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names class AnimeClassifyBenchmark(BaseBenchmark): @@ -11,12 +14,10 @@ def __init__(self, model): self.model = model def load(self): - from imgutils.validate.classify import _open_anime_classify_model - _ = _open_anime_classify_model(self.model) + _open_models_for_repo_id(_REPO_ID)._open_model(self.model) def unload(self): - from imgutils.validate.classify import _open_anime_classify_model - _open_anime_classify_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def run(self): image_file = random.choice(self.all_images) diff --git a/docs/source/api_doc/validate/classify_benchmark.plot.py.svg b/docs/source/api_doc/validate/classify_benchmark.plot.py.svg deleted file mode 100644 index 3fd8b3bde56..00000000000 --- a/docs/source/api_doc/validate/classify_benchmark.plot.py.svg +++ /dev/null @@ -1,2782 +0,0 @@ - - - - - - - - 2023-06-07T11:32:14.231803 - image/svg+xml - - - Matplotlib v3.5.3, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/imgutils/validate/classify.py b/imgutils/validate/classify.py index 71076479426..d9cd7d7bdb3 100644 --- a/imgutils/validate/classify.py +++ b/imgutils/validate/classify.py @@ -15,62 +15,18 @@ The models are hosted on `huggingface - deepghs/anime_classification `_. """ -from functools import lru_cache -from typing import Tuple, Optional, Dict +from typing import Tuple, Dict -import numpy as np -from PIL import Image -from huggingface_hub import hf_hub_download - -from imgutils.data import rgb_encode, ImageTyping, load_image -from imgutils.utils import open_onnx_model +from ..data import ImageTyping +from ..generic import classify_predict, classify_predict_score __all__ = [ 'anime_classify_score', 'anime_classify', ] -_LABELS = ['3d', 'bangumi', 'comic', 'illustration'] -_MODEL_NAMES = [ - 'caformer_s36', - 'caformer_s36_plus', - 'mobilenetv3', - 'mobilenetv3_dist', - 'mobilenetv3_sce', - 'mobilenetv3_sce_dist', - 'mobilevitv2_150', -] _DEFAULT_MODEL_NAME = 'mobilenetv3_sce_dist' - - -@lru_cache() -def _open_anime_classify_model(model_name): - return open_onnx_model(hf_hub_download( - f'deepghs/anime_classification', - f'{model_name}/model.onnx', - )) - - -def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), - normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): - image = image.resize(size, Image.BILINEAR) - data = rgb_encode(image, order_='CHW') - - if normalize is not None: - mean_, std_ = normalize - mean = np.asarray([mean_]).reshape((-1, 1, 1)) - std = np.asarray([std_]).reshape((-1, 1, 1)) - data = (data - mean) / std - - return data.astype(np.float32) - - -def _raw_anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME): - image = load_image(image, force_background='white', mode='RGB') - input_ = _img_encode(image)[None, ...] - output, = _open_anime_classify_model(model_name).run(['output'], {'input': input_}) - - return output +_REPO_ID = 'deepghs/anime_classification' def anime_classify_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]: @@ -111,9 +67,7 @@ def anime_classify_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NA >>> anime_classify_score('classify/illustration/12.jpg') {'3d': 3.153582292725332e-05, 'bangumi': 0.0001071861624950543, 'comic': 5.665345452143811e-05, 'illustration': 0.999804675579071} """ - output = _raw_anime_classify(image, model_name) - values = dict(zip(_LABELS, map(lambda x: x.item(), output[0]))) - return values + return classify_predict_score(image, _REPO_ID, model_name) def anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]: @@ -154,6 +108,4 @@ def anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> >>> anime_classify('classify/illustration/12.jpg') ('illustration', 0.999804675579071) """ - output = _raw_anime_classify(image, model_name)[0] - max_id = np.argmax(output) - return _LABELS[max_id], output[max_id].item() + return classify_predict(image, _REPO_ID, model_name) diff --git a/test/validate/test_classify.py b/test/validate/test_classify.py index 3285c8aee59..54e131c3579 100644 --- a/test/validate/test_classify.py +++ b/test/validate/test_classify.py @@ -3,8 +3,9 @@ import pytest +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_classify -from imgutils.validate.classify import _open_anime_classify_model, anime_classify_score +from imgutils.validate.classify import anime_classify_score, _REPO_ID from test.testings import get_testfile _ROOT_DIR = get_testfile('anime_cls') @@ -19,7 +20,7 @@ def _release_model_after_run(): try: yield finally: - _open_anime_classify_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() @pytest.mark.unittest From bfa22b1f7ab366f1f92f1d5ba7e98a0082102a0a Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 08:07:34 +0000 Subject: [PATCH 13/19] dev(narugo): auto sync Wed, 24 Jan 2024 08:07:34 +0000 --- .../bangumi_char_benchmark.plot.py.svg | 2284 +++++++++++++++++ 1 file changed, 2284 insertions(+) create mode 100644 docs/source/api_doc/validate/bangumi_char_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py.svg b/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py.svg new file mode 100644 index 00000000000..83b59829706 --- /dev/null +++ b/docs/source/api_doc/validate/bangumi_char_benchmark.plot.py.svg @@ -0,0 +1,2284 @@ + + + + + + + + 2024-01-24T08:07:18.507982 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From d4dee18cc7b49b4d388e14ca48de6c38c1032950 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 16:18:50 +0800 Subject: [PATCH 14/19] dev(narugo): update monochrome --- .../validate/monochrome_benchmark.plot.py | 23 +- .../validate/monochrome_benchmark.plot.py.svg | 2724 ----------------- imgutils/validate/monochrome.py | 70 +- test/validate/test_monochrome.py | 33 +- 4 files changed, 32 insertions(+), 2818 deletions(-) delete mode 100644 docs/source/api_doc/validate/monochrome_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/monochrome_benchmark.plot.py b/docs/source/api_doc/validate/monochrome_benchmark.plot.py index d80aea63935..2498b3d5caa 100644 --- a/docs/source/api_doc/validate/monochrome_benchmark.plot.py +++ b/docs/source/api_doc/validate/monochrome_benchmark.plot.py @@ -1,37 +1,34 @@ import random from benchmark import BaseBenchmark, create_plot_cli +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import get_monochrome_score +from imgutils.validate.monochrome import _REPO_ID + +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names class MonochromeBenchmark(BaseBenchmark): - def __init__(self, model, safe): + def __init__(self, model): BaseBenchmark.__init__(self) self.model = model - self.safe = safe def load(self): - from imgutils.validate.monochrome import _monochrome_validate_model - _ = _monochrome_validate_model(self.model, self.safe) + _open_models_for_repo_id(_REPO_ID)._open_model(self.model) def unload(self): - from imgutils.validate.monochrome import _monochrome_validate_model - _monochrome_validate_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def run(self): image_file = random.choice(self.all_images) - _ = get_monochrome_score(image_file, model=self.model, safe=self.safe) + _ = get_monochrome_score(image_file, model_name=self.model, safe=self.safe) if __name__ == '__main__': create_plot_cli( [ - ('caformer_s36 (unsafe)', MonochromeBenchmark('caformer_s36', False)), - ('caformer_s36 (safe)', MonochromeBenchmark('caformer_s36', True)), - ('mobilenetv3 (unsafe)', MonochromeBenchmark('mobilenetv3', False)), - ('mobilenetv3 (safe)', MonochromeBenchmark('mobilenetv3', True)), - ('mobilenetv3_dist (unsafe)', MonochromeBenchmark('mobilenetv3_dist', False)), - ('mobilenetv3_dist (safe)', MonochromeBenchmark('mobilenetv3_dist', True)), + (name, MonochromeBenchmark(name)) + for name in _MODEL_NAMES ], title='Benchmark for Monochrome Check Models', run_times=10, diff --git a/docs/source/api_doc/validate/monochrome_benchmark.plot.py.svg b/docs/source/api_doc/validate/monochrome_benchmark.plot.py.svg deleted file mode 100644 index 4889a265b10..00000000000 --- a/docs/source/api_doc/validate/monochrome_benchmark.plot.py.svg +++ /dev/null @@ -1,2724 +0,0 @@ - - - - - - - - 2023-06-07T11:12:11.923950 - image/svg+xml - - - Matplotlib v3.5.3, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/imgutils/validate/monochrome.py b/imgutils/validate/monochrome.py index 2b09bfcc999..bcc45d6b4bf 100644 --- a/imgutils/validate/monochrome.py +++ b/imgutils/validate/monochrome.py @@ -15,63 +15,26 @@ The models are hosted on `huggingface - deepghs/monochrome_detect `_. """ -from functools import lru_cache -from typing import Optional, Tuple, Mapping - -import numpy as np -from PIL import Image -from huggingface_hub import hf_hub_download - -from ..data import ImageTyping, load_image, rgb_encode -from ..utils import open_onnx_model +from ..data import ImageTyping +from ..generic import classify_predict_score, classify_predict __all__ = [ 'get_monochrome_score', 'is_monochrome', ] -_MODELS: Mapping[Tuple[str, bool], str] = { - ('caformer_s36', False): 'caformer_s36_plus', - ('caformer_s36', True): 'caformer_s36_plus_safe2', - ('mobilenetv3', False): 'mobilenetv3_large_100', - ('mobilenetv3', True): 'mobilenetv3_large_100_safe2', - ('mobilenetv3_dist', False): 'mobilenetv3_large_100_dist', - ('mobilenetv3_dist', True): 'mobilenetv3_large_100_dist_safe2', -} - - -@lru_cache() -def _monochrome_validate_model(model: str, safe: bool): - return open_onnx_model(hf_hub_download( - f'deepghs/monochrome_detect', - f'{_MODELS[(model, safe)]}/model.onnx', - )) - +_DEFAULT_MODEL_NAME = 'mobilenetv3_large_100_dist_safe2' +_REPO_ID = 'deepghs/monochrome_detect' -def _2d_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), - normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): - image = image.resize(size, Image.BILINEAR) - data = rgb_encode(image, order_='CHW') - if normalize is not None: - mean_, std_ = normalize - mean = np.asarray([mean_]).reshape((-1, 1, 1)) - std = np.asarray([std_]).reshape((-1, 1, 1)) - data = (data - mean) / std - - return data - - -def get_monochrome_score(image: ImageTyping, model: str = 'mobilenetv3_dist', safe: bool = True) -> float: +def get_monochrome_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> float: """ Overview: Get monochrome score of the given image. :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file. - :param model: The model used for inference. The default value is ``mobilenetv3_dist``, + :param model_name: The model used for inference. The default value is ``mobilenetv3_dist``, which offers high runtime performance. If you need better accuracy, just use ``caformer_s36``. - :param safe: Whether to enable the safe mode. When enabled, calculations will be performed using a model - with higher precision but lower recall. The default value is ``True``. Examples:: >>> from imgutils.validate import get_monochrome_score @@ -102,19 +65,11 @@ def get_monochrome_score(image: ImageTyping, model: str = 'mobilenetv3_dist', sa >>> get_monochrome_score('colored/12.jpg') 0.0315730981528759 """ - safe = bool(safe) - if (model, safe) not in _MODELS: - raise ValueError(f'Unknown model for monochrome detection - {model!r}, {safe!r}.') - - image = load_image(image, mode='RGB') - input_data = _2d_encode(image).astype(np.float32) - input_data = np.stack([input_data]) - output_data, = _monochrome_validate_model(model, safe).run(['output'], {'input': input_data}) - return output_data[0][0].item() + return classify_predict_score(image, _REPO_ID, model_name)['monochrome'] def is_monochrome(image: ImageTyping, threshold: float = 0.5, - model: str = 'mobilenetv3_dist', safe: bool = True) -> bool: + model_name: str = _DEFAULT_MODEL_NAME) -> bool: """ Overview: Predict if the image is monochrome. @@ -122,12 +77,8 @@ def is_monochrome(image: ImageTyping, threshold: float = 0.5, :param image: Image to predict, can be a ``PIL.Image`` object or the path of the image file. :param threshold: Threshold value during prediction. If the score is higher than the threshold, the image will be classified as monochrome. - :param model: The model used for inference. The default value is ``mobilenetv3_dist``, + :param model_name: The model used for inference. The default value is ``mobilenetv3_dist``, which offers high runtime performance. If you need better accuracy, just use ``caformer_s36``. - :param safe: Safe level, with optional values including ``0``, ``2``, and ``4``, - corresponding to different levels of the model. The default value is 2. - For more technical details about this model, please refer to: - https://huggingface.co/deepghs/imgutils-models#monochrome . Examples: >>> import os @@ -158,4 +109,5 @@ def is_monochrome(image: ImageTyping, threshold: float = 0.5, >>> is_monochrome('colored/12.jpg') False """ - return get_monochrome_score(image, model, safe) >= threshold + type_, _ = classify_predict(image, _REPO_ID, model_name) + return type_ == 'monochrome' diff --git a/test/validate/test_monochrome.py b/test/validate/test_monochrome.py index 31aa900e53d..d6e168f2722 100644 --- a/test/validate/test_monochrome.py +++ b/test/validate/test_monochrome.py @@ -3,7 +3,10 @@ import pytest from hbutils.testing import tmatrix -from imgutils.validate.monochrome import get_monochrome_score, is_monochrome, _monochrome_validate_model +from imgutils.generic.classify import _open_models_for_repo_id +from imgutils.validate.monochrome import get_monochrome_score, is_monochrome, _REPO_ID + +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names @pytest.fixture(scope='module', autouse=True) @@ -11,7 +14,7 @@ def _release_model_after_run(): try: yield finally: - _monochrome_validate_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def get_samples(): @@ -36,27 +39,13 @@ def get_samples(): class TestValidateMonochrome: @pytest.mark.parametrize(*tmatrix({ ('type_', 'file'): get_samples(), - ('model', 'safe'): [ - ('caformer_s36', False), - ('caformer_s36', True), - ('mobilenetv3', False), - ('mobilenetv3', True), - ('mobilenetv3_dist', False), - ('mobilenetv3_dist', True), - ], + 'model_name': _MODEL_NAMES, })) - def test_monochrome_test(self, type_: str, file: str, model: str, safe: bool): + def test_monochrome_test(self, type_: str, file: str, model_name: str): filename = os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', type_, file) if type_ == 'monochrome': - assert get_monochrome_score(filename, model=model, safe=safe) >= 0.5 - assert is_monochrome(filename, model=model, safe=safe) + assert get_monochrome_score(filename, model_name=model_name) >= 0.5 + assert is_monochrome(filename, model_name=model_name) else: - assert get_monochrome_score(filename, model=model, safe=safe) <= 0.5 - assert not is_monochrome(filename, model=model, safe=safe) - - def test_monochrome_test_with_unknown_safe(self): - with pytest.raises(ValueError): - _ = get_monochrome_score( - os.path.join('test', 'testfile', 'dataset', 'monochrome_danbooru', 'normal', '2475192.jpg'), - model='Model not found', - ) + assert get_monochrome_score(filename, model_name=model_name) <= 0.5 + assert not is_monochrome(filename, model_name=model_name) From 22aeec4f649c0fdbf157d7ecb57d46d5b2b98acc Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 16:28:41 +0800 Subject: [PATCH 15/19] dev(narugo): update monochrome --- docs/source/api_doc/validate/monochrome_benchmark.plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/api_doc/validate/monochrome_benchmark.plot.py b/docs/source/api_doc/validate/monochrome_benchmark.plot.py index 2498b3d5caa..1f27cb971f0 100644 --- a/docs/source/api_doc/validate/monochrome_benchmark.plot.py +++ b/docs/source/api_doc/validate/monochrome_benchmark.plot.py @@ -21,7 +21,7 @@ def unload(self): def run(self): image_file = random.choice(self.all_images) - _ = get_monochrome_score(image_file, model_name=self.model, safe=self.safe) + _ = get_monochrome_score(image_file, model_name=self.model) if __name__ == '__main__': From 4c9d8a7f8fc7661b3b3656679be8c97fe2b81b71 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 08:46:27 +0000 Subject: [PATCH 16/19] dev(narugo): auto sync Wed, 24 Jan 2024 08:46:27 +0000 --- .../validate/classify_benchmark.plot.py.svg | 2708 +++++++++++++++ .../validate/monochrome_benchmark.plot.py.svg | 2910 +++++++++++++++++ 2 files changed, 5618 insertions(+) create mode 100644 docs/source/api_doc/validate/classify_benchmark.plot.py.svg create mode 100644 docs/source/api_doc/validate/monochrome_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/classify_benchmark.plot.py.svg b/docs/source/api_doc/validate/classify_benchmark.plot.py.svg new file mode 100644 index 00000000000..8bbbc626d7f --- /dev/null +++ b/docs/source/api_doc/validate/classify_benchmark.plot.py.svg @@ -0,0 +1,2708 @@ + + + + + + + + 2024-01-24T08:37:24.525815 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/api_doc/validate/monochrome_benchmark.plot.py.svg b/docs/source/api_doc/validate/monochrome_benchmark.plot.py.svg new file mode 100644 index 00000000000..2d2b77b42e3 --- /dev/null +++ b/docs/source/api_doc/validate/monochrome_benchmark.plot.py.svg @@ -0,0 +1,2910 @@ + + + + + + + + 2024-01-24T08:46:09.911930 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From c90127c5f4235d0d3217743139a12feda5b38363 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 16:55:00 +0800 Subject: [PATCH 17/19] dev(narugo): replace for rating and teen --- .../api_doc/validate/rating_benchmark.plot.py | 11 +- .../validate/rating_benchmark.plot.py.svg | 2508 ----------------- .../api_doc/validate/teen_benchmark.plot.py | 11 +- .../validate/teen_benchmark.plot.py.svg | 2238 --------------- imgutils/validate/rating.py | 68 +- imgutils/validate/teen.py | 55 +- test/validate/test_rating.py | 5 +- test/validate/test_teen.py | 5 +- 8 files changed, 31 insertions(+), 4870 deletions(-) delete mode 100644 docs/source/api_doc/validate/rating_benchmark.plot.py.svg delete mode 100644 docs/source/api_doc/validate/teen_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/rating_benchmark.plot.py b/docs/source/api_doc/validate/rating_benchmark.plot.py index 74c5b4564fb..9d5169139cd 100644 --- a/docs/source/api_doc/validate/rating_benchmark.plot.py +++ b/docs/source/api_doc/validate/rating_benchmark.plot.py @@ -1,8 +1,11 @@ import random from benchmark import BaseBenchmark, create_plot_cli +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_rating -from imgutils.validate.rating import _MODEL_NAMES +from imgutils.validate.rating import _REPO_ID + +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names class AnimeRatingBenchmark(BaseBenchmark): @@ -11,12 +14,10 @@ def __init__(self, model): self.model = model def load(self): - from imgutils.validate.rating import _open_anime_rating_model - _ = _open_anime_rating_model(self.model) + _open_models_for_repo_id(_REPO_ID)._open_model(self.model) def unload(self): - from imgutils.validate.rating import _open_anime_rating_model - _open_anime_rating_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def run(self): image_file = random.choice(self.all_images) diff --git a/docs/source/api_doc/validate/rating_benchmark.plot.py.svg b/docs/source/api_doc/validate/rating_benchmark.plot.py.svg deleted file mode 100644 index fc628ce51e6..00000000000 --- a/docs/source/api_doc/validate/rating_benchmark.plot.py.svg +++ /dev/null @@ -1,2508 +0,0 @@ - - - - - - - - 2023-06-26T08:26:25.152202 - image/svg+xml - - - Matplotlib v3.7.1, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/docs/source/api_doc/validate/teen_benchmark.plot.py b/docs/source/api_doc/validate/teen_benchmark.plot.py index e2fe744f162..80ec1a700b2 100644 --- a/docs/source/api_doc/validate/teen_benchmark.plot.py +++ b/docs/source/api_doc/validate/teen_benchmark.plot.py @@ -1,8 +1,11 @@ import random from benchmark import BaseBenchmark, create_plot_cli +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_teen -from imgutils.validate.teen import _MODEL_NAMES +from imgutils.validate.teen import _REPO_ID + +_MODEL_NAMES = _open_models_for_repo_id(_REPO_ID).model_names class AnimeTeenBenchmark(BaseBenchmark): @@ -11,12 +14,10 @@ def __init__(self, model): self.model = model def load(self): - from imgutils.validate.teen import _open_anime_teen_model - _ = _open_anime_teen_model(self.model) + _open_models_for_repo_id(_REPO_ID)._open_model(self.model) def unload(self): - from imgutils.validate.teen import _open_anime_teen_model - _open_anime_teen_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() def run(self): image_file = random.choice(self.all_images) diff --git a/docs/source/api_doc/validate/teen_benchmark.plot.py.svg b/docs/source/api_doc/validate/teen_benchmark.plot.py.svg deleted file mode 100644 index cb95da5990f..00000000000 --- a/docs/source/api_doc/validate/teen_benchmark.plot.py.svg +++ /dev/null @@ -1,2238 +0,0 @@ - - - - - - - - 2023-08-27T10:05:50.143385 - image/svg+xml - - - Matplotlib v3.7.2, https://matplotlib.org/ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/imgutils/validate/rating.py b/imgutils/validate/rating.py index c1d8ceb9048..3ef1ca52976 100644 --- a/imgutils/validate/rating.py +++ b/imgutils/validate/rating.py @@ -25,68 +25,18 @@ it is recommended to consider using object detection-based methods**, such as using :func:`imgutils.detect.censor.detect_censors` to detect sensitive regions as the basis for judgment. """ -import json -from functools import lru_cache -from typing import Tuple, Optional, Dict, List +from typing import Tuple, Dict -import numpy as np -from PIL import Image -from huggingface_hub import hf_hub_download - -from imgutils.data import rgb_encode, ImageTyping, load_image -from imgutils.utils import open_onnx_model +from ..data import ImageTyping +from ..generic import classify_predict, classify_predict_score __all__ = [ 'anime_rating_score', 'anime_rating', ] -_MODEL_NAMES = [ - 'caformer_s36_plus', - 'mobilenetv3', - 'mobilenetv3_sce', - 'mobilenetv3_sce_dist', -] -_DEFAULT_MODEL_NAME = 'mobilenetv3_sce_dist' - - -@lru_cache() -def _open_anime_rating_model(model_name): - return open_onnx_model(hf_hub_download( - f'deepghs/anime_rating', - f'{model_name}/model.onnx', - )) - - -@lru_cache() -def _open_anime_rating_labels(model_name) -> List[str]: - with open(hf_hub_download( - f'deepghs/anime_rating', - f'{model_name}/meta.json', - ), 'r') as f: - return json.load(f)['labels'] - - -def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), - normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): - image = image.resize(size, Image.BILINEAR) - data = rgb_encode(image, order_='CHW') - - if normalize is not None: - mean_, std_ = normalize - mean = np.asarray([mean_]).reshape((-1, 1, 1)) - std = np.asarray([std_]).reshape((-1, 1, 1)) - data = (data - mean) / std - - return data.astype(np.float32) - - -def _raw_anime_rating(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME): - image = load_image(image, force_background='white', mode='RGB') - input_ = _img_encode(image)[None, ...] - output, = _open_anime_rating_model(model_name).run(['output'], {'input': input_}) - - return output +_DEFAULT_MODEL_NAME = 'mobilenetv3_v1_pruned_ls0.1' +_REPO_ID = 'deepghs/anime_rating' def anime_rating_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]: @@ -127,9 +77,7 @@ def anime_rating_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME >>> anime_rating_score('rating/r18/12.jpg') {'safe': 6.902020231791539e-06, 'r15': 0.0005639699520543218, 'r18': 0.9994290471076965} """ - output = _raw_anime_rating(image, model_name) - values = dict(zip(_open_anime_rating_labels(model_name), map(lambda x: x.item(), output[0]))) - return values + return classify_predict_score(image, _REPO_ID, model_name) def anime_rating(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]: @@ -170,6 +118,4 @@ def anime_rating(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> T >>> anime_rating('rating/r18/12.jpg') ('r18', 0.9994290471076965) """ - output = _raw_anime_rating(image, model_name)[0] - max_id = np.argmax(output) - return _open_anime_rating_labels(model_name)[max_id], output[max_id].item() + return classify_predict(image, _REPO_ID, model_name) diff --git a/imgutils/validate/teen.py b/imgutils/validate/teen.py index 346360cac03..caae3cb4dba 100644 --- a/imgutils/validate/teen.py +++ b/imgutils/validate/teen.py @@ -15,57 +15,18 @@ The models are hosted on `huggingface - deepghs/anime_teen `_. """ -from functools import lru_cache -from typing import Tuple, Optional, Dict +from typing import Tuple, Dict -import numpy as np -from PIL import Image -from huggingface_hub import hf_hub_download - -from imgutils.data import rgb_encode, ImageTyping, load_image -from imgutils.utils import open_onnx_model +from ..data import ImageTyping +from ..generic import classify_predict, classify_predict_score __all__ = [ 'anime_teen_score', 'anime_teen', ] -_LABELS = ["contentious", "safe_teen", "non_teen"] -_MODEL_NAMES = [ - 'caformer_s36_v0', - 'mobilenetv3_v0_dist', -] _DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist' - - -@lru_cache() -def _open_anime_teen_model(model_name): - return open_onnx_model(hf_hub_download( - f'deepghs/anime_teen', - f'{model_name}/model.onnx', - )) - - -def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), - normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): - image = image.resize(size, Image.BILINEAR) - data = rgb_encode(image, order_='CHW') - - if normalize is not None: - mean_, std_ = normalize - mean = np.asarray([mean_]).reshape((-1, 1, 1)) - std = np.asarray([std_]).reshape((-1, 1, 1)) - data = (data - mean) / std - - return data.astype(np.float32) - - -def _raw_anime_teen(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME): - image = load_image(image, force_background='white', mode='RGB') - input_ = _img_encode(image)[None, ...] - output, = _open_anime_teen_model(model_name).run(['output'], {'input': input_}) - - return output +_REPO_ID = 'deepghs/anime_teen' def anime_teen_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]: @@ -100,9 +61,7 @@ def anime_teen_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) >>> anime_teen_score('teen/non_teen/9.jpg') {'contentious': 0.0001218809193233028, 'safe_teen': 0.00013706681784242392, 'non_teen': 0.9997410178184509} """ - output = _raw_anime_teen(image, model_name) - values = dict(zip(_LABELS, map(lambda x: x.item(), output[0]))) - return values + return classify_predict_score(image, _REPO_ID, model_name) def anime_teen(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]: @@ -137,6 +96,4 @@ def anime_teen(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tup >>> anime_teen('teen/non_teen/9.jpg') ('non_teen', 0.9997410178184509) """ - output = _raw_anime_teen(image, model_name)[0] - max_id = np.argmax(output) - return _LABELS[max_id], output[max_id].item() + return classify_predict(image, _REPO_ID, model_name) diff --git a/test/validate/test_rating.py b/test/validate/test_rating.py index a0cbbd20834..278c681e0db 100644 --- a/test/validate/test_rating.py +++ b/test/validate/test_rating.py @@ -3,8 +3,9 @@ import pytest +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_rating -from imgutils.validate.rating import _open_anime_rating_model, anime_rating_score +from imgutils.validate.rating import anime_rating_score, _REPO_ID from test.testings import get_testfile _ROOT_DIR = get_testfile('rating') @@ -19,7 +20,7 @@ def _release_model_after_run(): try: yield finally: - _open_anime_rating_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() @pytest.mark.unittest diff --git a/test/validate/test_teen.py b/test/validate/test_teen.py index 39703f88a6c..98f5d5f7e5a 100644 --- a/test/validate/test_teen.py +++ b/test/validate/test_teen.py @@ -3,8 +3,9 @@ import pytest +from imgutils.generic.classify import _open_models_for_repo_id from imgutils.validate import anime_teen -from imgutils.validate.teen import _open_anime_teen_model, anime_teen_score +from imgutils.validate.teen import anime_teen_score, _REPO_ID from test.testings import get_testfile _ROOT_DIR = get_testfile('anime_teen') @@ -19,7 +20,7 @@ def _release_model_after_run(): try: yield finally: - _open_anime_teen_model.cache_clear() + _open_models_for_repo_id(_REPO_ID).clear() @pytest.mark.unittest From cb7f0c6e9f22607571597784d2dec2e598e4c0f6 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 09:10:38 +0000 Subject: [PATCH 18/19] dev(narugo): auto sync Wed, 24 Jan 2024 09:10:38 +0000 --- .../validate/rating_benchmark.plot.py.svg | 2860 +++++++++++++++++ .../validate/teen_benchmark.plot.py.svg | 2262 +++++++++++++ 2 files changed, 5122 insertions(+) create mode 100644 docs/source/api_doc/validate/rating_benchmark.plot.py.svg create mode 100644 docs/source/api_doc/validate/teen_benchmark.plot.py.svg diff --git a/docs/source/api_doc/validate/rating_benchmark.plot.py.svg b/docs/source/api_doc/validate/rating_benchmark.plot.py.svg new file mode 100644 index 00000000000..be01ede4830 --- /dev/null +++ b/docs/source/api_doc/validate/rating_benchmark.plot.py.svg @@ -0,0 +1,2860 @@ + + + + + + + + 2024-01-24T09:10:21.671070 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/api_doc/validate/teen_benchmark.plot.py.svg b/docs/source/api_doc/validate/teen_benchmark.plot.py.svg new file mode 100644 index 00000000000..298c343daec --- /dev/null +++ b/docs/source/api_doc/validate/teen_benchmark.plot.py.svg @@ -0,0 +1,2262 @@ + + + + + + + + 2024-01-24T08:59:30.421082 + image/svg+xml + + + Matplotlib v3.7.4, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + From ee1850bd13e6bafe8241c514d2da1cbdd12b9a59 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 24 Jan 2024 17:28:10 +0800 Subject: [PATCH 19/19] dev(narugo): fix docs --- imgutils/validate/rating.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/imgutils/validate/rating.py b/imgutils/validate/rating.py index 3ef1ca52976..675f996b2cf 100644 --- a/imgutils/validate/rating.py +++ b/imgutils/validate/rating.py @@ -4,8 +4,10 @@ The following are sample images for testing. - .. image:: rating.plot.py.svg - :align: center + .. collapse:: The following are sample images for testing. (WARNING: NSFW!!!) + + .. image:: rating.plot.py.svg + :align: center This is an overall benchmark of all the rating validation models: