diff --git a/ms_agent/tools/docling/doc_loader.py b/ms_agent/tools/docling/doc_loader.py index a19a03df5..c074115a3 100644 --- a/ms_agent/tools/docling/doc_loader.py +++ b/ms_agent/tools/docling/doc_loader.py @@ -15,7 +15,7 @@ from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem from ms_agent.tools.docling.doc_postprocess import PostProcess -from ms_agent.tools.docling.patches import (download_models_ms, +from ms_agent.tools.docling.patches import (convert_ms, download_models_ms, download_models_pic_classifier_ms, html_handle_figure, html_handle_image, @@ -211,6 +211,7 @@ def _postprocess(doc: DoclingDocument) -> Union[DoclingDocument, None]: return doc + @patch(DocumentConverter, '_convert', convert_ms) @patch(LayoutModel, 'download_models', download_models_ms) @patch(TableStructureModel, 'download_models', download_models_ms) @patch(DocumentPictureClassifier, 'download_models', @@ -230,10 +231,16 @@ def load(self, urls_or_files: list[str]) -> List[DoclingDocument]: # TODO: Support progress bar for document loading (with pather) results: Iterator[ConversionResult] = self._converter.convert_all( source=urls_or_files, ) + iter_urls_or_files = iter(urls_or_files) final_results = [] while True: try: + # Record the current URL for parsing images + setattr(self._converter, 'current_url', + next(iter_urls_or_files)) + assert self._converter.current_url is not None, 'Current URL should not be None' + res: ConversionResult = next(results) if res is None or res.document is None: continue diff --git a/ms_agent/tools/docling/patches.py b/ms_agent/tools/docling/patches.py index ac4f45521..cc32398e5 100644 --- a/ms_agent/tools/docling/patches.py +++ b/ms_agent/tools/docling/patches.py @@ -1,12 +1,18 @@ # flake8: noqa +import time +from functools import partial from pathlib import Path +from typing import Dict, Iterator from bs4 import Tag +from docling.datamodel.document import (ConversionResult, + _DocumentConversionInput) +from docling.datamodel.settings import settings +from docling.utils.utils import chunkify from docling_core.types import DoclingDocument from docling_core.types.doc import DocItemLabel, ImageRef from ms_agent.utils.logger import get_logger -from ms_agent.utils.utils import (load_image_from_uri_to_pil, - load_image_from_url_to_pil, validate_url) +from ms_agent.utils.utils import extract_image logger = get_logger() @@ -25,16 +31,11 @@ def html_handle_figure(self, element: Tag, doc: DoclingDocument) -> None: else: img_url = None - if img_url: - if img_url.startswith('data:'): - img_pil = load_image_from_uri_to_pil(img_url) - else: - if not img_url.startswith('http'): - img_url = validate_url(img_url=img_url, backend=self) - img_pil = load_image_from_url_to_pil( - img_url) if img_url.startswith('http') else None - else: - img_pil = None + # extract image from url or data URI + img_pil: 'Image.Image' = extract_image( + img_url=img_url, + backend=self, + base_url=self.current_url if hasattr(self, 'current_url') else None) dpi: int = int(img_pil.info.get('dpi', (96, 96))[0]) if img_pil else 96 img_ref: ImageRef = None @@ -82,15 +83,11 @@ def html_handle_image(self, element: Tag, doc: DoclingDocument) -> None: # Get the image from element img_url: str = element.attrs.get('src', None) - if img_url: - if img_url.startswith('data:'): - img_pil = load_image_from_uri_to_pil(img_url) - else: - if not img_url.startswith('http'): - img_url = validate_url(img_url=img_url, backend=self) - img_pil = load_image_from_url_to_pil(img_url) - else: - img_pil = None + # extract image from url or data URI + img_pil: 'Image.Image' = extract_image( + img_url=img_url, + backend=self, + base_url=self.current_url if hasattr(self, 'current_url') else None) dpi: int = int(img_pil.info.get('dpi', (96, 96))[0]) if img_pil else 96 @@ -170,3 +167,50 @@ def patch_easyocr_models(): 'url'] = 'https://modelscope.cn/models/ms-agent/kannada_g2/resolve/master/kannada_g2.zip' recognition_models['gen2']['cyrillic_g2'][ 'url'] = 'https://modelscope.cn/models/ms-agent/cyrillic_g2/resolve/master/cyrillic_g2.zip' + + +def convert_ms(self, conv_input: _DocumentConversionInput, + raises_on_error: bool) -> Iterator[ConversionResult]: + """ + Patch the `docling.document_converter.DocumentConverter._convert` method for image parsing. + """ + start_time = time.monotonic() + + def _add_custom_attributes(doc: DoclingDocument, target: str, + attributes_dict: Dict) -> DoclingDocument: + """ + Add custom attributes to the target object. + """ + for key, value in attributes_dict.items(): + target_obj = getattr(doc, target) if target is not None else doc + if not hasattr(target_obj, key): + setattr(target_obj, key, value) + else: + raise ValueError( + f"Attribute '{key}' already exists in the document.") + return doc + + for input_batch in chunkify( + conv_input.docs(self.format_to_options), + settings.perf.doc_batch_size, # pass format_options + ): + # parallel processing only within input_batch + # with ThreadPoolExecutor( + # max_workers=settings.perf.doc_batch_concurrency + # ) as pool: + # yield from pool.map(self.process_document, input_batch) + # Note: PDF backends are not thread-safe, thread pool usage was disabled. + + for item in map( + partial( + self._process_document, raises_on_error=raises_on_error), + map( + lambda doc: _add_custom_attributes( + doc, '_backend', {'current_url': self.current_url}), + input_batch)): + elapsed = time.monotonic() - start_time + start_time = time.monotonic() + logger.info( + f'Finished converting document {item.input.file.name} in {elapsed:.2f} sec.' + ) + yield item diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py index fda3e152d..04edd3278 100644 --- a/ms_agent/utils/utils.py +++ b/ms_agent/utils/utils.py @@ -11,6 +11,7 @@ import requests import yaml from omegaconf import DictConfig, OmegaConf +from PIL import Image from .logger import get_logger @@ -322,7 +323,6 @@ def load_image_from_url_to_pil(url: str) -> 'Image.Image': Returns: A PIL Image object if successful, None otherwise. """ - from PIL import Image try: response = requests.get(url) # Raise an HTTPError for bad responses (4xx or 5xx) @@ -349,7 +349,6 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image': Returns: tuple: (PIL Image object, file extension string) or None if failed """ - from PIL import Image try: header, encoded = uri.split(',', 1) if ';base64' in header: @@ -373,26 +372,56 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image': return None -def validate_url( - img_url: str, - backend: 'docling.backend.html_backend.HTMLDocumentBackend') -> str: +def resolve_url(img_url: str, + backend: 'docling.backend.html_backend.HTMLDocumentBackend', + base_url: str) -> str: """ Validates and resolves a relative image URL using the base URL from the HTML document's metadata. + If the base URL is not provided or is invalid, the function attempts to resolve relative image + URLs by looking for base URLs in the following order: + 1. `` tag + 2. `` tag + 3. `` tag - This function attempts to resolve relative image URLs by looking for base URLs in the following order: - 1. tag - 2. tag - 3. tag Args: img_url (str): The image URL to validate/resolve backend (HTMLDocumentBackend): The HTML document backend containing the parsed document + base_url (str): The base URL to use for resolving relative URLs. Returns: str: The resolved absolute URL if successful, None otherwise """ from urllib.parse import urljoin, urlparse + base_url = None if base_url is None else base_url.strip() + + def normalized_base(url: str) -> str: + ''' + If the base URL doesn't ends with '/' and the last segment doesn't contain a dot + (meaning it doesn't look like a file), it indicates that it's actually a directory. + In this case, append a '/' to ensure correct concatenation. + ''' + parsed = urlparse(url) + need_append_slash = False + if parsed.scheme in {'http', 'https'}: + if not parsed.path.endswith('/'): + need_append_slash = ('.' not in parsed.path.split('/')[-1] + or 'arxiv.org' in parsed.netloc) + + return url + '/' if need_append_slash else url + + # If the base URL is already valid, just join it with the image URL + if base_url and urlparse(base_url).scheme and urlparse(base_url).netloc: + base_url = normalized_base(base_url) + try: + valid_url = urljoin(base_url, img_url) + return valid_url + except Exception as e: + logger.error( + f'Error joining base URL with image URL: {e}, continuing to resolve...' + ) + # Check if we have a valid soup object in the backend if not backend or not hasattr( backend, 'soup') or not backend.soup or not backend.soup.head: @@ -428,6 +457,34 @@ def validate_url( return img_url +def extract_image( + img_url: Optional[str], + backend: 'docling.backend.html_backend.HTMLDocumentBackend', + base_url: Optional[str] = None, +) -> Optional['Image.Image']: + """ + Extracts an image from image URL, resolving the URL if necessary. + + Args: + img_url (Optional[str]): The image URL to extract. + backend (HTMLDocumentBackend): The HTML document backend containing the parsed document. + base_url (Optional[str]): The base URL to use for resolving relative URLs. + + Returns: + Optional[Image.Image]: A PIL Image object if successful, None otherwise. + """ + if not img_url or not isinstance(img_url, str): + return None + + if img_url.startswith('data:'): + img_pil = load_image_from_uri_to_pil(img_url) + else: + if not img_url.startswith('http'): + img_url = resolve_url(img_url, backend, base_url) + img_pil = load_image_from_url_to_pil(img_url) if img_url else None + return img_pil if isinstance(img_pil, Image.Image) else None + + def get_default_config(): """ Load and return the default configuration from 'ms_agent/agent/agent.yaml'. diff --git a/tests/doc/test_img.py b/tests/doc/test_img.py index c21ddcba5..d050ed854 100644 --- a/tests/doc/test_img.py +++ b/tests/doc/test_img.py @@ -1,80 +1,105 @@ # import os # import unittest -# +# from typing import List + # from ms_agent.tools.docling.doc_loader import DocLoader -# + # from modelscope.utils.test_utils import test_level -# -# + # class TestExtractImage(unittest.TestCase): # base_dir: str = os.path.dirname(os.path.abspath(__file__)) -# absolute_path_img_url: str = 'https://www.chinahighlights.com/hangzhou/food-restaurant.htm' -# relative_path_img_url: str = 'https://github.com/asinghcsu/AgenticRAG-Survey' -# figure_tag_img_url: str = 'https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/' -# data_uri_img_url: str = 'https://arxiv.org/html/2505.16120v1' -# +# absolute_path_img_url: List[str] = [ +# 'https://www.chinahighlights.com/hangzhou/food-restaurant.htm' +# ] +# relative_path_img_url: List[str] = [ +# 'https://github.com/asinghcsu/AgenticRAG-Survey', +# 'https://www.ams.org.cn/article/2021/0412-1961/0412-1961-2021-57-11-1343.shtml', +# 'https://arxiv.org/html/2502.15214' +# ] +# figure_tag_img_url: List[str] = [ +# 'https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/' +# ] +# data_uri_img_url: List[str] = ['https://arxiv.org/html/2505.16120v1'] + # @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') # def test_absolute_path_img(self): # save_dir = os.path.join(self.base_dir, 'absolute_path_img') # if not os.path.exists(save_dir): # os.makedirs(save_dir) # doc_loader = DocLoader() -# doc_results = doc_loader.load( -# urls_or_files=[self.absolute_path_img_url]) -# for idx, pic in enumerate(doc_results[0].pictures): -# print(f'Picture: {pic.self_ref} ...') -# if pic.image: -# pic.image.pil_image.save( -# os.path.join(save_dir, 'picture_' + str(idx) + '.png')) -# assert len( -# doc_results[0].pictures) > 0, 'No pictures found in the document.' -# +# doc_results = doc_loader.load(urls_or_files=self.absolute_path_img_url) +# for task_id, doc in enumerate(doc_results): +# task_save_dir = os.path.join(save_dir, f'task_{task_id}') +# if not os.path.exists(task_save_dir): +# os.makedirs(task_save_dir) +# print(f'Task {task_id} ...') +# for idx, pic in enumerate(doc.pictures): +# print(f'Picture: {pic.self_ref} ...') +# if pic.image: +# pic.image.pil_image.save( +# os.path.join(task_save_dir, +# 'picture_' + str(idx) + '.png')) +# assert len(doc.pictures) > 0, 'No pictures found in the document.' + # @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') # def test_relative_path_img(self): # save_dir = os.path.join(self.base_dir, 'relative_path_img') # if not os.path.exists(save_dir): # os.makedirs(save_dir) # doc_loader = DocLoader() -# doc_results = doc_loader.load( -# urls_or_files=[self.relative_path_img_url]) -# for idx, pic in enumerate(doc_results[0].pictures): -# print(f'Picture: {pic.self_ref} ...') -# if pic.image: -# pic.image.pil_image.save( -# os.path.join(save_dir, 'picture_' + str(idx) + '.png')) -# assert len( -# doc_results[0].pictures) > 0, 'No pictures found in the document.' -# +# doc_results = doc_loader.load(urls_or_files=self.relative_path_img_url) +# for task_id, doc in enumerate(doc_results): +# task_save_dir = os.path.join(save_dir, f'task_{task_id}') +# if not os.path.exists(task_save_dir): +# os.makedirs(task_save_dir) +# print(f'Task {task_id} ...') +# for idx, pic in enumerate(doc.pictures): +# print(f'Picture: {pic.self_ref} ...') +# if pic.image: +# pic.image.pil_image.save( +# os.path.join(task_save_dir, +# 'picture_' + str(idx) + '.png')) +# assert len(doc.pictures) > 0, 'No pictures found in the document.' + # @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') # def test_figure_tag_img(self): # save_dir = os.path.join(self.base_dir, 'figure_tag_img') # if not os.path.exists(save_dir): # os.makedirs(save_dir) # doc_loader = DocLoader() -# doc_results = doc_loader.load(urls_or_files=[self.figure_tag_img_url]) -# for idx, pic in enumerate(doc_results[0].pictures): -# print(f'Picture: {pic.self_ref} ...') -# if pic.image: -# pic.image.pil_image.save( -# os.path.join(save_dir, 'picture_' + str(idx) + '.png')) -# assert len( -# doc_results[0].pictures) > 0, 'No pictures found in the document.' -# +# doc_results = doc_loader.load(urls_or_files=self.figure_tag_img_url) +# for task_id, doc in enumerate(doc_results): +# task_save_dir = os.path.join(save_dir, f'task_{task_id}') +# if not os.path.exists(task_save_dir): +# os.makedirs(task_save_dir) +# print(f'Task {task_id} ...') +# for idx, pic in enumerate(doc.pictures): +# print(f'Picture: {pic.self_ref} ...') +# if pic.image: +# pic.image.pil_image.save( +# os.path.join(task_save_dir, +# 'picture_' + str(idx) + '.png')) +# assert len(doc.pictures) > 0, 'No pictures found in the document.' + # @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') # def test_data_uri_img(self): # save_dir = os.path.join(self.base_dir, 'data_uri_img') # if not os.path.exists(save_dir): # os.makedirs(save_dir) # doc_loader = DocLoader() -# doc_results = doc_loader.load(urls_or_files=[self.data_uri_img_url]) -# for idx, pic in enumerate(doc_results[0].pictures): -# print(f'Picture: {pic.self_ref} ...') -# if pic.image: -# pic.image.pil_image.save( -# os.path.join(save_dir, 'picture_' + str(idx) + '.png')) -# assert len( -# doc_results[0].pictures) > 0, 'No pictures found in the document.' -# -# +# doc_results = doc_loader.load(urls_or_files=self.data_uri_img_url) +# for task_id, doc in enumerate(doc_results): +# task_save_dir = os.path.join(save_dir, f'task_{task_id}') +# if not os.path.exists(task_save_dir): +# os.makedirs(task_save_dir) +# print(f'Task {task_id} ...') +# for idx, pic in enumerate(doc.pictures): +# print(f'Picture: {pic.self_ref} ...') +# if pic.image: +# pic.image.pil_image.save( +# os.path.join(task_save_dir, +# 'picture_' + str(idx) + '.png')) +# assert len(doc.pictures) > 0, 'No pictures found in the document.' + # if __name__ == '__main__': # unittest.main()