Skip to content

Commit

Permalink
dev(narugo): use enhance layer
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed May 10, 2024
1 parent ba04f1c commit 255988e
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 103 deletions.
53 changes: 34 additions & 19 deletions imgutils/restore/nafnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..data import ImageTyping
from ..generic import ImageEnhancer
from ..utils import open_onnx_model, area_batch_run

NafNetModelTyping = Literal['REDS', 'GoPro', 'SIDD']
Expand All @@ -50,6 +50,37 @@ def _open_nafnet_model(model: NafNetModelTyping):
))


class _Enhancer(ImageEnhancer):
def __init__(self, model: NafNetModelTyping = 'REDS', tile_size: int = 256, tile_overlap: int = 16,
batch_size: int = 4, silent: bool = False):
self.model = model
self.tile_size = tile_size
self.tile_overlap = tile_overlap
self.batch_size = batch_size
self.silent = silent

def _process_rgb(self, rgb_array: np.ndarray):
input_ = rgb_array[None, ...]

def _method(ix):
ox, = _open_nafnet_model(self.model).run(['output'], {'input': ix})
return ox

output_ = area_batch_run(
input_, _method,
tile_size=self.tile_size, tile_overlap=self.tile_overlap, batch_size=self.batch_size,
silent=self.silent, process_title='NafNet Restore',
)
output_ = np.clip(output_, a_min=0.0, a_max=1.0)
return output_[0]


@lru_cache()
def _get_enhancer(model: NafNetModelTyping = 'REDS', tile_size: int = 256, tile_overlap: int = 16,
batch_size: int = 4, silent: bool = False) -> _Enhancer:
return _Enhancer(model, tile_size, tile_overlap, batch_size, silent)


def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
tile_size: int = 256, tile_overlap: int = 16, batch_size: int = 4,
silent: bool = False) -> Image.Image:
Expand All @@ -71,20 +102,4 @@ def restore_with_nafnet(image: ImageTyping, model: NafNetModelTyping = 'REDS',
:return: The restored image.
:rtype: Image.Image
"""
image, alpha_mask = _rgba_preprocess(image)
image = load_image(image, mode='RGB', force_background='white')
input_ = np.array(image).astype(np.float32) / 255.0
input_ = input_.transpose((2, 0, 1))[None, ...]

def _method(ix):
ox, = _open_nafnet_model(model).run(['output'], {'input': ix})
return ox

output_ = area_batch_run(
input_, _method,
tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
silent=silent, process_title='NafNet Restore',
)
output_ = np.clip(output_, a_min=0.0, a_max=1.0)
ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
return _rgba_postprocess(ret_image, alpha_mask)
return _get_enhancer(model, tile_size, tile_overlap, batch_size, silent).process(image)
53 changes: 34 additions & 19 deletions imgutils/restore/scunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from PIL import Image
from huggingface_hub import hf_hub_download

from .transparent import _rgba_preprocess, _rgba_postprocess
from ..data import ImageTyping, load_image
from ..data import ImageTyping
from ..generic import ImageEnhancer
from ..utils import open_onnx_model, area_batch_run

SCUNetModelTyping = Literal['GAN', 'PSNR']
Expand All @@ -45,6 +45,37 @@ def _open_scunet_model(model: SCUNetModelTyping):
))


class _Enhancer(ImageEnhancer):
def __init__(self, model: SCUNetModelTyping = 'GAN', tile_size: int = 128, tile_overlap: int = 16,
batch_size: int = 4, silent: bool = False):
self.model = model
self.tile_size = tile_size
self.tile_overlap = tile_overlap
self.batch_size = batch_size
self.silent = silent

def _process_rgb(self, rgb_array: np.ndarray):
input_ = rgb_array[None, ...]

def _method(ix):
ox, = _open_scunet_model(self.model).run(['output'], {'input': ix})
return ox

output_ = area_batch_run(
input_, _method,
tile_size=self.tile_size, tile_overlap=self.tile_overlap, batch_size=self.batch_size,
silent=self.silent, process_title='SCUNet Restore',
)
output_ = np.clip(output_, a_min=0.0, a_max=1.0)
return output_[0]


@lru_cache()
def _get_enhancer(model: SCUNetModelTyping = 'GAN', tile_size: int = 128, tile_overlap: int = 16,
batch_size: int = 4, silent: bool = False) -> _Enhancer:
return _Enhancer(model, tile_size, tile_overlap, batch_size, silent)


def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
tile_size: int = 128, tile_overlap: int = 16, batch_size: int = 4,
silent: bool = False) -> Image.Image:
Expand All @@ -66,20 +97,4 @@ def restore_with_scunet(image: ImageTyping, model: SCUNetModelTyping = 'GAN',
:return: The restored image.
:rtype: Image.Image
"""
image, alpha_mask = _rgba_preprocess(image)
image = load_image(image, mode='RGB', force_background='white')
input_ = np.array(image).astype(np.float32) / 255.0
input_ = input_.transpose((2, 0, 1))[None, ...]

def _method(ix):
ox, = _open_scunet_model(model).run(['output'], {'input': ix})
return ox

output_ = area_batch_run(
input_, _method,
tile_size=tile_size, tile_overlap=tile_overlap, batch_size=batch_size,
silent=silent, process_title='SCUNet Restore',
)
output_ = np.clip(output_, a_min=0.0, a_max=1.0)
ret_image = Image.fromarray((output_[0].transpose((1, 2, 0)) * 255).astype(np.int8), 'RGB')
return _rgba_postprocess(ret_image, alpha_mask)
return _get_enhancer(model, tile_size, tile_overlap, batch_size, silent).process(image)
65 changes: 0 additions & 65 deletions imgutils/restore/transparent.py

This file was deleted.

0 comments on commit 255988e

Please sign in to comment.