diff --git a/docs/source/api_doc/upscale/cdc.rst b/docs/source/api_doc/upscale/cdc.rst
new file mode 100644
index 00000000000..9d76f491f55
--- /dev/null
+++ b/docs/source/api_doc/upscale/cdc.rst
@@ -0,0 +1,15 @@
+imgutils.upscale.cdc
+====================================
+
+.. currentmodule:: imgutils.upscale.cdc
+
+.. automodule:: imgutils.upscale.cdc
+
+
+upscale_with_cdc
+---------------------------
+
+.. autofunction:: upscale_with_cdc
+
+
+
diff --git a/docs/source/api_doc/upscale/cdc_benchmark.plot.py b/docs/source/api_doc/upscale/cdc_benchmark.plot.py
new file mode 100644
index 00000000000..d21553d3f34
--- /dev/null
+++ b/docs/source/api_doc/upscale/cdc_benchmark.plot.py
@@ -0,0 +1,44 @@
+import os.path
+import random
+
+from huggingface_hub import HfFileSystem
+
+from benchmark import BaseBenchmark, create_plot_cli
+from imgutils.upscale.cdc import upscale_with_cdc
+
+hf_fs = HfFileSystem()
+repository = 'deepghs/cdc_anime_onnx'
+_CDC_MODELS = [
+ os.path.splitext(os.path.relpath(file, repository))[0]
+ for file in hf_fs.glob(f'{repository}/*.onnx')
+]
+
+
+class CDCUpscalerBenchmark(BaseBenchmark):
+ def __init__(self, model: str):
+ BaseBenchmark.__init__(self)
+ self.model = model
+
+ def load(self):
+ from imgutils.upscale.cdc import _open_cdc_upscaler_model
+ _open_cdc_upscaler_model(self.model)
+
+ def unload(self):
+ from imgutils.upscale.cdc import _open_cdc_upscaler_model
+ _open_cdc_upscaler_model.cache_clear()
+
+ def run(self):
+ image_file = random.choice(self.all_images)
+ _ = upscale_with_cdc(image_file, model=self.model)
+
+
+if __name__ == '__main__':
+ create_plot_cli(
+ [
+ (model, CDCUpscalerBenchmark(model))
+ for model in _CDC_MODELS
+ ],
+ title='Benchmark for CDCUpscaler Models',
+ run_times=3,
+ try_times=3,
+ )()
diff --git a/docs/source/api_doc/upscale/cdc_benchmark.plot.py.svg b/docs/source/api_doc/upscale/cdc_benchmark.plot.py.svg
new file mode 100644
index 00000000000..dba8439c9b5
--- /dev/null
+++ b/docs/source/api_doc/upscale/cdc_benchmark.plot.py.svg
@@ -0,0 +1,2154 @@
+
+
+
diff --git a/docs/source/api_doc/upscale/cdc_demo.plot.py b/docs/source/api_doc/upscale/cdc_demo.plot.py
new file mode 100644
index 00000000000..6c0d6a03c19
--- /dev/null
+++ b/docs/source/api_doc/upscale/cdc_demo.plot.py
@@ -0,0 +1,36 @@
+import os
+
+from huggingface_hub import HfFileSystem
+
+from imgutils.upscale import upscale_with_cdc
+from imgutils.upscale.cdc import _open_cdc_upscaler_model
+from plot import image_plot
+
+hf_fs = HfFileSystem()
+repository = 'deepghs/cdc_anime_onnx'
+_CDC_MODELS = [
+ os.path.splitext(os.path.relpath(file, repository))[0]
+ for file in hf_fs.glob(f'{repository}/*.onnx')
+]
+
+if __name__ == '__main__':
+ demo_images = [
+ ('sample/original.png', 'Small Logo'),
+ ('sample/skadi.jpg', 'Illustration'),
+ ('sample/hutao.png', 'Large Illustration'),
+ # ('sample/xx.jpg', 'Illustration #2'),
+ ('sample/rgba_restore.png', 'RGBA Artwork'),
+ ]
+
+ items = []
+ for file, title in demo_images:
+ items.append((file, title))
+ for model in _CDC_MODELS:
+ _, scale = _open_cdc_upscaler_model(model)
+ items.append((upscale_with_cdc(file, model=model), f'{title}\n({scale}X By {model})'))
+
+ image_plot(
+ *items,
+ columns=len(_CDC_MODELS) + 1,
+ figsize=(4 * (len(_CDC_MODELS) + 1), 3.5 * len(demo_images)),
+ )
diff --git a/docs/source/api_doc/upscale/cdc_demo.plot.py.svg b/docs/source/api_doc/upscale/cdc_demo.plot.py.svg
new file mode 100644
index 00000000000..ac7836f1904
--- /dev/null
+++ b/docs/source/api_doc/upscale/cdc_demo.plot.py.svg
@@ -0,0 +1,1593 @@
+
+
+
diff --git a/docs/source/api_doc/upscale/index.rst b/docs/source/api_doc/upscale/index.rst
new file mode 100644
index 00000000000..7182d0cf806
--- /dev/null
+++ b/docs/source/api_doc/upscale/index.rst
@@ -0,0 +1,13 @@
+imgutils.upscale
+========================
+
+.. currentmodule:: imgutils.upscale
+
+.. automodule:: imgutils.upscale
+
+
+.. toctree::
+ :maxdepth: 3
+
+ cdc
+
diff --git a/docs/source/api_doc/upscale/sample/hutao.png b/docs/source/api_doc/upscale/sample/hutao.png
new file mode 100644
index 00000000000..a8f42dce778
Binary files /dev/null and b/docs/source/api_doc/upscale/sample/hutao.png differ
diff --git a/docs/source/api_doc/upscale/sample/original.png b/docs/source/api_doc/upscale/sample/original.png
new file mode 100644
index 00000000000..4757b83a42f
Binary files /dev/null and b/docs/source/api_doc/upscale/sample/original.png differ
diff --git a/docs/source/api_doc/upscale/sample/rgba_restore.png b/docs/source/api_doc/upscale/sample/rgba_restore.png
new file mode 100644
index 00000000000..9bb79b8b253
Binary files /dev/null and b/docs/source/api_doc/upscale/sample/rgba_restore.png differ
diff --git a/docs/source/api_doc/upscale/sample/skadi.jpg b/docs/source/api_doc/upscale/sample/skadi.jpg
new file mode 100644
index 00000000000..a585ecb7c85
Binary files /dev/null and b/docs/source/api_doc/upscale/sample/skadi.jpg differ
diff --git a/docs/source/api_doc/upscale/sample/xx.jpg b/docs/source/api_doc/upscale/sample/xx.jpg
new file mode 100644
index 00000000000..dfa950eeab0
Binary files /dev/null and b/docs/source/api_doc/upscale/sample/xx.jpg differ
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 795e8fe0995..703e90bf13d 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -42,6 +42,7 @@ configuration file's structure and their versions.
api_doc/sd/index
api_doc/segment/index
api_doc/tagging/index
+ api_doc/upscale/index
api_doc/utils/index
api_doc/validate/index
diff --git a/imgutils/restore/nafnet.py b/imgutils/restore/nafnet.py
index 1ee0a050dc5..f13df1373f0 100644
--- a/imgutils/restore/nafnet.py
+++ b/imgutils/restore/nafnet.py
@@ -51,7 +51,8 @@ def _open_nafnet_model(model: NafNetModelTyping):
def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
- tile_size: int = 256, tile_overlap: int = 16, silent: bool = False) -> Image.Image:
+ tile_size: int = 256, tile_overlap: int = 16, batch_size: int = 4,
+ silent: bool = False) -> Image.Image:
"""
Restore an image using the NAFNet model.
@@ -63,6 +64,8 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
:type tile_size: int
:param tile_overlap: The overlap between tiles. Default is 16.
:type tile_overlap: int
+ :param batch_size: The batch size of inference. Default is 4.
+ :type batch_size: int
:param silent: If True, the progress will not be displayed. Default is False.
:type silent: bool
:return: The restored image.
@@ -79,8 +82,8 @@ def _method(ix):
output_ = area_batch_run(
input_, _method,
- tile_size=tile_size, tile_overlap=tile_overlap, silent=silent,
- process_title='NafNet Restore',
+ tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
+ silent=silent, process_title='NafNet Restore',
)
output_ = np.clip(output_, a_min=0.0, a_max=1.0)
ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
diff --git a/imgutils/restore/scunet.py b/imgutils/restore/scunet.py
index a95bd388ce5..368df4b5e03 100644
--- a/imgutils/restore/scunet.py
+++ b/imgutils/restore/scunet.py
@@ -46,7 +46,8 @@ def _open_scunet_model(model: SCUNetModelTyping):
def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
- tile_size: int = 128, tile_overlap: int = 16, silent: bool = False) -> Image.Image:
+ tile_size: int = 128, tile_overlap: int = 16, batch_size: int = 4,
+ silent: bool = False) -> Image.Image:
"""
Restore an image using the SCUNet model.
@@ -58,6 +59,8 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
:type tile_size: int
:param tile_overlap: The overlap between tiles. Default is 16.
:type tile_overlap: int
+ :param batch_size: The batch size of inference. Default is 4.
+ :type batch_size: int
:param silent: If True, the progress will not be displayed. Default is False.
:type silent: bool
:return: The restored image.
@@ -74,8 +77,8 @@ def _method(ix):
output_ = area_batch_run(
input_, _method,
- tile_size=tile_size, tile_overlap=tile_overlap, silent=silent,
- process_title='SCUNet Restore',
+ tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
+ silent=silent, process_title='SCUNet Restore',
)
output_ = np.clip(output_, a_min=0.0, a_max=1.0)
ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
diff --git a/imgutils/upscale/__init__.py b/imgutils/upscale/__init__.py
new file mode 100644
index 00000000000..45ce3d4f775
--- /dev/null
+++ b/imgutils/upscale/__init__.py
@@ -0,0 +1 @@
+from .cdc import upscale_with_cdc
diff --git a/imgutils/upscale/cdc.py b/imgutils/upscale/cdc.py
new file mode 100644
index 00000000000..b4996fb3bdd
--- /dev/null
+++ b/imgutils/upscale/cdc.py
@@ -0,0 +1,139 @@
+"""
+Overview:
+ Upscale images with CDC model, developed and trained by `7eu7d7 `_,
+ the models are hosted on `deepghs/cdc_anime_onnx `_.
+
+ Here are some examples:
+
+ .. image:: cdc_demo.plot.py.svg
+ :align: center
+
+ Here is the benchmark of CDC models:
+
+ .. image:: cdc_benchmark.plot.py.svg
+ :align: center
+
+ .. note::
+ CDC model has high quality, and really low running speed.
+ As we tested, when it upscales an image with 1024x1024 resolution on 2060 GPU,
+ the time cost is approx 70s/image. So we strongly recommend against running it on CPU.
+ Please run CDC model on environments with GPU for better experience.
+"""
+from functools import lru_cache
+from typing import Tuple, Any
+
+import cv2
+import numpy as np
+from PIL import Image
+from huggingface_hub import hf_hub_download
+
+from .transparent import _rgba_preprocess, _rgba_postprocess
+from ..data import ImageTyping, load_image
+from ..utils import open_onnx_model, area_batch_run
+
+
+@lru_cache()
+def _open_cdc_upscaler_model(model: str) -> Tuple[Any, int]:
+ """
+ Opens and initializes the CDC upscaler model.
+
+ :param model: The name of the model to use.
+ :type model: str
+
+ :return: Tuple of the ONNX model and the scale factor.
+ :rtype: Tuple[Any, int]
+ """
+ ort = open_onnx_model(hf_hub_download(
+ f'deepghs/cdc_anime_onnx',
+ f'{model}.onnx'
+ ))
+
+ input_ = np.random.randn(1, 3, 16, 16).astype(np.float32)
+ output_, = ort.run(['output'], {'input': input_})
+
+ batch, channels, scale_h, height, scale_w, width = output_.shape
+ assert batch == 1 and channels == 3 and height == 16 and width == 16, \
+ f'Unexpected output size found {output_.shape!r}.'
+ assert scale_h == scale_w, f'Scale of height and width not match - {output_.shape!r}.'
+
+ return ort, scale_h
+
+
+_CDC_INPUT_UNIT = 16
+
+
+def upscale_with_cdc(image: ImageTyping, model: str = 'HGSR-MHR-anime-aug_X4_320',
+ tile_size: int = 512, tile_overlap: int = 64, batch_size: int = 1,
+ alpha_interpolation: int = cv2.INTER_LINEAR, silent: bool = False, ) -> Image.Image:
+ """
+ Upscale the input image using the CDC upscaler model.
+
+ :param image: The input image.
+ :type image: ImageTyping
+
+ :param model: The name of the model to use. (default: 'HGSR-MHR-anime-aug_X4_320')
+ :type model: str
+
+ :param tile_size: The size of each tile. (default: 512)
+ :type tile_size: int
+
+ :param tile_overlap: The overlap between tiles. (default: 64)
+ :type tile_overlap: int
+
+ :param batch_size: The batch size. (default: 1)
+ :type batch_size: int
+
+ :param alpha_interpolation: Interpolation for :func:`cv2.resize`. Default is ``cv2.INTER_LINEAR``.
+ :type alpha_interpolation: int
+
+ :param silent: Whether to suppress progress messages. (default: False)
+ :type silent: bool
+
+ :return: The upscaled image.
+ :rtype: Image.Image
+
+ .. note::
+ RGBA images are supported. When you pass an image with transparency channel (e.g. RGBA image),
+ this function will return an RGBA image, otherwise return a RGB image.
+
+ Example::
+ >>> from PIL import Image
+ >>> from imgutils.upscale import upscale_with_cdc
+ >>>
+ >>> image = Image.open('cute_waifu_aroma.png')
+ >>> image
+
+ >>>
+ >>> upscale_with_cdc(image)
+
+ """
+ image, alpha_mask = _rgba_preprocess(image)
+ image = load_image(image, mode='RGB', force_background='white')
+ input_ = np.array(image).astype(np.float32) / 255.0
+ input_ = input_.transpose((2, 0, 1))[None, ...]
+
+ ort, scale = _open_cdc_upscaler_model(model)
+
+ def _method(ix):
+ ix = ix.astype(np.float32)
+ batch, channels, height, width = ix.shape
+ p_height = 0 if height % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (height % _CDC_INPUT_UNIT)
+ p_width = 0 if width % _CDC_INPUT_UNIT == 0 else _CDC_INPUT_UNIT - (width % _CDC_INPUT_UNIT)
+ if p_height > 0 or p_width > 0: # align to 16
+ ix = np.pad(ix, ((0, 0), (0, 0), (0, p_height), (0, p_width)), mode='reflect')
+ actual_height, actual_width = height, width
+
+ ox, = ort.run(['output'], {'input': ix})
+ batch, channels, scale_, height, scale_, width = ox.shape
+ ox = ox.reshape((batch, channels, scale_ * height, scale_ * width))
+ ox = ox[..., :scale_ * actual_height, :scale_ * actual_width] # crop back
+ return ox
+
+ output_ = area_batch_run(
+ input_, _method,
+ tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
+ scale=scale, silent=silent, process_title='CDC Upscale',
+ )
+ output_ = np.clip(output_, a_min=0.0, a_max=1.0)
+ ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.uint8), 'RGB')
+ return _rgba_postprocess(ret_image, alpha_mask, interpolation=alpha_interpolation)
diff --git a/imgutils/upscale/transparent.py b/imgutils/upscale/transparent.py
new file mode 100644
index 00000000000..bc9687a4b40
--- /dev/null
+++ b/imgutils/upscale/transparent.py
@@ -0,0 +1,71 @@
+from typing import Optional
+
+import cv2
+import numpy as np
+from PIL import Image
+
+from ..data.image import ImageTyping, load_image
+
+
+def _has_alpha_channel(image: Image.Image) -> bool:
+ """
+ Check if the image has an alpha channel.
+
+ :param image: The image to check.
+ :type image: Image.Image
+
+ :return: True if the image has an alpha channel, False otherwise.
+ :rtype: bool
+ """
+ return any(band in {'A', 'a', 'P'} for band in image.getbands())
+
+
+def _rgba_preprocess(image: ImageTyping):
+ """
+ Preprocess the image for RGBA conversion.
+
+ :param image: The image to preprocess.
+ :type image: ImageTyping
+
+ :return: Preprocessed image and alpha mask.
+ :rtype: Tuple[Image.Image, Optional[np.ndarray]]
+ """
+ image = load_image(image, force_background=None, mode=None)
+ if _has_alpha_channel(image):
+ image = image.convert('RGBA')
+ pimage = image.convert('RGB')
+ alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0
+ else:
+ pimage = image.convert('RGB')
+ alpha_mask = None
+
+ return pimage, alpha_mask
+
+
+def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None, interpolation: int = cv2.INTER_LINEAR):
+ """
+ Postprocess the image after RGBA conversion.
+
+ :param pimage: The processed image.
+ :type pimage: Image.Image
+
+ :param alpha_mask: The alpha mask.
+ :type alpha_mask: Optional[np.ndarray]
+
+ :param interpolation: Interpolation for :func:`cv2.resize`. Default is ``cv2.INTER_LINEAR``.
+ :type interpolation: int
+
+ :return: Postprocessed image.
+ :rtype: Image.Image
+ """
+ assert pimage.mode == 'RGB'
+ if alpha_mask is None:
+ return pimage
+ else:
+ channels = np.array(pimage)
+ alpha_mask = cv2.resize(alpha_mask, channels.shape[:2], interpolation=interpolation)
+ alpha_mask = np.clip(alpha_mask, a_min=0.0, a_max=1.0)
+ alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis]
+ rgba_channels = np.concatenate([channels, alpha_channel], axis=-1)
+ assert rgba_channels.shape == (*channels.shape[:-1], 4)
+ return Image.fromarray(rgba_channels, mode='RGBA')
diff --git a/imgutils/utils/area.py b/imgutils/utils/area.py
index d7671f3ac1c..b99c2cd486e 100644
--- a/imgutils/utils/area.py
+++ b/imgutils/utils/area.py
@@ -48,8 +48,8 @@ def area_batch_run(origin_input: np.ndarray, func, scale: int = 1,
tile = min(tile_size, height, width)
stride = tile - tile_overlap
- h_idx_list = list(range(0, height - tile, stride)) + [height - tile]
- w_idx_list = list(range(0, width - tile, stride)) + [width - tile]
+ h_idx_list = sorted(set(list(range(0, height - tile, stride)) + [height - tile]))
+ w_idx_list = sorted(set(list(range(0, width - tile, stride)) + [width - tile]))
sum_ = np.zeros((batch, output_channels, height * scale, width * scale), dtype=origin_input.dtype)
weight = np.zeros_like(sum_, dtype=origin_input.dtype)
diff --git a/test/testfile/rgba_upscale.png b/test/testfile/rgba_upscale.png
new file mode 100644
index 00000000000..d11cc27b15a
Binary files /dev/null and b/test/testfile/rgba_upscale.png differ
diff --git a/test/testfile/rgba_upscale_4x.png b/test/testfile/rgba_upscale_4x.png
new file mode 100644
index 00000000000..2ecda1d4fbf
Binary files /dev/null and b/test/testfile/rgba_upscale_4x.png differ
diff --git a/test/testfile/surtr_logo_2x.png b/test/testfile/surtr_logo_2x.png
new file mode 100644
index 00000000000..3d417c86101
Binary files /dev/null and b/test/testfile/surtr_logo_2x.png differ
diff --git a/test/testfile/surtr_logo_4x.png b/test/testfile/surtr_logo_4x.png
new file mode 100644
index 00000000000..b70cae88cf8
Binary files /dev/null and b/test/testfile/surtr_logo_4x.png differ
diff --git a/test/testfile/surtr_logo_small_2x.png b/test/testfile/surtr_logo_small_2x.png
new file mode 100644
index 00000000000..a179f1c2331
Binary files /dev/null and b/test/testfile/surtr_logo_small_2x.png differ
diff --git a/test/testfile/surtr_logo_small_4x.png b/test/testfile/surtr_logo_small_4x.png
new file mode 100644
index 00000000000..ce93179ecc5
Binary files /dev/null and b/test/testfile/surtr_logo_small_4x.png differ
diff --git a/test/upscale/__init__.py b/test/upscale/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/test/upscale/conftest.py b/test/upscale/conftest.py
new file mode 100644
index 00000000000..1c9b9a538ef
--- /dev/null
+++ b/test/upscale/conftest.py
@@ -0,0 +1,24 @@
+import pytest
+
+from imgutils.data import load_image
+from test.testings import get_testfile
+
+
+@pytest.fixture()
+def sample_image():
+ yield load_image(get_testfile('surtr_logo.png'), mode='RGB', force_background='white')
+
+
+@pytest.fixture()
+def sample_image_small(sample_image):
+ yield sample_image.resize((127, 126))
+
+
+@pytest.fixture()
+def sample_rgba_image():
+ yield load_image(get_testfile('rgba_upscale.png'), mode='RGBA', force_background=None)
+
+
+@pytest.fixture()
+def sample_rgba_image_4x():
+ yield load_image(get_testfile('rgba_upscale_4x.png'), mode='RGBA', force_background=None)
diff --git a/test/upscale/test_cdc.py b/test/upscale/test_cdc.py
new file mode 100644
index 00000000000..27084d17086
--- /dev/null
+++ b/test/upscale/test_cdc.py
@@ -0,0 +1,50 @@
+import pytest
+from PIL import Image
+
+from imgutils.data import grid_transparent
+from imgutils.metrics import psnr
+from imgutils.upscale import upscale_with_cdc
+from imgutils.upscale.cdc import _open_cdc_upscaler_model
+
+
+@pytest.fixture(autouse=True, scope='function')
+def _release_model():
+ try:
+ yield
+ finally:
+ _open_cdc_upscaler_model.cache_clear()
+
+
+@pytest.mark.unittest
+class TestUpscaleCDC:
+ def test_upscale_with_cdc_4x(self, sample_image):
+ assert psnr(
+ upscale_with_cdc(sample_image),
+ sample_image.resize((sample_image.width * 4, sample_image.height * 4), Image.LANCZOS)
+ ) >= 34.5
+
+ def test_upscale_with_cdc_2x(self, sample_image):
+ assert psnr(
+ upscale_with_cdc(sample_image, model='HGSR-MHR_X2_1680'),
+ sample_image.resize((sample_image.width * 2, sample_image.height * 2), Image.LANCZOS)
+ ) >= 35.5
+
+ def test_upscale_with_cdc_small_4x(self, sample_image_small, sample_image):
+ assert psnr(
+ upscale_with_cdc(sample_image_small)
+ .resize(sample_image.size, Image.LANCZOS),
+ sample_image,
+ ) >= 28.5
+
+ def test_upscale_with_cdc_small_2x(self, sample_image_small, sample_image):
+ assert psnr(
+ upscale_with_cdc(sample_image_small, model='HGSR-MHR_X2_1680')
+ .resize(sample_image.size, Image.LANCZOS),
+ sample_image,
+ ) >= 28.0
+
+ def test_upscale_with_cdc_4x_rgba(self, sample_rgba_image, sample_rgba_image_4x):
+ assert psnr(
+ grid_transparent(upscale_with_cdc(sample_rgba_image)),
+ grid_transparent(sample_rgba_image_4x),
+ ) >= 34.5