Skip to content

Commit

Permalink
add GenerationMixin inheritance - needed for transformers >= 4.50 (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
kha-white authored Oct 11, 2024
1 parent 083ddbd commit c5c76c3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion manga_ocr/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.12"
__version__ = "0.1.13"
7 changes: 5 additions & 2 deletions manga_ocr/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
import torch
from PIL import Image
from loguru import logger
from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel, GenerationMixin


class MangaOcrModel(VisionEncoderDecoderModel, GenerationMixin):
pass

class MangaOcr:
def __init__(self, pretrained_model_name_or_path="kha-white/manga-ocr-base", force_cpu=False):
logger.info(f"Loading OCR model from {pretrained_model_name_or_path}")
self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path)
self.model = MangaOcrModel.from_pretrained(pretrained_model_name_or_path)

if not force_cpu and torch.cuda.is_available():
logger.info("Using CUDA")
Expand Down

0 comments on commit c5c76c3

Please sign in to comment.