Skip to content

Commit

Permalink
docker_prepare: allow set of models/continue_on_error
Browse files Browse the repository at this point in the history
  • Loading branch information
jokester committed Nov 17, 2024
1 parent 815dcd3 commit d8458f9
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 20 deletions.
1 change: 1 addition & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ ocrs
models/*
test/testdata/bboxes
/venv
.git
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ RUN apt-get remove -y g++ && \
COPY . /app

# Prepare models
RUN python -u docker_prepare.py
RUN python -u docker_prepare.py --continue-on-error

RUN rm -rf /tmp

Expand Down
65 changes: 46 additions & 19 deletions docker_prepare.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,55 @@
import asyncio

from argparse import ArgumentParser
from manga_translator.utils import ModelWrapper
from manga_translator.detection import DETECTORS
from manga_translator.ocr import OCRS
from manga_translator.inpainting import INPAINTERS


arg_parser = ArgumentParser()
arg_parser.add_argument("--models", default="")
arg_parser.add_argument("--continue-on-error", action="store_true")


cli_args = arg_parser.parse_args()


async def download(dict):
for key, value in dict.items():
if issubclass(value, ModelWrapper):
print(' -- Downloading', key)
try:
inst = value()
await inst.download()
except Exception as e:
print('Failed to download', key, value)
print(e)
""" """
for key, value in dict.items():
if issubclass(value, ModelWrapper):
print(" -- Downloading", key)
try:
inst = value()
await inst.download()
except Exception as e:
print("Failed to download", key, value)
print(e)
if not cli_args.continue_on_error:
raise


async def main():
await download(DETECTORS)
await download(OCRS)
await download({
k: v for k, v in INPAINTERS.items()
if k not in ['sd']
})

if __name__ == '__main__':
asyncio.run(main())
models: set[str] = set(filter(None, cli_args.models.split(",")))

await download(
{
k: v
for k, v in DETECTORS.items()
if (not models) or (f"detector.{k}" in models)
}
)
await download(
{k: v for k, v in OCRS.items() if (not models) or (f"ocr.{k}" in models)}
)
await download(
{
k: v
for k, v in INPAINTERS.items()
if (not models) or (f"inpaint.{k}" in models) and (k not in ["sd"])
}
)


if __name__ == "__main__":
asyncio.run(main())

0 comments on commit d8458f9

Please sign in to comment.