From c5c76c320ec29a355adb5c66dea665a380965ea1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Budy=C5=9B?= Date: Fri, 11 Oct 2024 17:01:20 +0200 Subject: [PATCH] add GenerationMixin inheritance - needed for transformers >= 4.50 (#82) --- manga_ocr/_version.py | 2 +- manga_ocr/ocr.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/manga_ocr/_version.py b/manga_ocr/_version.py index 74acd0e..3cb7d95 100644 --- a/manga_ocr/_version.py +++ b/manga_ocr/_version.py @@ -1 +1 @@ -__version__ = "0.1.12" +__version__ = "0.1.13" diff --git a/manga_ocr/ocr.py b/manga_ocr/ocr.py index c5f398c..12eea7e 100644 --- a/manga_ocr/ocr.py +++ b/manga_ocr/ocr.py @@ -5,15 +5,18 @@ import torch from PIL import Image from loguru import logger -from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel +from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel, GenerationMixin +class MangaOcrModel(VisionEncoderDecoderModel, GenerationMixin): + pass + class MangaOcr: def __init__(self, pretrained_model_name_or_path="kha-white/manga-ocr-base", force_cpu=False): logger.info(f"Loading OCR model from {pretrained_model_name_or_path}") self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) - self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path) + self.model = MangaOcrModel.from_pretrained(pretrained_model_name_or_path) if not force_cpu and torch.cuda.is_available(): logger.info("Using CUDA")