diff --git a/requirements.txt b/requirements.txt index 7183da22..d10012c5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 click>=8.1.7,<9.0.0 datasets>=2.18.0,<3.0.0 -docling>=2.4.2,<3.0.0 +docling[tesserocr]>=2.4.2,<3.0.0 GitPython>=3.1.42,<4.0.0 httpx>=0.25.0,<1.0.0 instructlab-schema>=0.4.0 diff --git a/src/instructlab/sdg/utils/chunkers.py b/src/instructlab/sdg/utils/chunkers.py index 8946b8b7..50fd692c 100644 --- a/src/instructlab/sdg/utils/chunkers.py +++ b/src/instructlab/sdg/utils/chunkers.py @@ -12,16 +12,14 @@ from datasets import Dataset from docling.datamodel.base_models import InputFormat from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import PdfPipelineOptions -from docling.document_converter import ( - ConversionStatus, - DocumentConverter, - PdfFormatOption, +from docling.datamodel.pipeline_options import ( + EasyOcrOptions, + OcrOptions, + PdfPipelineOptions, + TesseractOcrOptions, ) -from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline from langchain_text_splitters import Language, RecursiveCharacterTextSplitter from tabulate import tabulate -from transformers import AutoTokenizer logger = logging.getLogger(__name__) _DEFAULT_CHUNK_OVERLAP = 100 @@ -35,6 +33,38 @@ def _num_chars_from_tokens(num_tokens) -> int: return int(num_tokens * 4) # 1 token ~ 4 English character +def resolve_ocr_options() -> OcrOptions: + # First, attempt to use tesserocr + try: + ocr_options = TesseractOcrOptions() + # pylint: disable=import-outside-toplevel + # Third Party + from docling.models.tesseract_ocr_model import TesseractOcrModel + + _ = TesseractOcrModel(True, ocr_options) + return ocr_options + except ImportError: + # No tesserocr, so try something else + pass + try: + ocr_options = EasyOcrOptions() + # Keep easyocr models on the CPU instead of GPU + ocr_options.use_gpu = False + # triggers torch loading, import lazily + # pylint: disable=import-outside-toplevel + # Third Party + from docling.models.easyocr_model import EasyOcrModel + + _ = EasyOcrModel(True, ocr_options) + return ocr_options + except ImportError: + # no easyocr either, so don't use any OCR + logger.error( + "Failed to load Tesseract and EasyOCR - disabling optical character recognition in PDF documents" + ) + return None + + class FileTypes(Enum): MD = ".md" PDF = ".pdf" @@ -208,13 +238,24 @@ def chunk_documents(self) -> List: Returns: List: a list of chunks from the documents """ + # triggers torch loading, import lazily + # pylint: disable=import-outside-toplevel + # Third Party + from docling.document_converter import DocumentConverter, PdfFormatOption + from docling.pipeline.standard_pdf_pipeline import StandardPdfPipeline + if self.document_paths == []: return [] model_artifacts_path = StandardPdfPipeline.download_models_hf() - pipeline_options = PdfPipelineOptions(artifacts_path=model_artifacts_path) - # Keep OCR models on the CPU instead of GPU - pipeline_options.ocr_options.use_gpu = False + pipeline_options = PdfPipelineOptions( + artifacts_path=model_artifacts_path, + do_ocr=False, + ) + ocr_options = resolve_ocr_options() + if ocr_options is not None: + pipeline_options.do_ocr = True + pipeline_options.ocr_options = ocr_options converter = DocumentConverter( format_options={ InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options) @@ -309,6 +350,11 @@ def create_tokenizer(self, model_name: str): Returns: AutoTokenizer: The tokenizer instance. """ + # import lazily to not load transformers at top level + # pylint: disable=import-outside-toplevel + # Third Party + from transformers import AutoTokenizer + try: tokenizer = AutoTokenizer.from_pretrained(model_name) logger.info(f"Successfully loaded tokenizer from: {model_name}") @@ -540,6 +586,11 @@ def export_documents(self, converted_docs: Iterable[ConversionResult]): Returns: Path: path to directory with docling json artifacts """ + # triggers torch loading, import lazily + # pylint: disable=import-outside-toplevel + # Third Party + from docling.document_converter import ConversionStatus + docling_artifacts_path = self.output_dir / "docling-artifacts" docling_artifacts_path.mkdir(parents=True, exist_ok=True) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py new file mode 100644 index 00000000..c1793a94 --- /dev/null +++ b/tests/functional/conftest.py @@ -0,0 +1,14 @@ +# Standard +import pathlib +import typing + +# Third Party +import pytest + +TESTS_PATH = pathlib.Path(__file__).parent.parent.absolute() + + +@pytest.fixture +def testdata_path() -> typing.Generator[pathlib.Path, None, None]: + """Path to local test data directory""" + yield TESTS_PATH / "testdata" diff --git a/tests/functional/test_imports.py b/tests/functional/test_imports.py new file mode 100644 index 00000000..0b8b77fb --- /dev/null +++ b/tests/functional/test_imports.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Standard +import pathlib +import subprocess +import sys + + +def test_sdg_imports(testdata_path: pathlib.Path): + script = testdata_path / "leanimports.py" + subprocess.check_call([sys.executable, str(script)], text=True) diff --git a/tests/test_chunkers.py b/tests/test_chunkers.py index 0d327982..a686fb75 100644 --- a/tests/test_chunkers.py +++ b/tests/test_chunkers.py @@ -2,9 +2,11 @@ # Standard from pathlib import Path +from unittest.mock import MagicMock, patch import tempfile # Third Party +from docling.datamodel.pipeline_options import EasyOcrOptions, TesseractOcrOptions import pytest # First Party @@ -13,6 +15,7 @@ DocumentChunker, FileTypes, TextSplitChunker, + resolve_ocr_options, ) # Local @@ -86,3 +89,52 @@ def test_chunker_factory_empty_filetype(documents_dir): output_dir=temp_dir, tokenizer_model_name="instructlab/merlinite-7b-lab", ) + + +def test_resolve_ocr_options_is_not_none(): + """ + Test that resolve_ocr_options does not return None, which means it + found a valid OCR library on the machine running this test + """ + ocr_options = resolve_ocr_options() + assert ocr_options is not None + + +@patch("docling.models.tesseract_ocr_model.TesseractOcrModel") +def test_resolve_ocr_options_prefers_tessserocr(mock_tesseract): + """ + Ensure resolve_ocr_options defaults to tesserocr if we're able + to load that library without error. + """ + mock_tesseract.return_value = MagicMock() + ocr_options = resolve_ocr_options() + assert isinstance(ocr_options, TesseractOcrOptions) + + +@patch("docling.models.tesseract_ocr_model.TesseractOcrModel") +def test_resolve_ocr_options_falls_back_to_easyocr(mock_tesseract): + """ + Ensure resolve_ocr_options falls back to easyocr if we cannot + load tesserocr. + """ + mock_tesseract.side_effect = ImportError("mock import error") + ocr_options = resolve_ocr_options() + assert isinstance(ocr_options, EasyOcrOptions) + + +@patch("docling.models.tesseract_ocr_model.TesseractOcrModel") +@patch("docling.models.easyocr_model.EasyOcrModel") +@patch("logging.Logger.error") +def test_resolve_ocr_options_none_found_logs_error( + mock_logger, mock_easyocr, mock_tesseract +): + """ + If we cannot load tesserocr or easyocr, ensure + resolve_ocr_options logs an error so that users are aware optical + character recognition in PDFs will be disabled. + """ + mock_tesseract.side_effect = ImportError("mock import error") + mock_easyocr.side_effect = ImportError("mock import error") + ocr_options = resolve_ocr_options() + assert ocr_options is None + mock_logger.assert_called() diff --git a/tests/testdata/leanimports.py b/tests/testdata/leanimports.py new file mode 100644 index 00000000..f6a50484 --- /dev/null +++ b/tests/testdata/leanimports.py @@ -0,0 +1,15 @@ +"""Helper for test_sdg_imports""" + +# Standard +import sys + +# block slow imports +for unwanted in ["deepspeed", "llama_cpp", "torch", "transformers", "vllm"]: + # importlib raises ModuleNotFound when sys.modules value is None. + assert unwanted not in sys.modules + sys.modules[unwanted] = None # type: ignore[assignment] + +# First Party +# This will trigger errors if any of the import chain tries to load +# the unwanted modules +from instructlab.sdg.generate_data import generate_data