From 255988ec37b5918ab55e74cd6abd9bd432908a04 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Fri, 10 May 2024 21:35:05 +0800 Subject: [PATCH] dev(narugo): use enhance layer --- imgutils/restore/nafnet.py | 53 +++++++++++++++++---------- imgutils/restore/scunet.py | 53 +++++++++++++++++---------- imgutils/restore/transparent.py | 65 --------------------------------- 3 files changed, 68 insertions(+), 103 deletions(-) delete mode 100644 imgutils/restore/transparent.py diff --git a/imgutils/restore/nafnet.py b/imgutils/restore/nafnet.py index f13df1373f0..cf792e082cc 100644 --- a/imgutils/restore/nafnet.py +++ b/imgutils/restore/nafnet.py @@ -28,8 +28,8 @@ from PIL import Image from huggingface_hub import hf_hub_download -from .transparent import _rgba_preprocess, _rgba_postprocess -from ..data import ImageTyping, load_image +from ..data import ImageTyping +from ..generic import ImageEnhancer from ..utils import open_onnx_model, area_batch_run NafNetModelTyping = Literal['REDS', 'GoPro', 'SIDD'] @@ -50,6 +50,37 @@ def _open_nafnet_model(model: NafNetModelTyping): )) +class _Enhancer(ImageEnhancer): + def __init__(self, model: NafNetModelTyping = 'REDS', tile_size: int = 256, tile_overlap: int = 16, + batch_size: int = 4, silent: bool = False): + self.model = model + self.tile_size = tile_size + self.tile_overlap = tile_overlap + self.batch_size = batch_size + self.silent = silent + + def _process_rgb(self, rgb_array: np.ndarray): + input_ = rgb_array[None, ...] + + def _method(ix): + ox, = _open_nafnet_model(self.model).run(['output'], {'input': ix}) + return ox + + output_ = area_batch_run( + input_, _method, + tile_size=self.tile_size, tile_overlap=self.tile_overlap, batch_size=self.batch_size, + silent=self.silent, process_title='NafNet Restore', + ) + output_ = np.clip(output_, a_min=0.0, a_max=1.0) + return output_[0] + + +@lru_cache() +def _get_enhancer(model: NafNetModelTyping = 'REDS', tile_size: int = 256, tile_overlap: int = 16, + batch_size: int = 4, silent: bool = False) -> _Enhancer: + return _Enhancer(model, tile_size, tile_overlap, batch_size, silent) + + def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', tile_size: int = 256, tile_overlap: int = 16, batch_size: int = 4, silent: bool = False) -> Image.Image: @@ -71,20 +102,4 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS', :return: The restored image. :rtype: Image.Image """ - image, alpha_mask = _rgba_preprocess(image) - image = load_image(image, mode='RGB', force_background='white') - input_ = np.array(image).astype(np.float32) / 255.0 - input_ = input_.transpose((2, 0, 1))[None, ...] - - def _method(ix): - ox, = _open_nafnet_model(model).run(['output'], {'input': ix}) - return ox - - output_ = area_batch_run( - input_, _method, - tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, - silent=silent, process_title='NafNet Restore', - ) - output_ = np.clip(output_, a_min=0.0, a_max=1.0) - ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') - return _rgba_postprocess(ret_image, alpha_mask) + return _get_enhancer(model, tile_size, tile_overlap, batch_size, silent).process(image) diff --git a/imgutils/restore/scunet.py b/imgutils/restore/scunet.py index 368df4b5e03..4cd25e9f278 100644 --- a/imgutils/restore/scunet.py +++ b/imgutils/restore/scunet.py @@ -23,8 +23,8 @@ from PIL import Image from huggingface_hub import hf_hub_download -from .transparent import _rgba_preprocess, _rgba_postprocess -from ..data import ImageTyping, load_image +from ..data import ImageTyping +from ..generic import ImageEnhancer from ..utils import open_onnx_model, area_batch_run SCUNetModelTyping = Literal['GAN', 'PSNR'] @@ -45,6 +45,37 @@ def _open_scunet_model(model: SCUNetModelTyping): )) +class _Enhancer(ImageEnhancer): + def __init__(self, model: SCUNetModelTyping = 'GAN', tile_size: int = 128, tile_overlap: int = 16, + batch_size: int = 4, silent: bool = False): + self.model = model + self.tile_size = tile_size + self.tile_overlap = tile_overlap + self.batch_size = batch_size + self.silent = silent + + def _process_rgb(self, rgb_array: np.ndarray): + input_ = rgb_array[None, ...] + + def _method(ix): + ox, = _open_scunet_model(self.model).run(['output'], {'input': ix}) + return ox + + output_ = area_batch_run( + input_, _method, + tile_size=self.tile_size, tile_overlap=self.tile_overlap, batch_size=self.batch_size, + silent=self.silent, process_title='SCUNet Restore', + ) + output_ = np.clip(output_, a_min=0.0, a_max=1.0) + return output_[0] + + +@lru_cache() +def _get_enhancer(model: SCUNetModelTyping = 'GAN', tile_size: int = 128, tile_overlap: int = 16, + batch_size: int = 4, silent: bool = False) -> _Enhancer: + return _Enhancer(model, tile_size, tile_overlap, batch_size, silent) + + def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', tile_size: int = 128, tile_overlap: int = 16, batch_size: int = 4, silent: bool = False) -> Image.Image: @@ -66,20 +97,4 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN', :return: The restored image. :rtype: Image.Image """ - image, alpha_mask = _rgba_preprocess(image) - image = load_image(image, mode='RGB', force_background='white') - input_ = np.array(image).astype(np.float32) / 255.0 - input_ = input_.transpose((2, 0, 1))[None, ...] - - def _method(ix): - ox, = _open_scunet_model(model).run(['output'], {'input': ix}) - return ox - - output_ = area_batch_run( - input_, _method, - tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size, - silent=silent, process_title='SCUNet Restore', - ) - output_ = np.clip(output_, a_min=0.0, a_max=1.0) - ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB') - return _rgba_postprocess(ret_image, alpha_mask) + return _get_enhancer(model, tile_size, tile_overlap, batch_size, silent).process(image) diff --git a/imgutils/restore/transparent.py b/imgutils/restore/transparent.py deleted file mode 100644 index 0bb5e7a0306..00000000000 --- a/imgutils/restore/transparent.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Optional - -import numpy as np -from PIL import Image - -from ..data.image import ImageTyping, load_image - - -def _has_alpha_channel(image: Image.Image) -> bool: - """ - Check if the image has an alpha channel. - - :param image: The image to check. - :type image: Image.Image - - :return: True if the image has an alpha channel, False otherwise. - :rtype: bool - """ - return any(band in {'A', 'a', 'P'} for band in image.getbands()) - - -def _rgba_preprocess(image: ImageTyping): - """ - Preprocess the image for RGBA conversion. - - :param image: The image to preprocess. - :type image: ImageTyping - - :return: Preprocessed image and alpha mask. - :rtype: Tuple[Image.Image, Optional[np.ndarray]] - """ - image = load_image(image, force_background=None, mode=None) - if _has_alpha_channel(image): - image = image.convert('RGBA') - pimage = image.convert('RGB') - alpha_mask = np.array(image)[:, :, 3].astype(np.float32) / 255.0 - else: - pimage = image.convert('RGB') - alpha_mask = None - - return pimage, alpha_mask - - -def _rgba_postprocess(pimage, alpha_mask: Optional[np.ndarray] = None): - """ - Postprocess the image after RGBA conversion. - - :param pimage: The processed image. - :type pimage: Image.Image - - :param alpha_mask: The alpha mask. - :type alpha_mask: Optional[np.ndarray] - - :return: Postprocessed image. - :rtype: Image.Image - """ - assert pimage.mode == 'RGB' - if alpha_mask is None: - return pimage - else: - alpha_channel = (alpha_mask * 255.0).astype(np.uint8)[..., np.newaxis] - channels = np.array(pimage) - rgba_channels = np.concatenate([channels, alpha_channel], axis=-1) - assert rgba_channels.shape == (*channels.shape[:-1], 4) - return Image.fromarray(rgba_channels, mode='RGBA')