Skip to content

Commit

Permalink
feat: update the gradio to be the same as hugginface
Browse files Browse the repository at this point in the history
  • Loading branch information
Kiteretsu77 committed Apr 3, 2024
1 parent 83b90db commit 6cad632
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions gradio_apisr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
'''
Gradio demo (almost the same code as the one used in Huggingface space)
'''
import os, sys
import cv2
import time
import gradio as gr
import torch
import numpy as np
Expand All @@ -20,6 +24,10 @@ def auto_download_if_needed(weight_path):
if not os.path.exists("pretrained"):
os.makedirs("pretrained")

if weight_path == "pretrained/4x_APISR_RRDB_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.2.0/4x_APISR_RRDB_GAN_generator.pth")
os.system("mv 4x_APISR_RRDB_GAN_generator.pth pretrained")

if weight_path == "pretrained/4x_APISR_GRL_GAN_generator.pth":
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/4x_APISR_GRL_GAN_generator.pth")
os.system("mv 4x_APISR_GRL_GAN_generator.pth pretrained")
Expand All @@ -28,6 +36,7 @@ def auto_download_if_needed(weight_path):
os.system("wget https://github.com/Kiteretsu77/APISR/releases/download/v0.1.0/2x_APISR_RRDB_GAN_generator.pth")
os.system("mv 2x_APISR_RRDB_GAN_generator.pth pretrained")




def inference(img_path, model_name):
Expand All @@ -41,22 +50,29 @@ def inference(img_path, model_name):
auto_download_if_needed(weight_path)
generator = load_grl(weight_path, scale=4) # Directly use default way now

elif model_name == "4xRRDB":
weight_path = "pretrained/4x_APISR_RRDB_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_rrdb(weight_path, scale=4) # Directly use default way now

elif model_name == "2xRRDB":
weight_path = "pretrained/2x_APISR_RRDB_GAN_generator.pth"
auto_download_if_needed(weight_path)
generator = load_rrdb(weight_path, scale=2) # Directly use default way now

else:
raise gr.Error(error)
raise gr.Error("We don't support such Model")

generator = generator.to(dtype=weight_dtype)


# In default, we will automatically use crop to match 4x size
super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, crop_for_4x=True)
save_image(super_resolved_img, "SR_result.png")
outputs = cv2.imread("SR_result.png")
super_resolved_img = super_resolve_img(generator, img_path, output_path=None, weight_dtype=weight_dtype, downsample_threshold=720, crop_for_4x=True)
store_name = str(time.time()) + ".png"
save_image(super_resolved_img, store_name)
outputs = cv2.imread(store_name)
outputs = cv2.cvtColor(outputs, cv2.COLOR_RGB2BGR)
os.remove(store_name)

return outputs

Expand All @@ -70,14 +86,18 @@ def inference(img_path, model_name):

MARKDOWN = \
"""
## APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024)
## <p style='text-align: center'> APISR: Anime Production Inspired Real-World Anime Super-Resolution (CVPR 2024) </p>
[GitHub](https://github.com/Kiteretsu77/APISR) | [Paper](https://arxiv.org/abs/2403.01598)
If APISR is helpful for you, please help star the GitHub Repo. Thanks!
APISR aims at restoring and enhancing low-quality low-resolution **anime** images and video sources with various degradations from real-world scenarios.
### Note: Due to memory restriction, all images whose short side is over 720 pixel will be downsampled to 720 pixel with the same aspect ratio. E.g., 1920x1080 -> 1280x720
### Note: Please check [Model Zoo](https://github.com/Kiteretsu77/APISR/blob/main/docs/model_zoo.md) for the description of each weight.
If APISR is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/APISR). Thanks!
"""

block = gr.Blocks().queue()
block = gr.Blocks().queue(max_size=10)
with block:
with gr.Row():
gr.Markdown(MARKDOWN)
Expand All @@ -87,6 +107,7 @@ def inference(img_path, model_name):
model_name = gr.Dropdown(
[
"2xRRDB",
"4xRRDB",
"4xGRL"
],
type="value",
Expand All @@ -106,7 +127,7 @@ def inference(img_path, model_name):
["__assets__/lr_inputs/41.png"],
["__assets__/lr_inputs/f91.jpg"],
["__assets__/lr_inputs/image-00440.png"],
["__assets__/lr_inputs/image-00164.png"],
["__assets__/lr_inputs/image-00164.jpg"],
["__assets__/lr_inputs/img_eva.jpeg"],
["__assets__/lr_inputs/naruto.jpg"],
],
Expand All @@ -115,4 +136,4 @@ def inference(img_path, model_name):

run_btn.click(inference, inputs=[input_image, model_name], outputs=[output_image])

block.launch()
block.launch()

0 comments on commit 6cad632

Please sign in to comment.