Skip to content

Commit

Permalink
Merge pull request #7 from moeflow-com/add-gradio-ui
Browse files Browse the repository at this point in the history
Add gradio UI
  • Loading branch information
jokester authored Nov 17, 2024
2 parents c6f179b + 6827c88 commit b0d6380
Show file tree
Hide file tree
Showing 16 changed files with 534 additions and 80 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ RUN apt-get remove -y g++ && \
COPY . /app

# Prepare models
RUN python -u docker_prepare.py
RUN python -u docker_prepare.py --continue-on-error

RUN rm -rf /tmp

Expand Down
12 changes: 12 additions & 0 deletions conda.working.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# name: mit-py311
channels:
- conda-forge
- pytorch
- nvidia
dependencies:
- python==3.11
- pytorch==2.2.2
- torchvision==0.17.2
- torchaudio==2.2.2
- pytorch-cuda=12.1
- numpy<2
2 changes: 1 addition & 1 deletion docker_prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def download(dict):

async def main():
models: set[str] = set(filter(None, cli_args.models.split(",")))
# print("parsed.models", models)

await download(
{
k: v
Expand Down
165 changes: 165 additions & 0 deletions gradio-multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import logging
from typing import List
import gradio as gr
import asyncio
from pathlib import Path
import json
import uuid
from PIL import Image
import manga_translator.detection as mit_detection
import manga_translator.ocr as mit_ocr
import manga_translator.textline_merge as textline_merge
import manga_translator.utils.generic as utils_generic
from manga_translator.gradio import (
mit_detect_text_default_params,
mit_ocr_default_params,
storage_dir,
load_model_mutex,
MitJSONEncoder,
)
from manga_translator.utils.textblock import TextBlock

STORAGE_DIR_RESOLVED = storage_dir.resolve()

if gr.NO_RELOAD:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
force=True,
)
for name in ["httpx"]:
logging.getLogger(name).setLevel(logging.WARN)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


async def copy_files(gradio_temp_files: list[str]) -> list[str]:
new_root: Path = storage_dir / uuid.uuid4().hex
new_root.mkdir(parents=True, exist_ok=True)

ret: list[str] = []
for f in gradio_temp_files:
new_file = new_root / f.split("/")[-1]
new_file.write_bytes(Path(f).read_bytes())
ret.append(str(new_file.relative_to(storage_dir)))
logger.debug("copied %s to %s", f, new_file)

return ret


def log_file(basename: str, result: List[TextBlock]):
logger.info("file: %s", basename)
for i, b in enumerate(result):
logger.info(" block %d: %s", i, b.text)


async def process_files(
filename_list: list[str], detector_key: str, ocr_key: str, device: str
) -> str:
path_list: list[Path] = []
for f in filename_list:
assert f
# p = (storage_dir / f).resolve()
# assert p.is_file() and STORAGE_DIR_RESOLVED in p.parents, f"illegal path: {f}"
path_list.append(Path(f))

with load_model_mutex:
await mit_detection.prepare(detector_key)
await mit_ocr.prepare(ocr_key, device)

result = await asyncio.gather(
*[process_file(p, detector_key, ocr_key, device) for p in path_list]
)

for r in result:
log_file(r["filename"], r["text_blocks"])

return json.dumps(result, cls=MitJSONEncoder)


async def process_file(
img_path: Path, detector: str, ocr_key: str, device: str
) -> dict:
pil_img = Image.open(img_path)
img, mask = utils_generic.load_image(pil_img)
img_w, img_h = img.shape[:2]

try:
# detector
detector_args = {
**mit_detect_text_default_params,
"detector_key": detector,
"device": device,
}
regions, mask_raw, mask = await mit_detection.dispatch(
image=img, **detector_args
)
# ocr
ocr_args = {**mit_ocr_default_params, "ocr_key": ocr_key, "device": device}
textlines = await mit_ocr.dispatch(image=img, regions=regions, **ocr_args)
# textline merge
text_blocks = await textline_merge.dispatch(
textlines=textlines, width=img_w, height=img_h
)
except Exception as e:
logger.error("error processing %s: %s", img_path, e)
print(e)
text_blocks = []
else:
logger.debug("processed %s", img_path)

return {
"filename": img_path.name,
"text_blocks": text_blocks,
}


with gr.Blocks() as demo:
file_input = gr.File(
label="upload file",
file_count="multiple",
type="filepath",
)

ocr_output = gr.JSON(
label="OCR output",
)

device_input = gr.Radio(choices=["cpu", "cuda"], label="device", value="cuda")
detector_key_input = gr.Radio(
choices=[
"default",
# maybe broken: manga_translator.utils.inference.InvalidModelMappingException: [DBConvNextDetector->model] Invalid _MODEL_MAPPING - Malformed url property
# "dbconvnext",
"ctd",
"craft",
"none",
],
value="default",
label="detector",
)

ocr_key_input = gr.Radio(
choices=["48px", "48px_ctc", "mocr"], label="ocr", value="48px"
)
run_button = gr.Button("upload + text detection + OCR + textline_merge")

@run_button.click(
inputs=[file_input, detector_key_input, ocr_key_input, device_input],
outputs=[ocr_output],
)
async def on_run_button(
gradio_temp_files: list[str], detector_key: str, ocr_key: str, device: str
) -> str:
res = await process_files(gradio_temp_files, detector_key, ocr_key, device)
return res


if __name__ == "__main__":
demo.queue(api_open=True, max_size=100).launch(
debug=True,
server_name="0.0.0.0",
max_file_size=10 * gr.FileSize.MB,
)
147 changes: 147 additions & 0 deletions gradio-single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import gradio as gr
import numpy as np
from PIL import Image


import dotenv
import logging
import os.path
from pathlib import Path
import manga_translator.detection as detection
import manga_translator.ocr as mit_ocr
import manga_translator.textline_merge as textline_merge
import manga_translator.utils.generic as utils_generic
from manga_translator.gradio import (
DetectionState,
OcrState,
mit_detect_text_default_params,
)
from typing import List, Optional, TypedDict

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

if gr.NO_RELOAD:
logging.basicConfig(level=logging.INFO, force=True)
for name in ["httpx"]:
logging.getLogger(name).setLevel(logging.WARN)

dotenv.load_dotenv()


with gr.Blocks() as demo:
gr.Markdown(
"""
# manga-image-translator demo
""".strip()
)

detector_state = gr.State(DetectionState())
ocr_state = gr.State(OcrState())

with gr.Row():
with gr.Column():
gr.Markdown("## Detection")
img_file = gr.Image(
label="input image", height=256, width=256, type="filepath"
)
detector_key = gr.Radio(
choices=[
"default",
# maybe broken: manga_translator.utils.inference.InvalidModelMappingException: [DBConvNextDetector->model] Invalid _MODEL_MAPPING - Malformed url property
# "dbconvnext",
"ctd",
"craft",
"none",
],
label="detector",
)

btn_detect = gr.Button("run detector")
detector_state_dump = gr.TextArea(
label="detector result" # , value=lambda: repr(detector_state.value)
)
with gr.Column():
gr.Markdown("## OCR")
ocr_key = gr.Radio(choices=["48px", "48px_ctc", "mocr"], label="ocr")
btn_ocr = gr.Button("ocr")
ocr_state_dump = gr.TextArea(label="ocr state")

@btn_detect.click(
inputs=[detector_state, img_file, detector_key],
outputs=[detector_state, detector_state_dump],
)
async def run_detector(
prev: DetectionState | gr.State,
img_path: Optional[str],
detector_key: Optional[str],
):
# print("prev", prev)
prev_value = prev if isinstance(prev, DetectionState) else None # prev.value
assert prev_value, "prev_value is None"
logger.debug("run_detector %s %s", prev_value, img_path)

value = prev_value.copy()

if img_path:
raw_bytes = Path(img_path).read_bytes()
pil_img = Image.open(img_path)
img, mask = utils_generic.load_image(pil_img)
value = value.copy(
raw_filename=os.path.basename(img_path), raw_bytes=raw_bytes, img=img
)
else:
value = prev_value.copy(raw_filename=None, raw_bytes=None, img=None)

if detector_key:
value = value.copy(
args={**mit_detect_text_default_params, "detector_key": detector_key}
)

if value.img is not None and value.args is not None:
logger.debug("run inference")
textlines, mask_raw, mask = await detection.dispatch(
image=img, **value.args
)
value = value.copy(textlines=textlines, mask_raw=mask_raw, mask=mask)

logger.debug("run_detector result %s", value)
return value, repr(value)

@btn_ocr.click(
inputs=[ocr_state, detector_state, ocr_key],
outputs=[ocr_state, ocr_state_dump],
)
async def run_ocr(
prev_value: OcrState,
detector_state: DetectionState,
ocr_key: Optional[str],
):
logger.debug(
"run ocr %s %s %s", type(prev_value), type(detector_state), ocr_key
)

if not (
ocr_key and (detector_state.img is not None) and detector_state.textlines
):
return prev_value, repr(prev_value)

textlines = await mit_ocr.dispatch(
ocr_key=ocr_key,
image=detector_state.img,
regions=detector_state.textlines,
args={},
verbose=True,
)

img_w, img_h = detector_state.img.shape[:2]
text_blocks = await textline_merge.dispatch(
textlines=textlines, width=img_w, height=img_h
)

value = prev_value.copy(text_blocks=text_blocks, ocr_key=ocr_key)
return value, repr(value)


if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")
23 changes: 13 additions & 10 deletions manga_translator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,19 @@ async def dispatch(args: Namespace):
else: # batch
dest = args.dest
for path in natural_sort(args.input):
# Apply pre-translation dictionaries
await translator.translate_path(path, dest, args_dict)
for textline in translator.textlines:
textline.text = translator.apply_dictionary(textline.text, pre_dict)
logger.info(f'Pre-translation dictionary applied: {textline.text}')

# Apply post-translation dictionaries
for textline in translator.textlines:
textline.translation = translator.apply_dictionary(textline.translation, post_dict)
logger.info(f'Post-translation dictionary applied: {textline.translation}')
try :
# Apply pre-translation dictionaries
await translator.translate_path(path, dest, args_dict)
for textline in translator.textlines:
textline.text = translator.apply_dictionary(textline.text, pre_dict)
logger.info(f'Pre-translation dictionary applied: {textline.text}')

# Apply post-translation dictionaries
for textline in translator.textlines:
textline.translation = translator.apply_dictionary(textline.translation, post_dict)
logger.info(f'Post-translation dictionary applied: {textline.translation}')
except Exception :
pass

elif args.mode == 'web':
from .server.web_main import dispatch
Expand Down
1 change: 1 addition & 0 deletions manga_translator/gradio/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/storage
19 changes: 19 additions & 0 deletions manga_translator/gradio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from threading import RLock
from pathlib import Path
from .ocr import mit_ocr_default_params, OcrState
from .detection import mit_detect_text_default_params, DetectionState
from .json_encoder import JSONEncoder as MitJSONEncoder

load_model_mutex = RLock()

storage_dir = Path(__file__).parent.parent / "storage"

__all__ = [
"mit_ocr_default_params",
"OcrState",
"mit_detect_text_default_params",
"DetectionState",
"storage_dir",
"MitJSONEncoder",
"load_model_mutex",
]
Loading

0 comments on commit b0d6380

Please sign in to comment.