forked from zyddnys/manga-image-translator
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from moeflow-com/add-gradio-ui
Add gradio UI
- Loading branch information
Showing
16 changed files
with
534 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# name: mit-py311 | ||
channels: | ||
- conda-forge | ||
- pytorch | ||
- nvidia | ||
dependencies: | ||
- python==3.11 | ||
- pytorch==2.2.2 | ||
- torchvision==0.17.2 | ||
- torchaudio==2.2.2 | ||
- pytorch-cuda=12.1 | ||
- numpy<2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
import logging | ||
from typing import List | ||
import gradio as gr | ||
import asyncio | ||
from pathlib import Path | ||
import json | ||
import uuid | ||
from PIL import Image | ||
import manga_translator.detection as mit_detection | ||
import manga_translator.ocr as mit_ocr | ||
import manga_translator.textline_merge as textline_merge | ||
import manga_translator.utils.generic as utils_generic | ||
from manga_translator.gradio import ( | ||
mit_detect_text_default_params, | ||
mit_ocr_default_params, | ||
storage_dir, | ||
load_model_mutex, | ||
MitJSONEncoder, | ||
) | ||
from manga_translator.utils.textblock import TextBlock | ||
|
||
STORAGE_DIR_RESOLVED = storage_dir.resolve() | ||
|
||
if gr.NO_RELOAD: | ||
logging.basicConfig( | ||
level=logging.WARN, | ||
format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
force=True, | ||
) | ||
for name in ["httpx"]: | ||
logging.getLogger(name).setLevel(logging.WARN) | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
async def copy_files(gradio_temp_files: list[str]) -> list[str]: | ||
new_root: Path = storage_dir / uuid.uuid4().hex | ||
new_root.mkdir(parents=True, exist_ok=True) | ||
|
||
ret: list[str] = [] | ||
for f in gradio_temp_files: | ||
new_file = new_root / f.split("/")[-1] | ||
new_file.write_bytes(Path(f).read_bytes()) | ||
ret.append(str(new_file.relative_to(storage_dir))) | ||
logger.debug("copied %s to %s", f, new_file) | ||
|
||
return ret | ||
|
||
|
||
def log_file(basename: str, result: List[TextBlock]): | ||
logger.info("file: %s", basename) | ||
for i, b in enumerate(result): | ||
logger.info(" block %d: %s", i, b.text) | ||
|
||
|
||
async def process_files( | ||
filename_list: list[str], detector_key: str, ocr_key: str, device: str | ||
) -> str: | ||
path_list: list[Path] = [] | ||
for f in filename_list: | ||
assert f | ||
# p = (storage_dir / f).resolve() | ||
# assert p.is_file() and STORAGE_DIR_RESOLVED in p.parents, f"illegal path: {f}" | ||
path_list.append(Path(f)) | ||
|
||
with load_model_mutex: | ||
await mit_detection.prepare(detector_key) | ||
await mit_ocr.prepare(ocr_key, device) | ||
|
||
result = await asyncio.gather( | ||
*[process_file(p, detector_key, ocr_key, device) for p in path_list] | ||
) | ||
|
||
for r in result: | ||
log_file(r["filename"], r["text_blocks"]) | ||
|
||
return json.dumps(result, cls=MitJSONEncoder) | ||
|
||
|
||
async def process_file( | ||
img_path: Path, detector: str, ocr_key: str, device: str | ||
) -> dict: | ||
pil_img = Image.open(img_path) | ||
img, mask = utils_generic.load_image(pil_img) | ||
img_w, img_h = img.shape[:2] | ||
|
||
try: | ||
# detector | ||
detector_args = { | ||
**mit_detect_text_default_params, | ||
"detector_key": detector, | ||
"device": device, | ||
} | ||
regions, mask_raw, mask = await mit_detection.dispatch( | ||
image=img, **detector_args | ||
) | ||
# ocr | ||
ocr_args = {**mit_ocr_default_params, "ocr_key": ocr_key, "device": device} | ||
textlines = await mit_ocr.dispatch(image=img, regions=regions, **ocr_args) | ||
# textline merge | ||
text_blocks = await textline_merge.dispatch( | ||
textlines=textlines, width=img_w, height=img_h | ||
) | ||
except Exception as e: | ||
logger.error("error processing %s: %s", img_path, e) | ||
print(e) | ||
text_blocks = [] | ||
else: | ||
logger.debug("processed %s", img_path) | ||
|
||
return { | ||
"filename": img_path.name, | ||
"text_blocks": text_blocks, | ||
} | ||
|
||
|
||
with gr.Blocks() as demo: | ||
file_input = gr.File( | ||
label="upload file", | ||
file_count="multiple", | ||
type="filepath", | ||
) | ||
|
||
ocr_output = gr.JSON( | ||
label="OCR output", | ||
) | ||
|
||
device_input = gr.Radio(choices=["cpu", "cuda"], label="device", value="cuda") | ||
detector_key_input = gr.Radio( | ||
choices=[ | ||
"default", | ||
# maybe broken: manga_translator.utils.inference.InvalidModelMappingException: [DBConvNextDetector->model] Invalid _MODEL_MAPPING - Malformed url property | ||
# "dbconvnext", | ||
"ctd", | ||
"craft", | ||
"none", | ||
], | ||
value="default", | ||
label="detector", | ||
) | ||
|
||
ocr_key_input = gr.Radio( | ||
choices=["48px", "48px_ctc", "mocr"], label="ocr", value="48px" | ||
) | ||
run_button = gr.Button("upload + text detection + OCR + textline_merge") | ||
|
||
@run_button.click( | ||
inputs=[file_input, detector_key_input, ocr_key_input, device_input], | ||
outputs=[ocr_output], | ||
) | ||
async def on_run_button( | ||
gradio_temp_files: list[str], detector_key: str, ocr_key: str, device: str | ||
) -> str: | ||
res = await process_files(gradio_temp_files, detector_key, ocr_key, device) | ||
return res | ||
|
||
|
||
if __name__ == "__main__": | ||
demo.queue(api_open=True, max_size=100).launch( | ||
debug=True, | ||
server_name="0.0.0.0", | ||
max_file_size=10 * gr.FileSize.MB, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
import gradio as gr | ||
import numpy as np | ||
from PIL import Image | ||
|
||
|
||
import dotenv | ||
import logging | ||
import os.path | ||
from pathlib import Path | ||
import manga_translator.detection as detection | ||
import manga_translator.ocr as mit_ocr | ||
import manga_translator.textline_merge as textline_merge | ||
import manga_translator.utils.generic as utils_generic | ||
from manga_translator.gradio import ( | ||
DetectionState, | ||
OcrState, | ||
mit_detect_text_default_params, | ||
) | ||
from typing import List, Optional, TypedDict | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
if gr.NO_RELOAD: | ||
logging.basicConfig(level=logging.INFO, force=True) | ||
for name in ["httpx"]: | ||
logging.getLogger(name).setLevel(logging.WARN) | ||
|
||
dotenv.load_dotenv() | ||
|
||
|
||
with gr.Blocks() as demo: | ||
gr.Markdown( | ||
""" | ||
# manga-image-translator demo | ||
""".strip() | ||
) | ||
|
||
detector_state = gr.State(DetectionState()) | ||
ocr_state = gr.State(OcrState()) | ||
|
||
with gr.Row(): | ||
with gr.Column(): | ||
gr.Markdown("## Detection") | ||
img_file = gr.Image( | ||
label="input image", height=256, width=256, type="filepath" | ||
) | ||
detector_key = gr.Radio( | ||
choices=[ | ||
"default", | ||
# maybe broken: manga_translator.utils.inference.InvalidModelMappingException: [DBConvNextDetector->model] Invalid _MODEL_MAPPING - Malformed url property | ||
# "dbconvnext", | ||
"ctd", | ||
"craft", | ||
"none", | ||
], | ||
label="detector", | ||
) | ||
|
||
btn_detect = gr.Button("run detector") | ||
detector_state_dump = gr.TextArea( | ||
label="detector result" # , value=lambda: repr(detector_state.value) | ||
) | ||
with gr.Column(): | ||
gr.Markdown("## OCR") | ||
ocr_key = gr.Radio(choices=["48px", "48px_ctc", "mocr"], label="ocr") | ||
btn_ocr = gr.Button("ocr") | ||
ocr_state_dump = gr.TextArea(label="ocr state") | ||
|
||
@btn_detect.click( | ||
inputs=[detector_state, img_file, detector_key], | ||
outputs=[detector_state, detector_state_dump], | ||
) | ||
async def run_detector( | ||
prev: DetectionState | gr.State, | ||
img_path: Optional[str], | ||
detector_key: Optional[str], | ||
): | ||
# print("prev", prev) | ||
prev_value = prev if isinstance(prev, DetectionState) else None # prev.value | ||
assert prev_value, "prev_value is None" | ||
logger.debug("run_detector %s %s", prev_value, img_path) | ||
|
||
value = prev_value.copy() | ||
|
||
if img_path: | ||
raw_bytes = Path(img_path).read_bytes() | ||
pil_img = Image.open(img_path) | ||
img, mask = utils_generic.load_image(pil_img) | ||
value = value.copy( | ||
raw_filename=os.path.basename(img_path), raw_bytes=raw_bytes, img=img | ||
) | ||
else: | ||
value = prev_value.copy(raw_filename=None, raw_bytes=None, img=None) | ||
|
||
if detector_key: | ||
value = value.copy( | ||
args={**mit_detect_text_default_params, "detector_key": detector_key} | ||
) | ||
|
||
if value.img is not None and value.args is not None: | ||
logger.debug("run inference") | ||
textlines, mask_raw, mask = await detection.dispatch( | ||
image=img, **value.args | ||
) | ||
value = value.copy(textlines=textlines, mask_raw=mask_raw, mask=mask) | ||
|
||
logger.debug("run_detector result %s", value) | ||
return value, repr(value) | ||
|
||
@btn_ocr.click( | ||
inputs=[ocr_state, detector_state, ocr_key], | ||
outputs=[ocr_state, ocr_state_dump], | ||
) | ||
async def run_ocr( | ||
prev_value: OcrState, | ||
detector_state: DetectionState, | ||
ocr_key: Optional[str], | ||
): | ||
logger.debug( | ||
"run ocr %s %s %s", type(prev_value), type(detector_state), ocr_key | ||
) | ||
|
||
if not ( | ||
ocr_key and (detector_state.img is not None) and detector_state.textlines | ||
): | ||
return prev_value, repr(prev_value) | ||
|
||
textlines = await mit_ocr.dispatch( | ||
ocr_key=ocr_key, | ||
image=detector_state.img, | ||
regions=detector_state.textlines, | ||
args={}, | ||
verbose=True, | ||
) | ||
|
||
img_w, img_h = detector_state.img.shape[:2] | ||
text_blocks = await textline_merge.dispatch( | ||
textlines=textlines, width=img_w, height=img_h | ||
) | ||
|
||
value = prev_value.copy(text_blocks=text_blocks, ocr_key=ocr_key) | ||
return value, repr(value) | ||
|
||
|
||
if __name__ == "__main__": | ||
demo.launch(server_name="0.0.0.0") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
/storage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from threading import RLock | ||
from pathlib import Path | ||
from .ocr import mit_ocr_default_params, OcrState | ||
from .detection import mit_detect_text_default_params, DetectionState | ||
from .json_encoder import JSONEncoder as MitJSONEncoder | ||
|
||
load_model_mutex = RLock() | ||
|
||
storage_dir = Path(__file__).parent.parent / "storage" | ||
|
||
__all__ = [ | ||
"mit_ocr_default_params", | ||
"OcrState", | ||
"mit_detect_text_default_params", | ||
"DetectionState", | ||
"storage_dir", | ||
"MitJSONEncoder", | ||
"load_model_mutex", | ||
] |
Oops, something went wrong.