diff --git a/.dockerignore b/.dockerignore index 4bfc25102..ce674a1bf 100644 --- a/.dockerignore +++ b/.dockerignore @@ -8,3 +8,4 @@ ocrs models/* test/testdata/bboxes /venv +.git diff --git a/Dockerfile b/Dockerfile index 4249d46ba..f94d11ace 100644 --- a/Dockerfile +++ b/Dockerfile @@ -22,7 +22,7 @@ RUN apt-get remove -y g++ && \ COPY . /app # Prepare models -RUN python -u docker_prepare.py +RUN python -u docker_prepare.py --continue-on-error RUN rm -rf /tmp diff --git a/docker_prepare.py b/docker_prepare.py index 3a6e79cc2..5acbabf04 100644 --- a/docker_prepare.py +++ b/docker_prepare.py @@ -1,28 +1,55 @@ import asyncio - +from argparse import ArgumentParser from manga_translator.utils import ModelWrapper from manga_translator.detection import DETECTORS from manga_translator.ocr import OCRS from manga_translator.inpainting import INPAINTERS + +arg_parser = ArgumentParser() +arg_parser.add_argument("--models", default="") +arg_parser.add_argument("--continue-on-error", action="store_true") + + +cli_args = arg_parser.parse_args() + + async def download(dict): - for key, value in dict.items(): - if issubclass(value, ModelWrapper): - print(' -- Downloading', key) - try: - inst = value() - await inst.download() - except Exception as e: - print('Failed to download', key, value) - print(e) + """ """ + for key, value in dict.items(): + if issubclass(value, ModelWrapper): + print(" -- Downloading", key) + try: + inst = value() + await inst.download() + except Exception as e: + print("Failed to download", key, value) + print(e) + if not cli_args.continue_on_error: + raise + async def main(): - await download(DETECTORS) - await download(OCRS) - await download({ - k: v for k, v in INPAINTERS.items() - if k not in ['sd'] - }) - -if __name__ == '__main__': - asyncio.run(main()) + models: set[str] = set(filter(None, cli_args.models.split(","))) + + await download( + { + k: v + for k, v in DETECTORS.items() + if (not models) or (f"detector.{k}" in models) + } + ) + await download( + {k: v for k, v in OCRS.items() if (not models) or (f"ocr.{k}" in models)} + ) + await download( + { + k: v + for k, v in INPAINTERS.items() + if (not models) or (f"inpaint.{k}" in models) and (k not in ["sd"]) + } + ) + + +if __name__ == "__main__": + asyncio.run(main())