From e3d277c8bff3381a58222b1ecbf5fce1002bff29 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Tue, 22 Jul 2025 11:41:35 +0800 Subject: [PATCH 1/5] add examples in unittest --- tests/doc/test_img.py | 103 ++++++++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 38 deletions(-) diff --git a/tests/doc/test_img.py b/tests/doc/test_img.py index 457f0796..bc8fcf8f 100644 --- a/tests/doc/test_img.py +++ b/tests/doc/test_img.py @@ -1,5 +1,6 @@ import os import unittest +from typing import List from ms_agent.tools.docling.doc_loader import DocLoader @@ -10,10 +11,18 @@ class TestExtractImage(unittest.TestCase): base_dir: str = os.path.dirname(os.path.abspath(__file__)) - absolute_path_img_url: str = 'https://www.chinahighlights.com/hangzhou/food-restaurant.htm' - relative_path_img_url: str = 'https://github.com/asinghcsu/AgenticRAG-Survey' - figure_tag_img_url: str = 'https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/' - data_uri_img_url: str = 'https://arxiv.org/html/2505.16120v1' + absolute_path_img_url: List[str] = [ + 'https://www.chinahighlights.com/hangzhou/food-restaurant.htm' + ] + relative_path_img_url: List[str] = [ + 'https://github.com/asinghcsu/AgenticRAG-Survey', + 'https://www.ams.org.cn/article/2021/0412-1961/0412-1961-2021-57-11-1343.shtml', + 'https://arxiv.org/html/2502.15214' + ] + figure_tag_img_url: List[str] = [ + 'https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/' + ] + data_uri_img_url: List[str] = ['https://arxiv.org/html/2505.16120v1'] @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_absolute_path_img(self): @@ -21,15 +30,19 @@ def test_absolute_path_img(self): if not os.path.exists(save_dir): os.makedirs(save_dir) doc_loader = DocLoader() - doc_results = doc_loader.load( - urls_or_files=[self.absolute_path_img_url]) - for idx, pic in enumerate(doc_results[0].pictures): - print(f'Picture: {pic.self_ref} ...') - if pic.image: - pic.image.pil_image.save( - os.path.join(save_dir, 'picture_' + str(idx) + '.png')) - assert len( - doc_results[0].pictures) > 0, 'No pictures found in the document.' + doc_results = doc_loader.load(urls_or_files=self.absolute_path_img_url) + for task_id, doc in enumerate(doc_results): + task_save_dir = os.path.join(save_dir, f'task_{task_id}') + if not os.path.exists(task_save_dir): + os.makedirs(task_save_dir) + print(f'Task {task_id} ...') + for idx, pic in enumerate(doc.pictures): + print(f'Picture: {pic.self_ref} ...') + if pic.image: + pic.image.pil_image.save( + os.path.join(task_save_dir, + 'picture_' + str(idx) + '.png')) + assert len(doc.pictures) > 0, 'No pictures found in the document.' @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_relative_path_img(self): @@ -37,15 +50,19 @@ def test_relative_path_img(self): if not os.path.exists(save_dir): os.makedirs(save_dir) doc_loader = DocLoader() - doc_results = doc_loader.load( - urls_or_files=[self.relative_path_img_url]) - for idx, pic in enumerate(doc_results[0].pictures): - print(f'Picture: {pic.self_ref} ...') - if pic.image: - pic.image.pil_image.save( - os.path.join(save_dir, 'picture_' + str(idx) + '.png')) - assert len( - doc_results[0].pictures) > 0, 'No pictures found in the document.' + doc_results = doc_loader.load(urls_or_files=self.relative_path_img_url) + for task_id, doc in enumerate(doc_results): + task_save_dir = os.path.join(save_dir, f'task_{task_id}') + if not os.path.exists(task_save_dir): + os.makedirs(task_save_dir) + print(f'Task {task_id} ...') + for idx, pic in enumerate(doc.pictures): + print(f'Picture: {pic.self_ref} ...') + if pic.image: + pic.image.pil_image.save( + os.path.join(task_save_dir, + 'picture_' + str(idx) + '.png')) + assert len(doc.pictures) > 0, 'No pictures found in the document.' @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_figure_tag_img(self): @@ -53,14 +70,19 @@ def test_figure_tag_img(self): if not os.path.exists(save_dir): os.makedirs(save_dir) doc_loader = DocLoader() - doc_results = doc_loader.load(urls_or_files=[self.figure_tag_img_url]) - for idx, pic in enumerate(doc_results[0].pictures): - print(f'Picture: {pic.self_ref} ...') - if pic.image: - pic.image.pil_image.save( - os.path.join(save_dir, 'picture_' + str(idx) + '.png')) - assert len( - doc_results[0].pictures) > 0, 'No pictures found in the document.' + doc_results = doc_loader.load(urls_or_files=self.figure_tag_img_url) + for task_id, doc in enumerate(doc_results): + task_save_dir = os.path.join(save_dir, f'task_{task_id}') + if not os.path.exists(task_save_dir): + os.makedirs(task_save_dir) + print(f'Task {task_id} ...') + for idx, pic in enumerate(doc.pictures): + print(f'Picture: {pic.self_ref} ...') + if pic.image: + pic.image.pil_image.save( + os.path.join(task_save_dir, + 'picture_' + str(idx) + '.png')) + assert len(doc.pictures) > 0, 'No pictures found in the document.' @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_data_uri_img(self): @@ -68,14 +90,19 @@ def test_data_uri_img(self): if not os.path.exists(save_dir): os.makedirs(save_dir) doc_loader = DocLoader() - doc_results = doc_loader.load(urls_or_files=[self.data_uri_img_url]) - for idx, pic in enumerate(doc_results[0].pictures): - print(f'Picture: {pic.self_ref} ...') - if pic.image: - pic.image.pil_image.save( - os.path.join(save_dir, 'picture_' + str(idx) + '.png')) - assert len( - doc_results[0].pictures) > 0, 'No pictures found in the document.' + doc_results = doc_loader.load(urls_or_files=self.data_uri_img_url) + for task_id, doc in enumerate(doc_results): + task_save_dir = os.path.join(save_dir, f'task_{task_id}') + if not os.path.exists(task_save_dir): + os.makedirs(task_save_dir) + print(f'Task {task_id} ...') + for idx, pic in enumerate(doc.pictures): + print(f'Picture: {pic.self_ref} ...') + if pic.image: + pic.image.pil_image.save( + os.path.join(task_save_dir, + 'picture_' + str(idx) + '.png')) + assert len(doc.pictures) > 0, 'No pictures found in the document.' if __name__ == '__main__': From 235bb2972bd34b76c1a6b2297c7458b2f99653e2 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Tue, 22 Jul 2025 13:02:59 +0800 Subject: [PATCH 2/5] 1.support capturing current base url in backend 2.support fecthing image using base url 3.support normalizing base url --- ms_agent/tools/docling/doc_loader.py | 96 +++++++++++++++++++++------- ms_agent/utils/utils.py | 74 ++++++++++++++++++--- 2 files changed, 138 insertions(+), 32 deletions(-) diff --git a/ms_agent/tools/docling/doc_loader.py b/ms_agent/tools/docling/doc_loader.py index 6d8bad65..b9c74e52 100644 --- a/ms_agent/tools/docling/doc_loader.py +++ b/ms_agent/tools/docling/doc_loader.py @@ -1,30 +1,82 @@ # flake8: noqa import os +import time +from functools import partial from pathlib import Path -from typing import Dict, Iterator, List, Union +from typing import Dict, Iterator, List, Optional, Union from bs4 import Tag from docling.backend.html_backend import HTMLDocumentBackend from docling.datamodel.accelerator_options import AcceleratorOptions from docling.datamodel.base_models import InputFormat -from docling.datamodel.document import ConversionResult +from docling.datamodel.document import (ConversionResult, + _DocumentConversionInput) from docling.datamodel.pipeline_options import PdfPipelineOptions +from docling.datamodel.settings import settings from docling.document_converter import DocumentConverter, PdfFormatOption from docling.models.document_picture_classifier import \ DocumentPictureClassifier from docling.models.layout_model import LayoutModel from docling.models.table_structure_model import TableStructureModel +from docling.utils.utils import chunkify from docling_core.types import DoclingDocument from docling_core.types.doc import DocItem, DocItemLabel, ImageRef from ms_agent.tools.docling.doc_postprocess import PostProcess from ms_agent.utils.logger import get_logger from ms_agent.utils.patcher import patch -from ms_agent.utils.utils import (load_image_from_uri_to_pil, - load_image_from_url_to_pil, validate_url) +from ms_agent.utils.utils import extract_image +from PIL import Image logger = get_logger() +def convert_ms(self, conv_input: _DocumentConversionInput, + raises_on_error: bool) -> Iterator[ConversionResult]: + """ + Patch the `docling.document_converter.DocumentConverter._convert` method for image parsing. + """ + start_time = time.monotonic() + + def _add_custom_attributes(doc: DoclingDocument, target: str, + attributes_dict: Dict) -> DoclingDocument: + """ + Add custom attributes to the target object. + """ + for key, value in attributes_dict.items(): + target_obj = getattr(doc, target) if target is not None else doc + if not hasattr(target_obj, key): + setattr(target_obj, key, value) + else: + raise ValueError( + f"Attribute '{key}' already exists in the document.") + return doc + + for input_batch in chunkify( + conv_input.docs(self.format_to_options), + settings.perf.doc_batch_size, # pass format_options + ): + # parallel processing only within input_batch + # with ThreadPoolExecutor( + # max_workers=settings.perf.doc_batch_concurrency + # ) as pool: + # yield from pool.map(self.process_document, input_batch) + # Note: PDF backends are not thread-safe, thread pool usage was disabled. + + for item in map( + partial( + self._process_document, raises_on_error=raises_on_error), + map( + lambda doc: _add_custom_attributes( + doc, '_backend', {'current_url': self.current_url}), + input_batch)): + elapsed = time.monotonic() - start_time + start_time = time.monotonic() + logger.info( + f'Finished converting document {item.input.file.name} in {elapsed:.2f} sec.' + ) + yield item + + def html_handle_figure(self, element: Tag, doc: DoclingDocument) -> None: """ Patch the `docling.backend.html_backend.HTMLDocumentBackend.handle_figure` method. @@ -39,16 +91,11 @@ def html_handle_figure(self, element: Tag, doc: DoclingDocument) -> None: else: img_url = None - if img_url: - if img_url.startswith('data:'): - img_pil = load_image_from_uri_to_pil(img_url) - else: - if not img_url.startswith('http'): - img_url = validate_url(img_url=img_url, backend=self) - img_pil = load_image_from_url_to_pil( - img_url) if img_url.startswith('http') else None - else: - img_pil = None + # extract image from url or data URI + img_pil: Optional[Image.Image] = extract_image( + img_url=img_url, + backend=self, + base_url=self.current_url if hasattr(self, 'current_url') else None) dpi: int = int(img_pil.info.get('dpi', (96, 96))[0]) if img_pil else 96 img_ref: ImageRef = None @@ -96,15 +143,11 @@ def html_handle_image(self, element: Tag, doc: DoclingDocument) -> None: # Get the image from element img_url: str = element.attrs.get('src', None) - if img_url: - if img_url.startswith('data:'): - img_pil = load_image_from_uri_to_pil(img_url) - else: - if not img_url.startswith('http'): - img_url = validate_url(img_url=img_url, backend=self) - img_pil = load_image_from_url_to_pil(img_url) - else: - img_pil = None + # extract image from url or data URI + img_pil: Optional[Image.Image] = extract_image( + img_url=img_url, + backend=self, + base_url=self.current_url if hasattr(self, 'current_url') else None) dpi: int = int(img_pil.info.get('dpi', (96, 96))[0]) if img_pil else 96 @@ -325,6 +368,7 @@ def _postprocess(doc: DoclingDocument) -> Union[DoclingDocument, None]: return doc + @patch(DocumentConverter, '_convert', convert_ms) @patch(LayoutModel, 'download_models', download_models_ms) @patch(TableStructureModel, 'download_models', download_models_ms) @patch(DocumentPictureClassifier, 'download_models', @@ -338,10 +382,16 @@ def load(self, urls_or_files: list[str]) -> List[DoclingDocument]: # TODO: Support progress bar for document loading (with pather) results: Iterator[ConversionResult] = self._converter.convert_all( source=urls_or_files, ) + iter_urls_or_files = iter(urls_or_files) final_results = [] while True: try: + # Record the current URL for parsing images + setattr(self._converter, 'current_url', + next(iter_urls_or_files)) + assert self._converter.current_url is not None, 'Current URL should not be None' + res: ConversionResult = next(results) if res is None or res.document is None: continue diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py index ee2a4340..69d47c33 100644 --- a/ms_agent/utils/utils.py +++ b/ms_agent/utils/utils.py @@ -10,6 +10,7 @@ import json import requests from omegaconf import DictConfig, OmegaConf +from PIL import Image from .logger import get_logger @@ -321,7 +322,6 @@ def load_image_from_url_to_pil(url: str) -> 'Image.Image': Returns: A PIL Image object if successful, None otherwise. """ - from PIL import Image try: response = requests.get(url) # Raise an HTTPError for bad responses (4xx or 5xx) @@ -348,7 +348,6 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image': Returns: tuple: (PIL Image object, file extension string) or None if failed """ - from PIL import Image try: header, encoded = uri.split(',', 1) if ';base64' in header: @@ -371,26 +370,55 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image': return None -def validate_url( - img_url: str, - backend: 'docling.backend.html_backend.HTMLDocumentBackend') -> str: +def resolve_url(img_url: str, + backend: 'docling.backend.html_backend.HTMLDocumentBackend', + base_url: str) -> str: """ Validates and resolves a relative image URL using the base URL from the HTML document's metadata. + If the base URL is not provided or is invalid, the function attempts to resolve relative image + URLs by looking for base URLs in the following order: + 1. `` tag + 2. `` tag + 3. `` tag - This function attempts to resolve relative image URLs by looking for base URLs in the following order: - 1. tag - 2. tag - 3. tag Args: img_url (str): The image URL to validate/resolve backend (HTMLDocumentBackend): The HTML document backend containing the parsed document + base_url (str): The base URL to use for resolving relative URLs. Returns: str: The resolved absolute URL if successful, None otherwise """ from urllib.parse import urljoin, urlparse + base_url = None if base_url is None else base_url.strip() + + def normalized_base(url: str) -> str: + ''' + If the base URL doesn't ends with '/' and the last segment doesn't contain a dot + (meaning it doesn't look like a file), it indicates that it's actually a directory. + In this case, append a '/' to ensure correct concatenation. + ''' + parsed = urlparse(url) + if parsed.scheme in {'http', 'https'}: + if not parsed.path.endswith('/'): + need_append_slash = ('.' not in parsed.path.split('/')[-1] + or 'arxiv.org' in parsed.netloc) + + return url + '/' if need_append_slash else url + + # If the base URL is already valid, just join it with the image URL + if base_url and urlparse(base_url).scheme and urlparse(base_url).netloc: + base_url = normalized_base(base_url) + try: + valid_url = urljoin(base_url, img_url) + return valid_url + except Exception as e: + logger.error( + f'Error joining base URL with image URL: {e}, continuing to resolve...' + ) + # Check if we have a valid soup object in the backend if not backend or not hasattr( backend, 'soup') or not backend.soup or not backend.soup.head: @@ -424,3 +452,31 @@ def validate_url( # No valid base URL found return img_url + + +def extract_image( + img_url: Optional[str], + backend: 'docling.backend.html_backend.HTMLDocumentBackend', + base_url: Optional[str] = None, +) -> Optional['Image.Image']: + """ + Extracts an image from image URL, resolving the URL if necessary. + + Args: + img_url (Optional[str]): The image URL to extract. + backend (HTMLDocumentBackend): The HTML document backend containing the parsed document. + base_url (Optional[str]): The base URL to use for resolving relative URLs. + + Returns: + Optional[Image.Image]: A PIL Image object if successful, None otherwise. + """ + if not img_url or not isinstance(img_url, str): + return None + + if img_url.startswith('data:'): + img_pil = load_image_from_uri_to_pil(img_url) + else: + if not img_url.startswith('http'): + img_url = resolve_url(img_url, backend, base_url) + img_pil = load_image_from_url_to_pil(img_url) if img_url else None + return img_pil if isinstance(img_pil, Image.Image) else None From a691ed41d8c93a9b770cd3ac7c25fc49bf9990b0 Mon Sep 17 00:00:00 2001 From: Gongsheng Li <58078985+alcholiclg@users.noreply.github.com> Date: Tue, 22 Jul 2025 13:50:54 +0800 Subject: [PATCH 3/5] Update ms_agent/utils/utils.py initialize need_append_slash to False Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- ms_agent/utils/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py index 69d47c33..b7a5207f 100644 --- a/ms_agent/utils/utils.py +++ b/ms_agent/utils/utils.py @@ -401,12 +401,13 @@ def normalized_base(url: str) -> str: In this case, append a '/' to ensure correct concatenation. ''' parsed = urlparse(url) - if parsed.scheme in {'http', 'https'}: - if not parsed.path.endswith('/'): - need_append_slash = ('.' not in parsed.path.split('/')[-1] - or 'arxiv.org' in parsed.netloc) + need_append_slash = False + if parsed.scheme in {'http', 'https'}: + if not parsed.path.endswith('/'): + need_append_slash = ('.' not in parsed.path.split('/')[-1] + or 'arxiv.org' in parsed.netloc) - return url + '/' if need_append_slash else url + return url + '/' if need_append_slash else url # If the base URL is already valid, just join it with the image URL if base_url and urlparse(base_url).scheme and urlparse(base_url).netloc: From 3b4644d6364b57d9ef264bd66693744249eb440d Mon Sep 17 00:00:00 2001 From: Gongsheng Li <58078985+alcholiclg@users.noreply.github.com> Date: Tue, 22 Jul 2025 13:54:03 +0800 Subject: [PATCH 4/5] fix wrong indentation introduced by Gemini --- ms_agent/utils/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py index b7a5207f..71eae98f 100644 --- a/ms_agent/utils/utils.py +++ b/ms_agent/utils/utils.py @@ -401,11 +401,11 @@ def normalized_base(url: str) -> str: In this case, append a '/' to ensure correct concatenation. ''' parsed = urlparse(url) - need_append_slash = False - if parsed.scheme in {'http', 'https'}: - if not parsed.path.endswith('/'): - need_append_slash = ('.' not in parsed.path.split('/')[-1] - or 'arxiv.org' in parsed.netloc) + need_append_slash = False + if parsed.scheme in {'http', 'https'}: + if not parsed.path.endswith('/'): + need_append_slash = ('.' not in parsed.path.split('/')[-1] + or 'arxiv.org' in parsed.netloc) return url + '/' if need_append_slash else url From 17315944cc61ed35e4a025346b3da88ab42ee6ac Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Tue, 22 Jul 2025 13:58:00 +0800 Subject: [PATCH 5/5] fix another wrong indentation introduced by Gemini --- ms_agent/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py index 71eae98f..c2a2eb4e 100644 --- a/ms_agent/utils/utils.py +++ b/ms_agent/utils/utils.py @@ -407,7 +407,7 @@ def normalized_base(url: str) -> str: need_append_slash = ('.' not in parsed.path.split('/')[-1] or 'arxiv.org' in parsed.netloc) - return url + '/' if need_append_slash else url + return url + '/' if need_append_slash else url # If the base URL is already valid, just join it with the image URL if base_url and urlparse(base_url).scheme and urlparse(base_url).netloc: