diff --git a/docs/source/api_doc/upscale/cdc.rst b/docs/source/api_doc/upscale/cdc.rst new file mode 100644 index 00000000000..9d76f491f55 --- /dev/null +++ b/docs/source/api_doc/upscale/cdc.rst @@ -0,0 +1,15 @@ +imgutils.upscale.cdc +==================================== + +.. currentmodule:: imgutils.upscale.cdc + +.. automodule:: imgutils.upscale.cdc + + +upscale_with_cdc +--------------------------- + +.. autofunction:: upscale_with_cdc + + + diff --git a/docs/source/api_doc/upscale/cdc_benchmark.plot.py b/docs/source/api_doc/upscale/cdc_benchmark.plot.py new file mode 100644 index 00000000000..d21553d3f34 --- /dev/null +++ b/docs/source/api_doc/upscale/cdc_benchmark.plot.py @@ -0,0 +1,44 @@ +import os.path +import random + +from huggingface_hub import HfFileSystem + +from benchmark import BaseBenchmark, create_plot_cli +from imgutils.upscale.cdc import upscale_with_cdc + +hf_fs = HfFileSystem() +repository = 'deepghs/cdc_anime_onnx' +_CDC_MODELS = [ + os.path.splitext(os.path.relpath(file, repository))[0] + for file in hf_fs.glob(f'{repository}/*.onnx') +] + + +class CDCUpscalerBenchmark(BaseBenchmark): + def __init__(self, model: str): + BaseBenchmark.__init__(self) + self.model = model + + def load(self): + from imgutils.upscale.cdc import _open_cdc_upscaler_model + _open_cdc_upscaler_model(self.model) + + def unload(self): + from imgutils.upscale.cdc import _open_cdc_upscaler_model + _open_cdc_upscaler_model.cache_clear() + + def run(self): + image_file = random.choice(self.all_images) + _ = upscale_with_cdc(image_file, model=self.model) + + +if __name__ == '__main__': + create_plot_cli( + [ + (model, CDCUpscalerBenchmark(model)) + for model in _CDC_MODELS + ], + title='Benchmark for CDCUpscaler Models', + run_times=3, + try_times=3, + )() diff --git a/docs/source/api_doc/upscale/cdc_benchmark.plot.py.svg b/docs/source/api_doc/upscale/cdc_benchmark.plot.py.svg new file mode 100644 index 00000000000..dba8439c9b5 --- /dev/null +++ b/docs/source/api_doc/upscale/cdc_benchmark.plot.py.svg @@ -0,0 +1,2154 @@ + + + + + + + + 2024-05-09T20:55:19.611881 + image/svg+xml + + + Matplotlib v3.7.5, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/api_doc/upscale/cdc_demo.plot.py b/docs/source/api_doc/upscale/cdc_demo.plot.py new file mode 100644 index 00000000000..6c0d6a03c19 --- /dev/null +++ b/docs/source/api_doc/upscale/cdc_demo.plot.py @@ -0,0 +1,36 @@ +import os + +from huggingface_hub import HfFileSystem + +from imgutils.upscale import upscale_with_cdc +from imgutils.upscale.cdc import _open_cdc_upscaler_model +from plot import image_plot + +hf_fs = HfFileSystem() +repository = 'deepghs/cdc_anime_onnx' +_CDC_MODELS = [ + os.path.splitext(os.path.relpath(file, repository))[0] + for file in hf_fs.glob(f'{repository}/*.onnx') +] + +if __name__ == '__main__': + demo_images = [ + ('sample/original.png', 'Small Logo'), + ('sample/skadi.jpg', 'Illustration'), + ('sample/hutao.png', 'Large Illustration'), + # ('sample/xx.jpg', 'Illustration #2'), + ('sample/rgba_restore.png', 'RGBA Artwork'), + ] + + items = [] + for file, title in demo_images: + items.append((file, title)) + for model in _CDC_MODELS: + _, scale = _open_cdc_upscaler_model(model) + items.append((upscale_with_cdc(file, model=model), f'{title}\n({scale}X By {model})')) + + image_plot( + *items, + columns=len(_CDC_MODELS) + 1, + figsize=(4 * (len(_CDC_MODELS) + 1), 3.5 * len(demo_images)), + ) diff --git a/docs/source/api_doc/upscale/cdc_demo.plot.py.svg b/docs/source/api_doc/upscale/cdc_demo.plot.py.svg new file mode 100644 index 00000000000..ac7836f1904 --- /dev/null +++ b/docs/source/api_doc/upscale/cdc_demo.plot.py.svg @@ -0,0 +1,1593 @@ + + + + + + + + 2024-05-09T21:56:50.492320 + image/svg+xml + + + Matplotlib v3.7.5, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/docs/source/api_doc/upscale/index.rst b/docs/source/api_doc/upscale/index.rst new file mode 100644 index 00000000000..7182d0cf806 --- /dev/null +++ b/docs/source/api_doc/upscale/index.rst @@ -0,0 +1,13 @@ +imgutils.upscale +======================== + +.. currentmodule:: imgutils.upscale + +.. automodule:: imgutils.upscale + + +.. toctree:: + :maxdepth: 3 + + cdc + diff --git a/docs/source/api_doc/upscale/sample/hutao.png b/docs/source/api_doc/upscale/sample/hutao.png new file mode 100644 index 00000000000..a8f42dce778 Binary files /dev/null and b/docs/source/api_doc/upscale/sample/hutao.png differ diff --git a/docs/source/api_doc/upscale/sample/original.png b/docs/source/api_doc/upscale/sample/original.png new file mode 100644 index 00000000000..4757b83a42f Binary files /dev/null and b/docs/source/api_doc/upscale/sample/original.png differ diff --git a/docs/source/api_doc/upscale/sample/rgba_restore.png b/docs/source/api_doc/upscale/sample/rgba_restore.png new file mode 100644 index 00000000000..9bb79b8b253 Binary files /dev/null and b/docs/source/api_doc/upscale/sample/rgba_restore.png differ diff --git a/docs/source/api_doc/upscale/sample/skadi.jpg b/docs/source/api_doc/upscale/sample/skadi.jpg new file mode 100644 index 00000000000..a585ecb7c85 Binary files /dev/null and b/docs/source/api_doc/upscale/sample/skadi.jpg differ diff --git a/docs/source/api_doc/upscale/sample/xx.jpg b/docs/source/api_doc/upscale/sample/xx.jpg new file mode 100644 index 00000000000..dfa950eeab0 Binary files /dev/null and b/docs/source/api_doc/upscale/sample/xx.jpg differ diff --git a/docs/source/index.rst b/docs/source/index.rst index 795e8fe0995..703e90bf13d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -42,6 +42,7 @@ configuration file's structure and their versions. api_doc/sd/index api_doc/segment/index api_doc/tagging/index + api_doc/upscale/index api_doc/utils/index api_doc/validate/index diff --git a/imgutils/restore/nafnet.py b/imgutils/restore/nafnet.py index 1ee0a050dc5..f13df1373f0 100644 --- a/imgutils/restore/nafnet.py +++ b/imgutils/restore/nafnet.py @@ -51,7 +51,8 @@ def _open_nafnet_model(model: NafNetModelTyping): def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', - tile_size: int = 256, tile_overlap: int = 16, silent: bool = False) -> Image.Image: + tile_size: int = 256, tile_overlap: int = 16, batch_size: int = 4, + silent: bool = False) -> Image.Image: """ Restore an image using the NAFNet model. @@ -63,6 +64,8 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', :type tile_size: int :param tile_overlap: The overlap between tiles. Default is 16. :type tile_overlap: int + :param batch_size: The batch size of inference. Default is 4. + :type batch_size: int :param silent: If True, the progress will not be displayed. Default is False. :type silent: bool :return: The restored image. @@ -79,8 +82,8 @@ def _method(ix): output_ = area_batch_run( input_, _method, - tile_size=tile_size, tile_overlap=tile_overlap, silent=silent, - process_title='NafNet Restore', + tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, + silent=silent, process_title='NafNet Restore', ) output_ = np.clip(output_, a_min=0.0, a_max=1.0) ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') diff --git a/imgutils/restore/scunet.py b/imgutils/restore/scunet.py index a95bd388ce5..368df4b5e03 100644 --- a/imgutils/restore/scunet.py +++ b/imgutils/restore/scunet.py @@ -46,7 +46,8 @@ def _open_scunet_model(model: SCUNetModelTyping): def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', - tile_size: int = 128, tile_overlap: int = 16, silent: bool = False) -> Image.Image: + tile_size: int = 128, tile_overlap: int = 16, batch_size: int = 4, + silent: bool = False) -> Image.Image: """ Restore an image using the SCUNet model. @@ -58,6 +59,8 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', :type tile_size: int :param tile_overlap: The overlap between tiles. Default is 16. :type tile_overlap: int + :param batch_size: The batch size of inference. Default is 4. + :type batch_size: int :param silent: If True, the progress will not be displayed. Default is False. :type silent: bool :return: The restored image. @@ -74,8 +77,8 @@ def _method(ix): output_ = area_batch_run( input_, _method, - tile_size=tile_size, tile_overlap=tile_overlap, silent=silent, - process_title='SCUNet Restore', + tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, + silent=silent, process_title='SCUNet Restore', ) output_ = np.clip(output_, a_min=0.0, a_max=1.0) ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') diff --git a/imgutils/upscale/__init__.py b/imgutils/upscale/__init__.py new file mode 100644 index 00000000000..45ce3d4f775 --- /dev/null +++ b/imgutils/upscale/__init__.py @@ -0,0 +1 @@ +from .cdc import upscale_with_cdc diff --git a/imgutils/upscale/cdc.py b/imgutils/upscale/cdc.py new file mode 100644 index 00000000000..b4996fb3bdd --- /dev/null +++ b/imgutils/upscale/cdc.py @@ -0,0 +1,139 @@ +""" +Overview: + Upscale images with CDC model, developed and trained by `7eu7d7 `_, + the models are hosted on `deepghs/cdc_anime_onnx `_. + + Here are some examples: + + .. image:: cdc_demo.plot.py.svg + :align: center + + Here is the benchmark of CDC models: + + .. image:: cdc_benchmark.plot.py.svg + :align: center + + .. note:: + CDC model has high quality, and really low running speed. + As we tested, when it upscales an image with 1024x1024 resolution on 2060 GPU, + the time cost is approx 70s/image. So we strongly recommend against running it on CPU. + Please run CDC model on environments with GPU for better experience. +""" +from functools import lru_cache +from typing import Tuple, Any + +import cv2 +import numpy as np +from PIL import Image +from huggingface_hub import hf_hub_download + +from .transparent import _rgba_preprocess, _rgba_postprocess +from ..data import ImageTyping, load_image +from ..utils import open_onnx_model, area_batch_run + + +@lru_cache() +def _open_cdc_upscaler_model(model: str) -> Tuple[Any, int]: + """ + Opens and initializes the CDC upscaler model. + + :param model: The name of the model to use. + :type model: str + + :return: Tuple of the ONNX model and the scale factor. + :rtype: Tuple[Any, int] + """ + ort = open_onnx_model(hf_hub_download( + f'deepghs/cdc_anime_onnx', + f'{model}.onnx' + )) + + input_ = np.random.randn(1, 3, 16, 16).astype(np.float32) + output_, = ort.run(['output'], {'input': input_}) + + batch, channels, scale_h, height, scale_w, width = output_.shape + assert batch == 1 and channels == 3 and height == 16 and width == 16, \ + f'Unexpected output size found {output_.shape!r}.' + assert scale_h == scale_w, f'Scale of height and width not match - {output_.shape!r}.' + + return ort, scale_h + + +_CDC_INPUT_UNIT = 16 + + +def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320', + tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1, + alpha_interpolation: int = cv2.INTER_LINEAR, silent: bool = False, ) -> Image.Image: + """ + Upscale the input image using the CDC upscaler model. + + :param image: The input image. + :type image: ImageTyping + + :param model: The name of the model to use. (default: 'HGSR-MHR-anime-aug_X4_320') + :type model: str + + :param tile_size: The size of each tile. (default: 512) + :type tile_size: int + + :param tile_overlap: The overlap between tiles. (default: 64) + :type tile_overlap: int + + :param batch_size: The batch size. (default: 1) + :type batch_size: int + + :param alpha_interpolation: Interpolation for :func:`cv2.resize`. Default is ``cv2.INTER_LINEAR``. + :type alpha_interpolation: int + + :param silent: Whether to suppress progress messages. (default: False) + :type silent: bool + + :return: The upscaled image. + :rtype: Image.Image + + .. note:: + RGBA images are supported. When you pass an image with transparency channel (e.g. RGBA image), + this function will return an RGBA image, otherwise return a RGB image. + + Example:: + >>> from PIL import Image + >>> from imgutils.upscale import upscale_with_cdc + >>> + >>> image = Image.open('cute_waifu_aroma.png') + >>> image + + >>> + >>> upscale_with_cdc(image) + + """ + image, alpha_mask = _rgba_preprocess(image) + image = load_image(image, mode='RGB', force_background='white') + input_ = np.array(image).astype(np.float32) / 255.0 + input_ = input_.transpose((2, 0, 1))[None, ...] + + ort, scale = _open_cdc_upscaler_model(model) + + def _method(ix): + ix = ix.astype(np.float32) + batch, channels, height, width = ix.shape + p_height = 0 if height % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (height % _CDC_INPUT_UNIT) + p_width = 0 if width % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (width % _CDC_INPUT_UNIT) + if p_height > 0 or p_width > 0: # align to 16 + ix = np.pad(ix, ((0, 0), (0, 0), (0, p_height), (0, p_width)), mode='reflect') + actual_height, actual_width = height, width + + ox, = ort.run(['output'], {'input': ix}) + batch, channels, scale_, height, scale_, width = ox.shape + ox = ox.reshape((batch, channels, scale_ * height, scale_ * width)) + ox = ox[..., :scale_ * actual_height, :scale_ * actual_width] # crop back + return ox + + output_ = area_batch_run( + input_, _method, + tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, + scale=scale, silent=silent, process_title='CDC Upscale', + ) + output_ = np.clip(output_, a_min=0.0, a_max=1.0) + ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.uint8), 'RGB') + return _rgba_postprocess(ret_image, alpha_mask, interpolation=alpha_interpolation) diff --git a/imgutils/upscale/transparent.py b/imgutils/upscale/transparent.py new file mode 100644 index 00000000000..bc9687a4b40 --- /dev/null +++ b/imgutils/upscale/transparent.py @@ -0,0 +1,71 @@ +from typing import Optional + +import cv2 +import numpy as np +from PIL import Image + +from ..data.image import ImageTyping, load_image + + +def _has_alpha_channel(image: Image.Image) -> bool: + """ + Check if the image has an alpha channel. + + :param image: The image to check. + :type image: Image.Image + + :return: True if the image has an alpha channel, False otherwise. + :rtype: bool + """ + return any(band in {'A', 'a', 'P'} for band in image.getbands()) + + +def _rgba_preprocess(image: ImageTyping): + """ + Preprocess the image for RGBA conversion. + + :param image: The image to preprocess. + :type image: ImageTyping + + :return: Preprocessed image and alpha mask. + :rtype: Tuple[Image.Image, Optional[np.ndarray]] + """ + image = load_image(image, force_background=None, mode=None) + if _has_alpha_channel(image): + image = image.convert('RGBA') + pimage = image.convert('RGB') + alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0 + else: + pimage = image.convert('RGB') + alpha_mask = None + + return pimage, alpha_mask + + +def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None, interpolation: int = cv2.INTER_LINEAR): + """ + Postprocess the image after RGBA conversion. + + :param pimage: The processed image. + :type pimage: Image.Image + + :param alpha_mask: The alpha mask. + :type alpha_mask: Optional[np.ndarray] + + :param interpolation: Interpolation for :func:`cv2.resize`. Default is ``cv2.INTER_LINEAR``. + :type interpolation: int + + :return: Postprocessed image. + :rtype: Image.Image + """ + assert pimage.mode == 'RGB' + if alpha_mask is None: + return pimage + else: + channels = np.array(pimage) + alpha_mask = cv2.resize(alpha_mask, channels.shape[:2], interpolation=interpolation) + alpha_mask = np.clip(alpha_mask, a_min=0.0, a_max=1.0) + alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis] + rgba_channels = np.concatenate([channels, alpha_channel], axis=-1) + assert rgba_channels.shape == (*channels.shape[:-1], 4) + return Image.fromarray(rgba_channels, mode='RGBA') diff --git a/imgutils/utils/area.py b/imgutils/utils/area.py index d7671f3ac1c..b99c2cd486e 100644 --- a/imgutils/utils/area.py +++ b/imgutils/utils/area.py @@ -48,8 +48,8 @@ def area_batch_run(origin_input: np.ndarray, func, scale: int = 1, tile = min(tile_size, height, width) stride = tile - tile_overlap - h_idx_list = list(range(0, height - tile, stride)) + [height - tile] - w_idx_list = list(range(0, width - tile, stride)) + [width - tile] + h_idx_list = sorted(set(list(range(0, height - tile, stride)) + [height - tile])) + w_idx_list = sorted(set(list(range(0, width - tile, stride)) + [width - tile])) sum_ = np.zeros((batch, output_channels, height * scale, width * scale), dtype=origin_input.dtype) weight = np.zeros_like(sum_, dtype=origin_input.dtype) diff --git a/test/testfile/rgba_upscale.png b/test/testfile/rgba_upscale.png new file mode 100644 index 00000000000..d11cc27b15a Binary files /dev/null and b/test/testfile/rgba_upscale.png differ diff --git a/test/testfile/rgba_upscale_4x.png b/test/testfile/rgba_upscale_4x.png new file mode 100644 index 00000000000..2ecda1d4fbf Binary files /dev/null and b/test/testfile/rgba_upscale_4x.png differ diff --git a/test/testfile/surtr_logo_2x.png b/test/testfile/surtr_logo_2x.png new file mode 100644 index 00000000000..3d417c86101 Binary files /dev/null and b/test/testfile/surtr_logo_2x.png differ diff --git a/test/testfile/surtr_logo_4x.png b/test/testfile/surtr_logo_4x.png new file mode 100644 index 00000000000..b70cae88cf8 Binary files /dev/null and b/test/testfile/surtr_logo_4x.png differ diff --git a/test/testfile/surtr_logo_small_2x.png b/test/testfile/surtr_logo_small_2x.png new file mode 100644 index 00000000000..a179f1c2331 Binary files /dev/null and b/test/testfile/surtr_logo_small_2x.png differ diff --git a/test/testfile/surtr_logo_small_4x.png b/test/testfile/surtr_logo_small_4x.png new file mode 100644 index 00000000000..ce93179ecc5 Binary files /dev/null and b/test/testfile/surtr_logo_small_4x.png differ diff --git a/test/upscale/__init__.py b/test/upscale/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/upscale/conftest.py b/test/upscale/conftest.py new file mode 100644 index 00000000000..1c9b9a538ef --- /dev/null +++ b/test/upscale/conftest.py @@ -0,0 +1,24 @@ +import pytest + +from imgutils.data import load_image +from test.testings import get_testfile + + +@pytest.fixture() +def sample_image(): + yield load_image(get_testfile('surtr_logo.png'), mode='RGB', force_background='white') + + +@pytest.fixture() +def sample_image_small(sample_image): + yield sample_image.resize((127, 126)) + + +@pytest.fixture() +def sample_rgba_image(): + yield load_image(get_testfile('rgba_upscale.png'), mode='RGBA', force_background=None) + + +@pytest.fixture() +def sample_rgba_image_4x(): + yield load_image(get_testfile('rgba_upscale_4x.png'), mode='RGBA', force_background=None) diff --git a/test/upscale/test_cdc.py b/test/upscale/test_cdc.py new file mode 100644 index 00000000000..27084d17086 --- /dev/null +++ b/test/upscale/test_cdc.py @@ -0,0 +1,50 @@ +import pytest +from PIL import Image + +from imgutils.data import grid_transparent +from imgutils.metrics import psnr +from imgutils.upscale import upscale_with_cdc +from imgutils.upscale.cdc import _open_cdc_upscaler_model + + +@pytest.fixture(autouse=True, scope='function') +def _release_model(): + try: + yield + finally: + _open_cdc_upscaler_model.cache_clear() + + +@pytest.mark.unittest +class TestUpscaleCDC: + def test_upscale_with_cdc_4x(self, sample_image): + assert psnr( + upscale_with_cdc(sample_image), + sample_image.resize((sample_image.width * 4, sample_image.height * 4), Image.LANCZOS) + ) >= 34.5 + + def test_upscale_with_cdc_2x(self, sample_image): + assert psnr( + upscale_with_cdc(sample_image, model='HGSR-MHR_X2_1680'), + sample_image.resize((sample_image.width * 2, sample_image.height * 2), Image.LANCZOS) + ) >= 35.5 + + def test_upscale_with_cdc_small_4x(self, sample_image_small, sample_image): + assert psnr( + upscale_with_cdc(sample_image_small) + .resize(sample_image.size, Image.LANCZOS), + sample_image, + ) >= 28.5 + + def test_upscale_with_cdc_small_2x(self, sample_image_small, sample_image): + assert psnr( + upscale_with_cdc(sample_image_small, model='HGSR-MHR_X2_1680') + .resize(sample_image.size, Image.LANCZOS), + sample_image, + ) >= 28.0 + + def test_upscale_with_cdc_4x_rgba(self, sample_rgba_image, sample_rgba_image_4x): + assert psnr( + grid_transparent(upscale_with_cdc(sample_rgba_image)), + grid_transparent(sample_rgba_image_4x), + ) >= 34.5