diff --git a/main/.buildinfo b/main/.buildinfo index 11c86e2..40fd496 100644 --- a/main/.buildinfo +++ b/main/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: 57fdf4dd42e72896657d499e3bb89d8c +config: ea264bc1e360522a0c758173b4372310 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/main/.doctrees/api_doc/config/index.doctree b/main/.doctrees/api_doc/config/index.doctree index f40c864..2e098fd 100644 Binary files a/main/.doctrees/api_doc/config/index.doctree and b/main/.doctrees/api_doc/config/index.doctree differ diff --git a/main/.doctrees/api_doc/config/meta.doctree b/main/.doctrees/api_doc/config/meta.doctree index 341b336..dcca50a 100644 Binary files a/main/.doctrees/api_doc/config/meta.doctree and b/main/.doctrees/api_doc/config/meta.doctree differ diff --git a/main/.doctrees/api_doc/controllability/bikini_plus.doctree b/main/.doctrees/api_doc/controllability/bikini_plus.doctree index ae49fa4..fec62e5 100644 Binary files a/main/.doctrees/api_doc/controllability/bikini_plus.doctree and b/main/.doctrees/api_doc/controllability/bikini_plus.doctree differ diff --git a/main/.doctrees/api_doc/controllability/index.doctree b/main/.doctrees/api_doc/controllability/index.doctree index bea8988..9ea0946 100644 Binary files a/main/.doctrees/api_doc/controllability/index.doctree and b/main/.doctrees/api_doc/controllability/index.doctree differ diff --git a/main/.doctrees/api_doc/corrupt/aicorrupt.doctree b/main/.doctrees/api_doc/corrupt/aicorrupt.doctree index df4225e..c00f187 100644 Binary files a/main/.doctrees/api_doc/corrupt/aicorrupt.doctree and b/main/.doctrees/api_doc/corrupt/aicorrupt.doctree differ diff --git a/main/.doctrees/api_doc/corrupt/index.doctree b/main/.doctrees/api_doc/corrupt/index.doctree index 76a8e02..effd74f 100644 Binary files a/main/.doctrees/api_doc/corrupt/index.doctree and b/main/.doctrees/api_doc/corrupt/index.doctree differ diff --git a/main/.doctrees/api_doc/fidelity/ccip.doctree b/main/.doctrees/api_doc/fidelity/ccip.doctree index e2e8637..80d018b 100644 Binary files a/main/.doctrees/api_doc/fidelity/ccip.doctree and b/main/.doctrees/api_doc/fidelity/ccip.doctree differ diff --git a/main/.doctrees/api_doc/fidelity/index.doctree b/main/.doctrees/api_doc/fidelity/index.doctree index af0d609..b0b1200 100644 Binary files a/main/.doctrees/api_doc/fidelity/index.doctree and b/main/.doctrees/api_doc/fidelity/index.doctree differ diff --git a/main/.doctrees/api_doc/utils/images.doctree b/main/.doctrees/api_doc/utils/images.doctree index 2bce83b..c7eba9c 100644 Binary files a/main/.doctrees/api_doc/utils/images.doctree and b/main/.doctrees/api_doc/utils/images.doctree differ diff --git a/main/.doctrees/api_doc/utils/index.doctree b/main/.doctrees/api_doc/utils/index.doctree index 53b3973..80b60c9 100644 Binary files a/main/.doctrees/api_doc/utils/index.doctree and b/main/.doctrees/api_doc/utils/index.doctree differ diff --git a/main/.doctrees/api_doc/utils/tqdm_.doctree b/main/.doctrees/api_doc/utils/tqdm_.doctree index 410251d..f4d9e37 100644 Binary files a/main/.doctrees/api_doc/utils/tqdm_.doctree and b/main/.doctrees/api_doc/utils/tqdm_.doctree differ diff --git a/main/.doctrees/environment.pickle b/main/.doctrees/environment.pickle index 3cb9a4e..2fa02f1 100644 Binary files a/main/.doctrees/environment.pickle and b/main/.doctrees/environment.pickle differ diff --git a/main/.doctrees/index.doctree b/main/.doctrees/index.doctree index f573a11..85bf27d 100644 Binary files a/main/.doctrees/index.doctree and b/main/.doctrees/index.doctree differ diff --git a/main/.doctrees/information/environment.doctree b/main/.doctrees/information/environment.doctree index 45c38c5..3500abd 100644 Binary files a/main/.doctrees/information/environment.doctree and b/main/.doctrees/information/environment.doctree differ diff --git a/main/.doctrees/information/environment.result.doctree b/main/.doctrees/information/environment.result.doctree index b170d28..1c59413 100644 Binary files a/main/.doctrees/information/environment.result.doctree and b/main/.doctrees/information/environment.result.doctree differ diff --git a/main/.doctrees/nbsphinx/information/environment.ipynb b/main/.doctrees/nbsphinx/information/environment.ipynb index 7d53a08..3c240ed 100644 --- a/main/.doctrees/nbsphinx/information/environment.ipynb +++ b/main/.doctrees/nbsphinx/information/environment.ipynb @@ -29,10 +29,10 @@ "execution_count": 1, "metadata": { "execution": { - "iopub.execute_input": "2024-01-26T08:06:18.831727Z", - "iopub.status.busy": "2024-01-26T08:06:18.831141Z", - "iopub.status.idle": "2024-01-26T08:06:19.998236Z", - "shell.execute_reply": "2024-01-26T08:06:19.997488Z" + "iopub.execute_input": "2024-01-26T08:20:05.389650Z", + "iopub.status.busy": "2024-01-26T08:20:05.389096Z", + "iopub.status.idle": "2024-01-26T08:20:06.564449Z", + "shell.execute_reply": "2024-01-26T08:20:06.563678Z" }, "pycharm": { "name": "#%%\n" @@ -53,7 +53,7 @@ "text": [ "CPU Brand: AMD EPYC 7763 64-Core Processor\n", "CPU Count: 4\n", - "CPU Freq: 2668.35625 MHz\n", + "CPU Freq: 2978.52125 MHz\n", "Memory Size: 15.607 GiB\n", "Has CUDA: No\n" ] diff --git a/main/.doctrees/nbsphinx/information/environment.result.ipynb b/main/.doctrees/nbsphinx/information/environment.result.ipynb index 1fea79d..6cf39b0 100644 --- a/main/.doctrees/nbsphinx/information/environment.result.ipynb +++ b/main/.doctrees/nbsphinx/information/environment.result.ipynb @@ -29,10 +29,10 @@ "execution_count": 1, "metadata": { "execution": { - "iopub.execute_input": "2024-01-26T08:05:29.611106Z", - "iopub.status.busy": "2024-01-26T08:05:29.610900Z", - "iopub.status.idle": "2024-01-26T08:05:30.786598Z", - "shell.execute_reply": "2024-01-26T08:05:30.785857Z" + "iopub.execute_input": "2024-01-26T08:18:43.263327Z", + "iopub.status.busy": "2024-01-26T08:18:43.263125Z", + "iopub.status.idle": "2024-01-26T08:18:44.441801Z", + "shell.execute_reply": "2024-01-26T08:18:44.441088Z" }, "pycharm": { "name": "#%%\n" @@ -53,7 +53,7 @@ "text": [ "CPU Brand: AMD EPYC 7763 64-Core Processor\n", "CPU Count: 4\n", - "CPU Freq: 2723.5545 MHz\n", + "CPU Freq: 2907.154 MHz\n", "Memory Size: 15.607 GiB\n", "Has CUDA: No\n" ] diff --git a/main/.doctrees/tutorials/installation/index.doctree b/main/.doctrees/tutorials/installation/index.doctree index 328eb99..8cb6fc7 100644 Binary files a/main/.doctrees/tutorials/installation/index.doctree and b/main/.doctrees/tutorials/installation/index.doctree differ diff --git a/main/.doctrees/tutorials/quick_start/index.doctree b/main/.doctrees/tutorials/quick_start/index.doctree index 3015c04..dc8edc1 100644 Binary files a/main/.doctrees/tutorials/quick_start/index.doctree and b/main/.doctrees/tutorials/quick_start/index.doctree differ diff --git a/main/_modules/index.html b/main/_modules/index.html index 987ed1c..7eaff5d 100644 --- a/main/_modules/index.html +++ b/main/_modules/index.html @@ -213,6 +213,7 @@
CPU Brand: AMD EPYC 7763 64-Core Processor CPU Count: 4 -CPU Freq: 2668.35625 MHz +CPU Freq: 2978.52125 MHz Memory Size: 15.607 GiB Has CUDA: No
CPU Brand: AMD EPYC 7763 64-Core Processor CPU Count: 4 -CPU Freq: 2723.5545 MHz +CPU Freq: 2907.154 MHz Memory Size: 15.607 GiB Has CUDA: No
CPU Brand: AMD EPYC 7763 64-Core Processor CPU Count: 4 -CPU Freq: 2907.7349999999997 MHz +CPU Freq: 2920.74625 MHz Memory Size: 15.607 GiB Has CUDA: No
CPU Brand: AMD EPYC 7763 64-Core Processor CPU Count: 4 -CPU Freq: 2914.88725 MHz +CPU Freq: 2730.90625 MHz Memory Size: 15.607 GiB Has CUDA: No
CPU Brand: AMD EPYC 7763 64-Core Processor CPU Count: 4 -CPU Freq: 2919.1737500000004 MHz +CPU Freq: 2775.81975 MHz Memory Size: 15.607 GiB Has CUDA: No
CPU Brand: AMD EPYC 7763 64-Core Processor CPU Count: 4 -CPU Freq: 2915.3095 MHz +CPU Freq: 2966.2889999999998 MHz Memory Size: 15.607 GiB Has CUDA: No
CPU Brand: AMD EPYC 7763 64-Core Processor CPU Count: 4 -CPU Freq: 2915.3940000000002 MHz +CPU Freq: 2942.42 MHz Memory Size: 15.607 GiB Has CUDA: No
CPU Brand: AMD EPYC 7763 64-Core Processor CPU Count: 4 -CPU Freq: 2920.75925 MHz +CPU Freq: 2879.0470000000005 MHz Memory Size: 15.607 GiB Has CUDA: No
+"""
+Overview:
+ Bikini plus score.
+"""
+import json
+import os
+import re
+import warnings
+import weakref
+from functools import lru_cache, partial
+from queue import Queue
+from typing import Optional, Tuple, List, Iterator, Union
+
+import numpy as np
+from PIL import Image, UnidentifiedImageError
+from hbutils.string import singular_form
+from huggingface_hub import hf_hub_download
+from imgutils.data import load_image
+from imgutils.sd import get_sdmeta_from_image
+from imgutils.tagging import get_deepdanbooru_tags, get_wd14_tags, get_mldanbooru_tags
+
+from ..utils import tqdm
+
+try:
+ from typing import Literal
+except (ImportError, ModuleNotFoundError):
+ from typing_extensions import Literal
+
+_NOTHING = object()
+
+
+class ACNode:
+ def __init__(self, segments: Tuple[str, ...],
+ value=_NOTHING, fail_ref: 'ACNode' = None, is_root: bool = False):
+ """
+ AC (Aho-Corasick) Node class for fast tag matching.
+
+ :param segments: Tuple of segments representing a path in the AC trie.
+ :type segments: Tuple[str, ...]
+ :param value: Value associated with the node, if any.
+ :param fail_ref: Reference to the fail node.
+ :type fail_ref: ACNode
+ :param is_root: Indicates if the node is the root of the AC trie.
+ :type is_root: bool
+ """
+ self.segments = segments
+ if value is _NOTHING:
+ self.has_value, self.value = False, None
+ else:
+ self.has_value, self.value = True, value
+
+ self.children = {}
+ self._fail_ref: Optional[weakref.ref] = None
+ if is_root:
+ self.fail = self
+ else:
+ if fail_ref is None:
+ raise ValueError('Fail reference not given for non-root node.') # pragma: no cover
+ self.fail = fail_ref
+
+ @property
+ def fail(self) -> Optional['ACNode']:
+ """
+ Get the fail node reference.
+
+ :return: Reference to the fail node.
+ :rtype: Optional[ACNode]
+ """
+ if self._fail_ref is None:
+ return None # pragma: no cover
+ else:
+ return self._fail_ref()
+
+ @fail.setter
+ def fail(self, node: 'ACNode'):
+ """
+ Set the fail node.
+
+ :param node: The fail node.
+ :type node: ACNode
+ """
+ self._fail_ref = weakref.ref(node)
+
+
+@lru_cache()
+def _tag_list(tagger_name: str):
+ """
+ Get the tag list for a given tagger.
+
+ :param tagger_name: Name of the tagger.
+ :type tagger_name: str
+
+ :return: List of tags.
+ :rtype: List[dict]
+ """
+ with open(hf_hub_download(
+ 'deepghs/tagger_vocabs',
+ filename=f'{tagger_name}/tags.json',
+ repo_type='dataset'
+ ), 'r', encoding='utf-8') as f:
+ return json.load(f)
+
+
+def _tokenize(text: str):
+ """
+ Tokenize the given text.
+
+ :param text: Input text.
+ :type text: str
+
+ :return: List of tokens.
+ :rtype: List[str]
+ """
+ return [singular_form(word) for word in re.split(r'[\W_]+', text) if word]
+
+
+class TaggerACModel:
+ def __init__(self, tagger_name: str):
+ """
+ Aho-Corasick (AC) model for fast tag matching.
+
+ :param tagger_name: Name of the tagger.
+ :type tagger_name: str
+ """
+ self._root_node = ACNode((), is_root=True)
+ counts = []
+ for tag in _tag_list(tagger_name):
+ words_list = tag['words']
+ for words in words_list:
+ node = self._root_node
+ for i, word in enumerate(words, start=1):
+ if word not in node.children:
+ node.children[word] = ACNode(
+ segments=tuple(words[:i]),
+ value=tag if i == len(words) else _NOTHING,
+ fail_ref=self._root_node,
+ )
+ node = node.children[word]
+ counts.append(tag['count'])
+
+ queue = Queue()
+ queue.put(self._root_node)
+ while not queue.empty():
+ current_node: ACNode = queue.get()
+ for key, child in current_node.children.items():
+ if current_node is not self._root_node and key in current_node.fail.children:
+ child.fail = current_node.fail.children[key]
+
+ queue.put(child)
+
+ counts = np.array(counts)
+ self._counts = np.log(counts[counts > 0])
+ self._mean_count = np.percentile(self._counts, 75).item()
+
+ def extract_tags_from_text(self, text: str) -> List[Tuple[str, float]]:
+ """
+ Extract tags from the given text.
+
+ :param text: Input text.
+ :type text: str
+
+ :return: List of tuples containing tags and their values.
+ :rtype: List[Tuple[str, float]]
+ """
+ tokens = _tokenize(text)
+ _exist_names = set()
+ retval = []
+ node = self._root_node
+ for token in tokens:
+ while node is not self._root_node:
+ if token not in node.children:
+ node = node.fail
+ else:
+ break
+
+ if token in node.children:
+ node = node.children[token]
+ cur_node = node
+ while cur_node is not self._root_node:
+ if cur_node.has_value:
+ tag_info = cur_node.value
+ if tag_info['name'] not in _exist_names:
+ value = np.log(tag_info['count']).item() - self._mean_count
+ retval.append((tag_info['name'], value))
+ _exist_names.add(tag_info['name'])
+ cur_node = cur_node.fail
+
+ else:
+ node = self._root_node
+
+ return sorted(retval, key=lambda x: (-x[1], x[0]))
+
+ def get_tag(self, tag_text):
+ """
+ Get the tag from the given tag text.
+
+ :param tag_text: Input tag text.
+ :type tag_text: str
+
+ :return: Tag name.
+ :rtype: str
+ """
+ tokens = _tokenize(tag_text)
+ node = self._root_node
+ for token in tokens:
+ if token not in node.children:
+ raise ValueError(f'Unknown tag {tag_text!r}.')
+
+ node = node.children[token]
+
+ if node.has_value:
+ return node.value['name']
+ else:
+ raise ValueError(f'Unknown tag {tag_text!r}.')
+
+
+def _deepdanbooru_tagging(image: Image.Image, use_real_name: bool = False,
+ general_threshold: float = 0.0, character_threshold: float = 0.0, **kwargs):
+ _ = kwargs
+ _, features, characters = get_deepdanbooru_tags(image, use_real_name, general_threshold, character_threshold)
+ return {**features, **characters}
+
+
+def _wd14_tagging(image: Image.Image, model_name: str,
+ general_threshold: float = 0.0, character_threshold: float = 0.0, **kwargs):
+ _ = kwargs
+ _, features, characters = get_wd14_tags(image, model_name, general_threshold, character_threshold)
+ return {**features, **characters}
+
+
+def _mldanbooru_tagging(image: Image.Image, use_real_name: bool = False, general_threshold: float = 0.0, **kwargs):
+ _ = kwargs
+ features = get_mldanbooru_tags(image, use_real_name, general_threshold)
+ return features
+
+
+_WD14_TAGGER_MODELS = {
+ "wd14_swinv2": "wd-v1-4-swinv2-tagger-v2",
+ "wd14_convnext": "wd-v1-4-convnext-tagger-v2",
+ "wd14_convnextv2": "wd-v1-4-convnextv2-tagger-v2",
+ "wd14_vit": "wd-v1-4-vit-tagger-v2",
+ "wd14_moat": "wd-v1-4-moat-tagger-v2",
+}
+_TAGGING_METHODS = {
+ 'deepdanbooru': _deepdanbooru_tagging,
+ 'wd14_vit': partial(_wd14_tagging, model_name='ViT'),
+ 'wd14_convnext': partial(_wd14_tagging, model_name='ConvNext'),
+ 'wd14_convnextv2': partial(_wd14_tagging, model_name='ConvNextV2'),
+ 'wd14_swinv2': partial(_wd14_tagging, model_name='SwinV2'),
+ 'wd14_moat': partial(_wd14_tagging, model_name='MOAT'),
+ 'mldanbooru': _mldanbooru_tagging,
+}
+
+TaggingMethodTyping = Literal[
+ 'deepdanbooru', 'mldanbooru',
+ 'wd14_vit', 'wd14_convnext', 'wd14_convnextv2', 'wd14_swinv2', 'wd14_moat',
+]
+
+PromptedImageTyping = Union[
+ Tuple[Image.Image, str, str], Tuple[Image.Image, str], Image.Image,
+ Tuple[str, str, str], Tuple[str, str], str
+]
+PromptedImagesTyping = Union[PromptedImageTyping, List[PromptedImageTyping]]
+
+
+def _yield_images(images: PromptedImagesTyping) -> Iterator[Tuple[Image.Image, str, str]]:
+ """
+ Yield images along with prompts and negative prompts.
+
+ :param images: Input images with prompts and negative prompts.
+ :type images: PromptedImagesTyping
+
+ :return: Iterator of image, prompt, and negative prompt tuples.
+ :rtype: Iterator[Tuple[Image.Image, str, str]]
+ """
+ if isinstance(images, list):
+ for item in images:
+ yield from _yield_images(item)
+ elif isinstance(images, str) and os.path.isdir(images):
+ for root, dirs, files in os.walk(images):
+ for file in files:
+ yield from _yield_images(os.path.join(root, file))
+ else:
+ if isinstance(images, tuple):
+ if len(images) == 2:
+ (img, prompt), neg_prompt = images, ''
+ elif len(images) == 3:
+ img, prompt, neg_prompt = images
+ else:
+ raise TypeError(f'Unknown tuple for prompted image - {images!r}.')
+ img = load_image(img)
+ img.load()
+
+ else:
+ try:
+ img = load_image(images, force_background=None, mode=None)
+ img.load()
+
+ sdmeta = get_sdmeta_from_image(img)
+ if sdmeta is None:
+ prompt, neg_prompt = '', ''
+ else:
+ prompt, neg_prompt = sdmeta.prompt, sdmeta.neg_prompt
+ except UnidentifiedImageError:
+ return
+
+ yield img, prompt, neg_prompt
+
+
+[docs]class BikiniPlusMetrics:
+ """
+ Class for evaluating the appropriateness of AI-generated images based on prompts and taggers.
+
+ The `BikiniPlusMetrics` class assesses the compatibility of AI-generated images with given prompts using taggers.
+
+ :param tagger: The tagging method to use. Default is 'wd14_convnextv2'.
+ :type tagger: TaggingMethodTyping
+ :param tagger_cfgs: Optional configuration parameters for the chosen tagger. Default is None.
+ :type tagger_cfgs: Optional[dict]
+ :param base_num: Base number for weighting prompt tags. Default is 1.5.
+ :type base_num: float
+ :param tag_blacklist: Optional list of tags to exclude from evaluation. Default is None.
+ :type tag_blacklist: Optional[List[str]]
+ :param silent: If True, suppresses progress bars and additional output during calculation. Default is False.
+ :type silent: bool
+ """
+
+[docs] def __init__(self, tagger: TaggingMethodTyping = 'wd14_convnextv2',
+ tagger_cfgs: Optional[dict] = None, base_num: float = 1.5,
+ tag_blacklist: Optional[List[str]] = None, silent: bool = False):
+ self.tagger = tagger
+ self._tagger_func = partial(_TAGGING_METHODS[tagger], **(tagger_cfgs or {}))
+ self._ac_model = TaggerACModel(_WD14_TAGGER_MODELS.get(tagger, tagger))
+ self._base_num = base_num
+
+ self._tag_blacklist_set = set()
+ _unknown_blacklist_tags = set()
+ for tag in (tag_blacklist or []):
+ try:
+ self._tag_blacklist_set.add(self._ac_model.get_tag(tag))
+ except ValueError:
+ _unknown_blacklist_tags.add(tag)
+ if _unknown_blacklist_tags:
+ warnings.warn(f'Unknown tags for blacklist: {sorted(_unknown_blacklist_tags)}.')
+ self.silent = silent
+
+ def _calculate_one_image(self, img: Image.Image, prompt: str, neg_prompt: str):
+ """
+ Calculate the bikini plus score for a single image.
+
+ This method computes the bikini plus score for a single image based on the provided prompts.
+
+ :param img: The input image.
+ :type img: Image.Image
+ :param prompt: The positive prompt for evaluation.
+ :type prompt: str
+ :param neg_prompt: The negative prompt for evaluation.
+ :type neg_prompt: str
+
+ :return: The calculated bikini plus score for the image.
+ :rtype: float
+ """
+ prompt_tags = self._ac_model.extract_tags_from_text(prompt)
+ prompt_tags = [(tag, value) for tag, value in prompt_tags if tag not in self._tag_blacklist_set]
+ neg_prompt_tags = self._ac_model.extract_tags_from_text(neg_prompt)
+ neg_prompt_tags = [(tag, value) for tag, value in neg_prompt_tags if tag not in self._tag_blacklist_set]
+ tagged_tags = self._tagger_func(img)
+
+ if not prompt_tags and not neg_prompt_tags:
+ return 1.0
+
+ vs = np.array([
+ *(tagged_tags.get(tag, 0.0) for tag, value in prompt_tags),
+ *((1.0 - tagged_tags.get(tag, 0.0)) for tag, value in neg_prompt_tags),
+ ])
+ ws = np.array([
+ *((self._base_num ** value) for tag, value in prompt_tags),
+ *((self._base_num ** value) for tag, value in neg_prompt_tags),
+ ])
+ return ((vs * ws).sum() / ws.sum()).item()
+
+[docs] def score(self, images: PromptedImagesTyping, silent: bool = False,
+ mode: Literal['mean', 'seq'] = 'mean') -> Union[float, np.ndarray]:
+ """
+ Calculate the average bikini plus score for a set of images.
+
+ This method computes the average bikini plus score for a set of images based on the provided prompts.
+
+ :param images: The set of images with associated positive and negative prompts.
+ :type images: PromptedImagesTyping
+ :param silent: If True, suppresses progress bars and additional output during calculation. Default is False.
+ :type silent: bool
+ :param mode: Mode of the return value. Return a float value when ``mean`` is assigned,
+ return a numpy array when ``seq`` is assigned. Default is ``mean``.
+ :type mode: Literal['mean', 'seq']
+
+ :return: The average bikini plus score for the set of images.
+ :rtype: Union[float, np.ndarray]
+ """
+ image_list = list(_yield_images(images))
+ if not image_list:
+ raise FileNotFoundError(f'Images for calculating bikini plus score not provided - {images}.')
+
+ score = np.array([
+ self._calculate_one_image(img, prompt, neg_prompt)
+ for img, prompt, neg_prompt in tqdm(image_list, silent=self.silent if silent is None else silent)
+ ])
+ assert score.shape == (len(image_list),)
+
+ if mode == 'seq':
+ return score
+ else:
+ return score.mean().item()
+
+"""
+Overview:
+ AI image corrupt evaluation metrics.
+"""
+import json
+from functools import lru_cache
+from typing import Tuple, Optional, Mapping, Literal, Union
+
+import numpy as np
+from PIL import Image
+from huggingface_hub import hf_hub_download
+from imgutils.data import rgb_encode, ImageTyping, load_image
+from imgutils.utils import open_onnx_model
+
+from ..utils import ImagesTyping, load_images, tqdm
+
+_DEFAULT_MODEL_NAME = 'caformer_s36_v0_focal'
+
+
+@lru_cache()
+def _open_anime_aicop_model(model_name: str):
+ """
+ Open the AI image corrupted detection model.
+
+ This function downloads and opens the AI image corrupted detection model specified by the given model name using Hugging Face Hub.
+
+ :param model_name: The name of the AI image corrupted detection model.
+ :type model_name: str
+
+ :return: The opened AI image corrupted detection model.
+ :rtype: Model
+ """
+ return open_onnx_model(hf_hub_download(
+ f'deepghs/ai_image_corrupted',
+ f'{model_name}/model.onnx',
+ ))
+
+
+@lru_cache()
+def _open_anime_aicop_meta(model_name: str):
+ """
+ Open the meta information of the AI image corrupted detection model.
+
+ This function downloads and opens the meta information of the AI image corrupted detection model specified by the given model name using Hugging Face Hub.
+
+ :param model_name: The name of the AI image corrupted detection model.
+ :type model_name: str
+
+ :return: The opened meta information of the AI image corrupted detection model.
+ :rtype: dict
+ """
+ with open(hf_hub_download(
+ f'deepghs/ai_image_corrupted',
+ f'{model_name}/meta.json',
+ ), 'r', encoding='utf-8') as f:
+ return json.load(f)
+
+
+@lru_cache()
+def _open_anime_aicop_labels(model_name: str):
+ """
+ Open the labels of the AI image corrupted detection model.
+
+ This function opens the labels of the AI image corrupted detection model specified by the given model name.
+
+ :param model_name: The name of the AI image corrupted detection model.
+ :type model_name: str
+
+ :return: The labels of the AI image corrupted detection model.
+ :rtype: List[str]
+ """
+ return _open_anime_aicop_meta(model_name)['labels']
+
+
+def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
+ normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
+ """
+ Encode the image for AI image corrupted detection.
+
+ This function resizes and encodes the image for AI image corrupted detection.
+
+ :param image: The input image.
+ :type image: Image.Image
+ :param size: The target size for encoding. Default is (384, 384).
+ :type size: Tuple[int, int]
+ :param normalize: The normalization parameters. Default is (0.5, 0.5).
+ :type normalize: Optional[Tuple[float, float]]
+
+ :return: The encoded image data.
+ :rtype: np.ndarray
+ """
+ image = image.resize(size, Image.BILINEAR)
+ data = rgb_encode(image, order_='CHW')
+
+ if normalize is not None:
+ mean_, std_ = normalize
+ mean = np.asarray([mean_]).reshape((-1, 1, 1))
+ std = np.asarray([std_]).reshape((-1, 1, 1))
+ data = (data - mean) / std
+
+ return data.astype(np.float32)
+
+
+[docs]def get_ai_corrupted(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Mapping[str, float]:
+ """
+ Get AI image corrupted detection scores for an image.
+
+ This function calculates AI image corrupted detection scores for a given image using the specified model.
+
+ :param image: The input image.
+ :type image: ImageTyping
+ :param model_name: The name of the AI image corrupted detection model. Default is 'caformer_s36_v0_focal'.
+ :type model_name: str
+
+ :return: A dictionary containing the corrupted score.
+ :rtype: Mapping[str, float]
+ """
+ image = load_image(image, force_background='white', mode='RGB')
+ input_ = _img_encode(image)[None, ...]
+ output, = _open_anime_aicop_model(model_name).run(['output'], {'input': input_})
+ return dict(zip(_open_anime_aicop_labels(model_name), output[0].tolist()))
+
+
+[docs]class AICorruptMetrics:
+ """
+ Class for calculating an AI image corruptness score.
+
+ The `AICorruptMetrics` class allows you to calculate an AI image corruptness score using the AI image corrupted detection model.
+
+ :param model_name: The name of the AI image corrupted detection model. Default is 'caformer_s36_v0_focal'.
+ :type model_name: str
+ :param silent: If True, suppresses progress bars and additional output during calculation.
+ :type silent: bool
+ :param tqdm_desc: Description for the tqdm progress bar during calculation.
+ :type tqdm_desc: str
+ """
+
+[docs] def __init__(self, model_name: str = _DEFAULT_MODEL_NAME,
+ silent: bool = False, tqdm_desc: str = None):
+ self._model_name = model_name
+ self.silent = silent
+ self.tqdm_desc = tqdm_desc or self.__class__.__name__
+
+[docs] def score(self, images: ImagesTyping, silent: bool = None,
+ mode: Literal['mean', 'seq'] = 'mean') -> Union[float, np.ndarray]:
+ """
+ Calculate the AI image corruptness score for a set of images.
+
+ This method calculates the AI image corruptness score for a set of input images using the AI image corrupted detection model.
+
+ :param images: The set of input images for calculating the AI image corruptness score.
+ :type images: ImagesTyping
+ :param silent: If True, suppresses progress bars and additional output during calculation.
+ :type silent: bool
+ :param mode: Mode of the return value. Return a float value when ``mean`` is assigned,
+ return a numpy array when ``seq`` is assigned. Default is ``mean``.
+ :type mode: Literal['mean', 'seq']
+
+ :return: The AI image corruptness score.
+ :rtype: Union[float, np.ndarray]
+ """
+ image_list = load_images(images)
+ if not image_list:
+ raise FileNotFoundError(f'Images for calculating AI corrupt score not provided - {images}.')
+
+ scores = 1.0 - np.array([
+ get_ai_corrupted(image, model_name=self._model_name)['corrupted']
+ for image in tqdm(image_list, silent=self.silent if silent is None else silent, desc=self.tqdm_desc)
+ ])
+ assert scores.shape == (len(image_list),)
+
+ if mode == 'seq':
+ return scores
+ else:
+ return scores.mean().item()
+
+"""
+Overview:
+ CCIP-based metrics for anime character training.
+
+ See `imgutils.metrics.ccip <https://deepghs.github.io/imgutils/main/api_doc/metrics/ccip.html>`_ for more information.
+"""
+import warnings
+from typing import List, Optional, Literal, Union
+
+import numpy as np
+from PIL import Image
+from imgutils.metrics import ccip_extract_feature, ccip_default_threshold, ccip_batch_differences
+
+from ..utils import load_images, ImagesTyping, tqdm
+
+_DEFAULT_CCIP_MODEL = 'ccip-caformer-24-randaug-pruned'
+
+
+[docs]class CCIPMetrics:
+ """
+ Class for calculating similarity scores between images using the CCIP (Content-Consistent Image Pairwise) metric.
+
+ The `CCIPMetrics` class allows you to calculate the similarity score between a set of images and a reference dataset using the CCIP metric.
+
+ :param images: The reference dataset of images for initializing CCIP metrics.
+ :type images: ImagesTyping
+ :param feats: Feature data of given character, should be (B, 768). When assigned, ``images`` argument will be ignored.
+ :type feats: Optional[np.ndarray]
+ :param model: The CCIP model to use for feature extraction. Default is 'ccip-caformer-24-randaug-pruned'.
+ :type model: str
+ :param threshold: The threshold for the CCIP metric. If not provided, the default threshold for the chosen model is used.
+ :type threshold: Optional[float]
+ :param silent: If True, suppresses progress bars and additional output during initialization and calculation.
+ :type silent: bool
+ :param tqdm_desc: Description for the tqdm progress bar during initialization and calculation.
+ :type tqdm_desc: str
+ """
+
+[docs] def __init__(self, images: ImagesTyping, feats: Optional[np.ndarray] = None, model: str = _DEFAULT_CCIP_MODEL,
+ threshold: Optional[float] = None, silent: bool = False, tqdm_desc: str = None):
+ self.silent = silent
+ self.tqdm_desc = tqdm_desc or self.__class__.__name__
+ self._ccip_model = model
+ self._threshold = ccip_default_threshold(self._ccip_model) if threshold is None else threshold
+
+ if feats is None:
+ image_list: List[Image.Image] = load_images(images)
+ if not image_list:
+ raise FileNotFoundError(f'Images for initializing CCIP metrics not provided - {images}.')
+ self._features = [
+ ccip_extract_feature(img, model=self._ccip_model)
+ for img in tqdm(image_list, silent=self.silent, desc=f'{self.tqdm_desc} Initializing')
+ ]
+
+ else:
+ if images:
+ warnings.warn(f'Binary features assigned, images {images!r} will be ignored.')
+ if len(feats.shape) != 2 or feats.shape[-1] != 768:
+ raise ValueError(f'Feature shape should be (B, 768), but {feats.shape!r} found actually.')
+ self._features = list(feats)
+
+[docs] def score(self, images: ImagesTyping, silent: bool = None,
+ mode: Literal['mean', 'seq'] = 'mean') -> Union[float, np.ndarray]:
+ """
+ Calculate the similarity score between the reference dataset and a set of input images.
+
+ This method calculates the similarity score between the reference dataset (used for initialization) and a set of input images using the CCIP metric.
+
+ :param images: The set of input images for calculating CCIP metrics.
+ :type images: ImagesTyping
+ :param silent: If True, suppresses progress bars and additional output during calculation.
+ :type silent: bool
+ :param mode: Mode of the return value. Return a float value when ``mean`` is assigned,
+ return a numpy array when ``seq`` is assigned. Default is ``mean``.
+ :type mode: Literal['mean', 'seq']
+
+ :return: The similarity score between the reference dataset and the input images.
+ :rtype: Union[float, np.ndarray]
+ """
+ image_list: List[Image.Image] = load_images(images)
+ if not image_list:
+ raise FileNotFoundError(f'Images for calculating CCIP metrics not provided - {images}.')
+
+ _features = [
+ ccip_extract_feature(img, model=self._ccip_model)
+ for img in tqdm(image_list, silent=self.silent if silent is None else silent,
+ desc=f'{self.tqdm_desc} Calculating')
+ ]
+
+ diffs = ccip_batch_differences([*self._features, *_features])
+ matrix = diffs[:len(self._features), len(self._features):]
+ seq = (matrix < self._threshold).mean(axis=0)
+ assert seq.shape == (len(_features),)
+
+ if mode == 'seq':
+ return seq
+ else:
+ return seq.mean().item()
+
+import os.path
+from typing import List, Iterator, Union
+
+from PIL import UnidentifiedImageError, Image
+from imgutils.data import load_image
+
+ImagesTyping = Union[Image.Image, str, List[Union[Image.Image, str]]]
+
+
+def _yield_images(images: ImagesTyping) -> Iterator[Image.Image]:
+ """
+ Yield PIL.Image objects from various sources.
+
+ This internal function yields PIL.Image objects from a variety of sources, including PIL.Image objects, file paths, and lists of images. It supports recursive loading of images from directories.
+
+ :param images: An image or a list of images (PIL.Image, file paths, or a combination).
+ :type images: ImagesTyping
+
+ :return: An iterator yielding PIL.Image objects.
+ :rtype: Iterator[Image.Image]
+ """
+ if isinstance(images, list):
+ for item in images:
+ yield from _yield_images(item)
+ elif isinstance(images, str) and os.path.isdir(images):
+ for root, dirs, files in os.walk(images):
+ for file in files:
+ yield from _yield_images(os.path.join(root, file))
+ else:
+ try:
+ image = load_image(images)
+ image.load()
+ yield image
+ except UnidentifiedImageError:
+ pass
+
+
+[docs]def load_images(images: ImagesTyping) -> List[Image.Image]:
+ """
+ Load multiple PIL.Image objects from various sources.
+
+ This function loads multiple PIL.Image objects from a variety of sources, including PIL.Image objects, file paths, and lists of images. It supports recursive loading of images from directories.
+
+ :param images: An image or a list of images (PIL.Image, file paths, or a combination).
+ :type images: ImagesTyping
+
+ :return: A list of PIL.Image objects.
+ :rtype: List[Image.Image]
+ """
+ return list(_yield_images(images))
+
+import io
+
+from tqdm.auto import tqdm as _origin_tqdm
+
+__all__ = ['tqdm']
+
+
+[docs]def tqdm(*args, silent: bool = False, **kwargs):
+ """
+ An enhanced version of tqdm (progress bar) with an option to silence the output.
+
+ This function modifies the behavior of tqdm to allow silencing the progress bar.
+
+ :param args: Positional arguments to be passed to tqdm.
+ :param silent: If True, the progress bar content will not be displayed.
+ :type silent: bool
+ :param kwargs: Additional keyword arguments to be passed to tqdm.
+ :return: tqdm progress bar.
+ :rtype: tqdm.std.tqdm
+ """
+ with io.StringIO() as sio:
+ if silent:
+ kwargs['file'] = sio
+
+ return _origin_tqdm(*args, **kwargs)
+