From 92bb03e12955746a0bb94df103e943bbe604b66e Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 15 Nov 2023 22:43:25 +0800 Subject: [PATCH] dev(narugo): add corrupt docs --- docs/source/api_doc/corrupt/aicorrupt.rst | 22 +++++ docs/source/api_doc/corrupt/index.rst | 12 +++ docs/source/index.rst | 1 + sdeval/corrupt/aicorrupt.py | 97 ++++++++++++++++++++++- 4 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 docs/source/api_doc/corrupt/aicorrupt.rst create mode 100644 docs/source/api_doc/corrupt/index.rst diff --git a/docs/source/api_doc/corrupt/aicorrupt.rst b/docs/source/api_doc/corrupt/aicorrupt.rst new file mode 100644 index 0000000..c53b313 --- /dev/null +++ b/docs/source/api_doc/corrupt/aicorrupt.rst @@ -0,0 +1,22 @@ +sdeval.corrupt.aicorrupt +================================= + +.. currentmodule:: sdeval.corrupt.aicorrupt + +.. automodule:: sdeval.corrupt.aicorrupt + + +get_ai_corrupted +--------------------------------- + +.. autofunction:: get_ai_corrupted + + + +AICorruptMetrics +-------------------------------- + +.. autoclass:: AICorruptMetrics + :members: __init__, score + + diff --git a/docs/source/api_doc/corrupt/index.rst b/docs/source/api_doc/corrupt/index.rst new file mode 100644 index 0000000..3da1dfe --- /dev/null +++ b/docs/source/api_doc/corrupt/index.rst @@ -0,0 +1,12 @@ +sdeval.corrupt +===================== + +.. currentmodule:: sdeval.corrupt + +.. automodule:: sdeval.corrupt + + +.. toctree:: + :maxdepth: 3 + + aicorrupt diff --git a/docs/source/index.rst b/docs/source/index.rst index a2ccbdb..a71f1aa 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -29,6 +29,7 @@ configuration file's structure and their versions. :caption: API Documentation api_doc/config/index + api_doc/corrupt/index api_doc/fidelity/index api_doc/utils/index diff --git a/sdeval/corrupt/aicorrupt.py b/sdeval/corrupt/aicorrupt.py index 766ed32..a509ec0 100644 --- a/sdeval/corrupt/aicorrupt.py +++ b/sdeval/corrupt/aicorrupt.py @@ -1,3 +1,7 @@ +""" +Overview: + AI image corrupt evaluation metrics. +""" import json from functools import lru_cache from typing import Tuple, Optional, Mapping @@ -14,7 +18,18 @@ @lru_cache() -def _open_anime_aicop_model(model_name): +def _open_anime_aicop_model(model_name: str): + """ + Open the AI image corrupted detection model. + + This function downloads and opens the AI image corrupted detection model specified by the given model name using Hugging Face Hub. + + :param model_name: The name of the AI image corrupted detection model. + :type model_name: str + + :return: The opened AI image corrupted detection model. + :rtype: Model + """ return open_onnx_model(hf_hub_download( f'deepghs/ai_image_corrupted', f'{model_name}/model.onnx', @@ -22,7 +37,18 @@ def _open_anime_aicop_model(model_name): @lru_cache() -def _open_anime_aicop_meta(model_name): +def _open_anime_aicop_meta(model_name: str): + """ + Open the meta information of the AI image corrupted detection model. + + This function downloads and opens the meta information of the AI image corrupted detection model specified by the given model name using Hugging Face Hub. + + :param model_name: The name of the AI image corrupted detection model. + :type model_name: str + + :return: The opened meta information of the AI image corrupted detection model. + :rtype: dict + """ with open(hf_hub_download( f'deepghs/ai_image_corrupted', f'{model_name}/meta.json', @@ -31,12 +57,38 @@ def _open_anime_aicop_meta(model_name): @lru_cache() -def _open_anime_aicop_labels(model_name): +def _open_anime_aicop_labels(model_name: str): + """ + Open the labels of the AI image corrupted detection model. + + This function opens the labels of the AI image corrupted detection model specified by the given model name. + + :param model_name: The name of the AI image corrupted detection model. + :type model_name: str + + :return: The labels of the AI image corrupted detection model. + :rtype: List[str] + """ return _open_anime_aicop_meta(model_name)['labels'] def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), normalize: Optional[Tuple[float, float]] = (0.5, 0.5)): + """ + Encode the image for AI image corrupted detection. + + This function resizes and encodes the image for AI image corrupted detection. + + :param image: The input image. + :type image: Image.Image + :param size: The target size for encoding. Default is (384, 384). + :type size: Tuple[int, int] + :param normalize: The normalization parameters. Default is (0.5, 0.5). + :type normalize: Optional[Tuple[float, float]] + + :return: The encoded image data. + :rtype: np.ndarray + """ image = image.resize(size, Image.BILINEAR) data = rgb_encode(image, order_='CHW') @@ -50,6 +102,19 @@ def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384), def get_ai_corrupted(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Mapping[str, float]: + """ + Get AI image corrupted detection scores for an image. + + This function calculates AI image corrupted detection scores for a given image using the specified model. + + :param image: The input image. + :type image: ImageTyping + :param model_name: The name of the AI image corrupted detection model. Default is 'caformer_s36_v0_focal'. + :type model_name: str + + :return: A dictionary containing the corrupted score. + :rtype: Mapping[str, float] + """ image = load_image(image, force_background='white', mode='RGB') input_ = _img_encode(image)[None, ...] output, = _open_anime_aicop_model(model_name).run(['output'], {'input': input_}) @@ -57,6 +122,19 @@ def get_ai_corrupted(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) class AICorruptMetrics: + """ + Class for calculating an AI image corruptness score. + + The `AICorruptMetrics` class allows you to calculate an AI image corruptness score using the AI image corrupted detection model. + + :param model_name: The name of the AI image corrupted detection model. Default is 'caformer_s36_v0_focal'. + :type model_name: str + :param silent: If True, suppresses progress bars and additional output during calculation. + :type silent: bool + :param tqdm_desc: Description for the tqdm progress bar during calculation. + :type tqdm_desc: str + """ + def __init__(self, model_name: str = _DEFAULT_MODEL_NAME, silent: bool = False, tqdm_desc: str = None): self._model_name = model_name @@ -64,6 +142,19 @@ def __init__(self, model_name: str = _DEFAULT_MODEL_NAME, self.tqdm_desc = tqdm_desc or self.__class__.__name__ def score(self, images: ImagesTyping, silent: bool = None): + """ + Calculate the AI image corruptness score for a set of images. + + This method calculates the AI image corruptness score for a set of input images using the AI image corrupted detection model. + + :param images: The set of input images for calculating the AI image corruptness score. + :type images: ImagesTyping + :param silent: If True, suppresses progress bars and additional output during calculation. + :type silent: bool + + :return: The AI image corruptness score. + :rtype: float + """ image_list = load_images(images) if not image_list: raise FileNotFoundError(f'Images for calculating AI corrupt score not provided - {images}.')