Skip to content

Commit

Permalink
dev(narugo): add support for anime_teen
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Aug 27, 2023
1 parent 9b269f8 commit bba9d23
Show file tree
Hide file tree
Showing 25 changed files with 251 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/api_doc/validate/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ imgutils.validate
monochrome
nsfw
rating
teen
truncate
14 changes: 14 additions & 0 deletions docs/source/api_doc/validate/teen.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import glob
import os.path

from natsort import natsorted

from plot import image_plot

if __name__ == '__main__':
image_plot(
*natsorted(glob.glob(os.path.join('teen', 'contentious', '*.jpg'))),
*natsorted(glob.glob(os.path.join('teen', 'safe_teen', '*.jpg'))),
*natsorted(glob.glob(os.path.join('teen', 'non_teen', '*.jpg'))),
columns=3, figsize=(8, 10),
)
21 changes: 21 additions & 0 deletions docs/source/api_doc/validate/teen.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
imgutils.validate.teen
=============================================

.. currentmodule:: imgutils.validate.teen

.. automodule:: imgutils.validate.teen


anime_teen_score
-----------------------------

.. autofunction:: anime_teen_score



anime_teen
-----------------------------

.. autofunction:: anime_teen


Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/api_doc/validate/teen/non_teen/7.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/api_doc/validate/teen/non_teen/8.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/api_doc/validate/teen/non_teen/9.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/api_doc/validate/teen/safe_teen/4.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/api_doc/validate/teen/safe_teen/6.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions docs/source/api_doc/validate/teen_benchmark.plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import random

from benchmark import BaseBenchmark, create_plot_cli
from imgutils.validate import anime_teen
from imgutils.validate.teen import _MODEL_NAMES


class AnimeTeenBenchmark(BaseBenchmark):
def __init__(self, model):
BaseBenchmark.__init__(self)
self.model = model

def load(self):
from imgutils.validate.teen import _open_anime_teen_model
_ = _open_anime_teen_model(self.model)

def unload(self):
from imgutils.validate.teen import _open_anime_teen_model
_open_anime_teen_model.cache_clear()

def run(self):
image_file = random.choice(self.all_images)
_ = anime_teen(image_file, self.model)


if __name__ == '__main__':
create_plot_cli(
[
(name, AnimeTeenBenchmark(name))
for name in _MODEL_NAMES
],
title='Benchmark for Anime teen Models',
run_times=10,
try_times=20,
)()
1 change: 1 addition & 0 deletions imgutils/validate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
from .monochrome import *
from .nsfw import *
from .rating import *
from .teen import *
from .truncate import *
142 changes: 142 additions & 0 deletions imgutils/validate/teen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Overview:
A model for classifying teen anime images into 4 classes (``contentious``, ``safe_teen``, ``non_teen"``).
The following are sample images for testing.
.. image:: teen.plot.py.svg
:align: center
This is an overall benchmark of all the classification validation models:
.. image:: teen_benchmark.plot.py.svg
:align: center
The models are hosted on
`huggingface - deepghs/anime_teen <https://huggingface.co/deepghs/anime_teen>`_.
"""
from functools import lru_cache
from typing import Tuple, Optional, Dict

import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download

from imgutils.data import rgb_encode, ImageTyping, load_image
from imgutils.utils import open_onnx_model

__all__ = [
'anime_teen_score',
'anime_teen',
]

_LABELS = ["contentious", "safe_teen", "non_teen"]
_MODEL_NAMES = [
'caformer_s36_v0',
'mobilenetv3_v0_dist',
]
_DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist'


@lru_cache()
def _open_anime_teen_model(model_name):
return open_onnx_model(hf_hub_download(
f'deepghs/anime_teen',
f'{model_name}/model.onnx',
))


def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
image = image.resize(size, Image.BILINEAR)
data = rgb_encode(image, order_='CHW')

if normalize is not None:
mean_, std_ = normalize
mean = np.asarray([mean_]).reshape((-1, 1, 1))
std = np.asarray([std_]).reshape((-1, 1, 1))
data = (data - mean) / std

return data.astype(np.float32)


def _raw_anime_teen(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME):
image = load_image(image, force_background='white', mode='RGB')
input_ = _img_encode(image)[None, ...]
output, = _open_anime_teen_model(model_name).run(['output'], {'input': input_})

return output


def anime_teen_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]:
"""
Overview:
Predict the class of the given image, return the score with as a dict object.
:param image: Image to teen.
:param model_name: Model to use. Default is ``mobilenetv3_v0_dist``. All available models are listed
on the benchmark plot above. If you need better accuracy, just set this to ``caformer_s36_plus``.
:return: A dict with classes and scores.
Examples::
>>> from imgutils.validate import anime_teen_score
>>>
>>> anime_teen_score('teen/contentious/1.jpg')
{'contentious': 0.9998493194580078, 'safe_teen': 3.0378791052498855e-05, 'non_teen': 0.00012023092131130397}
>>> anime_teen_score('teen/contentious/2.jpg')
{'contentious': 0.9790042638778687, 'safe_teen': 0.0017522255657240748, 'non_teen': 0.01924353837966919}
>>> anime_teen_score('teen/contentious/3.jpg')
{'contentious': 0.9998124241828918, 'safe_teen': 4.19778298237361e-05, 'non_teen': 0.0001456339523429051}
>>> anime_teen_score('teen/safe_teen/4.jpg')
{'contentious': 0.0008521362324245274, 'safe_teen': 0.9989691972732544, 'non_teen': 0.00017870066221803427}
>>> anime_teen_score('teen/safe_teen/5.jpg')
{'contentious': 6.0992944781901315e-05, 'safe_teen': 0.9994398951530457, 'non_teen': 0.0004991036257706583}
>>> anime_teen_score('teen/safe_teen/6.jpg')
{'contentious': 5.2035720727872103e-05, 'safe_teen': 0.9994019269943237, 'non_teen': 0.0005460577667690814}
>>> anime_teen_score('teen/non_teen/7.jpg')
{'contentious': 3.0478151529678144e-05, 'safe_teen': 3.524079147609882e-05, 'non_teen': 0.999934196472168}
>>> anime_teen_score('teen/non_teen/8.jpg')
{'contentious': 9.786742884898558e-05, 'safe_teen': 8.653994154883549e-05, 'non_teen': 0.9998156428337097}
>>> anime_teen_score('teen/non_teen/9.jpg')
{'contentious': 0.0001218809193233028, 'safe_teen': 0.00013706681784242392, 'non_teen': 0.9997410178184509}
"""
output = _raw_anime_teen(image, model_name)
values = dict(zip(_LABELS, map(lambda x: x.item(), output[0])))
return values


def anime_teen(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]:
"""
Overview:
Predict the class of the given image, return the class and its score.
:param image: Image to teen.
:param model_name: Model to use. Default is ``mobilenetv3_sce_dist``. All available models are listed
on the benchmark plot above. If you need better accuracy, just set this to ``caformer_s36_plus``.
:return: A tuple contains the class and its score.
Examples::
>>> from imgutils.validate import anime_teen
>>>
>>> anime_teen('teen/contentious/1.jpg')
('contentious', 0.9998493194580078)
>>> anime_teen('teen/contentious/2.jpg')
('contentious', 0.9790042638778687)
>>> anime_teen('teen/contentious/3.jpg')
('contentious', 0.9998124241828918)
>>> anime_teen('teen/safe_teen/4.jpg')
('safe_teen', 0.9989691972732544)
>>> anime_teen('teen/safe_teen/5.jpg')
('safe_teen', 0.9994398951530457)
>>> anime_teen('teen/safe_teen/6.jpg')
('safe_teen', 0.9994019269943237)
>>> anime_teen('teen/non_teen/7.jpg')
('non_teen', 0.999934196472168)
>>> anime_teen('teen/non_teen/8.jpg')
('non_teen', 0.9998156428337097)
>>> anime_teen('teen/non_teen/9.jpg')
('non_teen', 0.9997410178184509)
"""
output = _raw_anime_teen(image, model_name)[0]
max_id = np.argmax(output)
return _LABELS[max_id], output[max_id].item()
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
37 changes: 37 additions & 0 deletions test/validate/test_teen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import glob
import os.path

import pytest

from imgutils.validate import anime_teen
from imgutils.validate.teen import _open_anime_teen_model, anime_teen_score
from test.testings import get_testfile

_ROOT_DIR = get_testfile('anime_teen')
_EXAMPLE_FILES = [
(os.path.relpath(file, _ROOT_DIR), os.path.basename(os.path.dirname(file)))
for file in glob.glob(get_testfile('anime_teen', '**', '*.jpg'), recursive=True)
]


@pytest.fixture(scope='module', autouse=True)
def _release_model_after_run():
try:
yield
finally:
_open_anime_teen_model.cache_clear()


@pytest.mark.unittest
class TestValidateTeen:
@pytest.mark.parametrize(['image', 'label'], _EXAMPLE_FILES)
def test_anime_teen(self, image, label):
image_file = get_testfile('anime_teen', image)
tag, score = anime_teen(image_file)
assert tag == label

@pytest.mark.parametrize(['image', 'label'], _EXAMPLE_FILES)
def test_anime_teen_score(self, image, label):
image_file = get_testfile('anime_teen', image)
scores = anime_teen_score(image_file)
assert scores[label] > 0.5

0 comments on commit bba9d23

Please sign in to comment.