diff --git a/ms_agent/tools/docling/doc_loader.py b/ms_agent/tools/docling/doc_loader.py
index a19a03df5..c074115a3 100644
--- a/ms_agent/tools/docling/doc_loader.py
+++ b/ms_agent/tools/docling/doc_loader.py
@@ -15,7 +15,7 @@
from docling_core.types import DoclingDocument
from docling_core.types.doc import DocItem
from ms_agent.tools.docling.doc_postprocess import PostProcess
-from ms_agent.tools.docling.patches import (download_models_ms,
+from ms_agent.tools.docling.patches import (convert_ms, download_models_ms,
download_models_pic_classifier_ms,
html_handle_figure,
html_handle_image,
@@ -211,6 +211,7 @@ def _postprocess(doc: DoclingDocument) -> Union[DoclingDocument, None]:
return doc
+ @patch(DocumentConverter, '_convert', convert_ms)
@patch(LayoutModel, 'download_models', download_models_ms)
@patch(TableStructureModel, 'download_models', download_models_ms)
@patch(DocumentPictureClassifier, 'download_models',
@@ -230,10 +231,16 @@ def load(self, urls_or_files: list[str]) -> List[DoclingDocument]:
# TODO: Support progress bar for document loading (with pather)
results: Iterator[ConversionResult] = self._converter.convert_all(
source=urls_or_files, )
+ iter_urls_or_files = iter(urls_or_files)
final_results = []
while True:
try:
+ # Record the current URL for parsing images
+ setattr(self._converter, 'current_url',
+ next(iter_urls_or_files))
+ assert self._converter.current_url is not None, 'Current URL should not be None'
+
res: ConversionResult = next(results)
if res is None or res.document is None:
continue
diff --git a/ms_agent/tools/docling/patches.py b/ms_agent/tools/docling/patches.py
index ac4f45521..cc32398e5 100644
--- a/ms_agent/tools/docling/patches.py
+++ b/ms_agent/tools/docling/patches.py
@@ -1,12 +1,18 @@
# flake8: noqa
+import time
+from functools import partial
from pathlib import Path
+from typing import Dict, Iterator
from bs4 import Tag
+from docling.datamodel.document import (ConversionResult,
+ _DocumentConversionInput)
+from docling.datamodel.settings import settings
+from docling.utils.utils import chunkify
from docling_core.types import DoclingDocument
from docling_core.types.doc import DocItemLabel, ImageRef
from ms_agent.utils.logger import get_logger
-from ms_agent.utils.utils import (load_image_from_uri_to_pil,
- load_image_from_url_to_pil, validate_url)
+from ms_agent.utils.utils import extract_image
logger = get_logger()
@@ -25,16 +31,11 @@ def html_handle_figure(self, element: Tag, doc: DoclingDocument) -> None:
else:
img_url = None
- if img_url:
- if img_url.startswith('data:'):
- img_pil = load_image_from_uri_to_pil(img_url)
- else:
- if not img_url.startswith('http'):
- img_url = validate_url(img_url=img_url, backend=self)
- img_pil = load_image_from_url_to_pil(
- img_url) if img_url.startswith('http') else None
- else:
- img_pil = None
+ # extract image from url or data URI
+ img_pil: 'Image.Image' = extract_image(
+ img_url=img_url,
+ backend=self,
+ base_url=self.current_url if hasattr(self, 'current_url') else None)
dpi: int = int(img_pil.info.get('dpi', (96, 96))[0]) if img_pil else 96
img_ref: ImageRef = None
@@ -82,15 +83,11 @@ def html_handle_image(self, element: Tag, doc: DoclingDocument) -> None:
# Get the image from element
img_url: str = element.attrs.get('src', None)
- if img_url:
- if img_url.startswith('data:'):
- img_pil = load_image_from_uri_to_pil(img_url)
- else:
- if not img_url.startswith('http'):
- img_url = validate_url(img_url=img_url, backend=self)
- img_pil = load_image_from_url_to_pil(img_url)
- else:
- img_pil = None
+ # extract image from url or data URI
+ img_pil: 'Image.Image' = extract_image(
+ img_url=img_url,
+ backend=self,
+ base_url=self.current_url if hasattr(self, 'current_url') else None)
dpi: int = int(img_pil.info.get('dpi', (96, 96))[0]) if img_pil else 96
@@ -170,3 +167,50 @@ def patch_easyocr_models():
'url'] = 'https://modelscope.cn/models/ms-agent/kannada_g2/resolve/master/kannada_g2.zip'
recognition_models['gen2']['cyrillic_g2'][
'url'] = 'https://modelscope.cn/models/ms-agent/cyrillic_g2/resolve/master/cyrillic_g2.zip'
+
+
+def convert_ms(self, conv_input: _DocumentConversionInput,
+ raises_on_error: bool) -> Iterator[ConversionResult]:
+ """
+ Patch the `docling.document_converter.DocumentConverter._convert` method for image parsing.
+ """
+ start_time = time.monotonic()
+
+ def _add_custom_attributes(doc: DoclingDocument, target: str,
+ attributes_dict: Dict) -> DoclingDocument:
+ """
+ Add custom attributes to the target object.
+ """
+ for key, value in attributes_dict.items():
+ target_obj = getattr(doc, target) if target is not None else doc
+ if not hasattr(target_obj, key):
+ setattr(target_obj, key, value)
+ else:
+ raise ValueError(
+ f"Attribute '{key}' already exists in the document.")
+ return doc
+
+ for input_batch in chunkify(
+ conv_input.docs(self.format_to_options),
+ settings.perf.doc_batch_size, # pass format_options
+ ):
+ # parallel processing only within input_batch
+ # with ThreadPoolExecutor(
+ # max_workers=settings.perf.doc_batch_concurrency
+ # ) as pool:
+ # yield from pool.map(self.process_document, input_batch)
+ # Note: PDF backends are not thread-safe, thread pool usage was disabled.
+
+ for item in map(
+ partial(
+ self._process_document, raises_on_error=raises_on_error),
+ map(
+ lambda doc: _add_custom_attributes(
+ doc, '_backend', {'current_url': self.current_url}),
+ input_batch)):
+ elapsed = time.monotonic() - start_time
+ start_time = time.monotonic()
+ logger.info(
+ f'Finished converting document {item.input.file.name} in {elapsed:.2f} sec.'
+ )
+ yield item
diff --git a/ms_agent/utils/utils.py b/ms_agent/utils/utils.py
index fda3e152d..04edd3278 100644
--- a/ms_agent/utils/utils.py
+++ b/ms_agent/utils/utils.py
@@ -11,6 +11,7 @@
import requests
import yaml
from omegaconf import DictConfig, OmegaConf
+from PIL import Image
from .logger import get_logger
@@ -322,7 +323,6 @@ def load_image_from_url_to_pil(url: str) -> 'Image.Image':
Returns:
A PIL Image object if successful, None otherwise.
"""
- from PIL import Image
try:
response = requests.get(url)
# Raise an HTTPError for bad responses (4xx or 5xx)
@@ -349,7 +349,6 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image':
Returns:
tuple: (PIL Image object, file extension string) or None if failed
"""
- from PIL import Image
try:
header, encoded = uri.split(',', 1)
if ';base64' in header:
@@ -373,26 +372,56 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image':
return None
-def validate_url(
- img_url: str,
- backend: 'docling.backend.html_backend.HTMLDocumentBackend') -> str:
+def resolve_url(img_url: str,
+ backend: 'docling.backend.html_backend.HTMLDocumentBackend',
+ base_url: str) -> str:
"""
Validates and resolves a relative image URL using the base URL from the HTML document's metadata.
+ If the base URL is not provided or is invalid, the function attempts to resolve relative image
+ URLs by looking for base URLs in the following order:
+ 1. `` tag
+ 2. `` tag
+ 3. `` tag
- This function attempts to resolve relative image URLs by looking for base URLs in the following order:
- 1. tag
- 2. tag
- 3. tag
Args:
img_url (str): The image URL to validate/resolve
backend (HTMLDocumentBackend): The HTML document backend containing the parsed document
+ base_url (str): The base URL to use for resolving relative URLs.
Returns:
str: The resolved absolute URL if successful, None otherwise
"""
from urllib.parse import urljoin, urlparse
+ base_url = None if base_url is None else base_url.strip()
+
+ def normalized_base(url: str) -> str:
+ '''
+ If the base URL doesn't ends with '/' and the last segment doesn't contain a dot
+ (meaning it doesn't look like a file), it indicates that it's actually a directory.
+ In this case, append a '/' to ensure correct concatenation.
+ '''
+ parsed = urlparse(url)
+ need_append_slash = False
+ if parsed.scheme in {'http', 'https'}:
+ if not parsed.path.endswith('/'):
+ need_append_slash = ('.' not in parsed.path.split('/')[-1]
+ or 'arxiv.org' in parsed.netloc)
+
+ return url + '/' if need_append_slash else url
+
+ # If the base URL is already valid, just join it with the image URL
+ if base_url and urlparse(base_url).scheme and urlparse(base_url).netloc:
+ base_url = normalized_base(base_url)
+ try:
+ valid_url = urljoin(base_url, img_url)
+ return valid_url
+ except Exception as e:
+ logger.error(
+ f'Error joining base URL with image URL: {e}, continuing to resolve...'
+ )
+
# Check if we have a valid soup object in the backend
if not backend or not hasattr(
backend, 'soup') or not backend.soup or not backend.soup.head:
@@ -428,6 +457,34 @@ def validate_url(
return img_url
+def extract_image(
+ img_url: Optional[str],
+ backend: 'docling.backend.html_backend.HTMLDocumentBackend',
+ base_url: Optional[str] = None,
+) -> Optional['Image.Image']:
+ """
+ Extracts an image from image URL, resolving the URL if necessary.
+
+ Args:
+ img_url (Optional[str]): The image URL to extract.
+ backend (HTMLDocumentBackend): The HTML document backend containing the parsed document.
+ base_url (Optional[str]): The base URL to use for resolving relative URLs.
+
+ Returns:
+ Optional[Image.Image]: A PIL Image object if successful, None otherwise.
+ """
+ if not img_url or not isinstance(img_url, str):
+ return None
+
+ if img_url.startswith('data:'):
+ img_pil = load_image_from_uri_to_pil(img_url)
+ else:
+ if not img_url.startswith('http'):
+ img_url = resolve_url(img_url, backend, base_url)
+ img_pil = load_image_from_url_to_pil(img_url) if img_url else None
+ return img_pil if isinstance(img_pil, Image.Image) else None
+
+
def get_default_config():
"""
Load and return the default configuration from 'ms_agent/agent/agent.yaml'.
diff --git a/tests/doc/test_img.py b/tests/doc/test_img.py
index c21ddcba5..d050ed854 100644
--- a/tests/doc/test_img.py
+++ b/tests/doc/test_img.py
@@ -1,80 +1,105 @@
# import os
# import unittest
-#
+# from typing import List
+
# from ms_agent.tools.docling.doc_loader import DocLoader
-#
+
# from modelscope.utils.test_utils import test_level
-#
-#
+
# class TestExtractImage(unittest.TestCase):
# base_dir: str = os.path.dirname(os.path.abspath(__file__))
-# absolute_path_img_url: str = 'https://www.chinahighlights.com/hangzhou/food-restaurant.htm'
-# relative_path_img_url: str = 'https://github.com/asinghcsu/AgenticRAG-Survey'
-# figure_tag_img_url: str = 'https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/'
-# data_uri_img_url: str = 'https://arxiv.org/html/2505.16120v1'
-#
+# absolute_path_img_url: List[str] = [
+# 'https://www.chinahighlights.com/hangzhou/food-restaurant.htm'
+# ]
+# relative_path_img_url: List[str] = [
+# 'https://github.com/asinghcsu/AgenticRAG-Survey',
+# 'https://www.ams.org.cn/article/2021/0412-1961/0412-1961-2021-57-11-1343.shtml',
+# 'https://arxiv.org/html/2502.15214'
+# ]
+# figure_tag_img_url: List[str] = [
+# 'https://blogs.nvidia.com/blog/what-is-retrieval-augmented-generation/'
+# ]
+# data_uri_img_url: List[str] = ['https://arxiv.org/html/2505.16120v1']
+
# @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
# def test_absolute_path_img(self):
# save_dir = os.path.join(self.base_dir, 'absolute_path_img')
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
# doc_loader = DocLoader()
-# doc_results = doc_loader.load(
-# urls_or_files=[self.absolute_path_img_url])
-# for idx, pic in enumerate(doc_results[0].pictures):
-# print(f'Picture: {pic.self_ref} ...')
-# if pic.image:
-# pic.image.pil_image.save(
-# os.path.join(save_dir, 'picture_' + str(idx) + '.png'))
-# assert len(
-# doc_results[0].pictures) > 0, 'No pictures found in the document.'
-#
+# doc_results = doc_loader.load(urls_or_files=self.absolute_path_img_url)
+# for task_id, doc in enumerate(doc_results):
+# task_save_dir = os.path.join(save_dir, f'task_{task_id}')
+# if not os.path.exists(task_save_dir):
+# os.makedirs(task_save_dir)
+# print(f'Task {task_id} ...')
+# for idx, pic in enumerate(doc.pictures):
+# print(f'Picture: {pic.self_ref} ...')
+# if pic.image:
+# pic.image.pil_image.save(
+# os.path.join(task_save_dir,
+# 'picture_' + str(idx) + '.png'))
+# assert len(doc.pictures) > 0, 'No pictures found in the document.'
+
# @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
# def test_relative_path_img(self):
# save_dir = os.path.join(self.base_dir, 'relative_path_img')
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
# doc_loader = DocLoader()
-# doc_results = doc_loader.load(
-# urls_or_files=[self.relative_path_img_url])
-# for idx, pic in enumerate(doc_results[0].pictures):
-# print(f'Picture: {pic.self_ref} ...')
-# if pic.image:
-# pic.image.pil_image.save(
-# os.path.join(save_dir, 'picture_' + str(idx) + '.png'))
-# assert len(
-# doc_results[0].pictures) > 0, 'No pictures found in the document.'
-#
+# doc_results = doc_loader.load(urls_or_files=self.relative_path_img_url)
+# for task_id, doc in enumerate(doc_results):
+# task_save_dir = os.path.join(save_dir, f'task_{task_id}')
+# if not os.path.exists(task_save_dir):
+# os.makedirs(task_save_dir)
+# print(f'Task {task_id} ...')
+# for idx, pic in enumerate(doc.pictures):
+# print(f'Picture: {pic.self_ref} ...')
+# if pic.image:
+# pic.image.pil_image.save(
+# os.path.join(task_save_dir,
+# 'picture_' + str(idx) + '.png'))
+# assert len(doc.pictures) > 0, 'No pictures found in the document.'
+
# @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
# def test_figure_tag_img(self):
# save_dir = os.path.join(self.base_dir, 'figure_tag_img')
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
# doc_loader = DocLoader()
-# doc_results = doc_loader.load(urls_or_files=[self.figure_tag_img_url])
-# for idx, pic in enumerate(doc_results[0].pictures):
-# print(f'Picture: {pic.self_ref} ...')
-# if pic.image:
-# pic.image.pil_image.save(
-# os.path.join(save_dir, 'picture_' + str(idx) + '.png'))
-# assert len(
-# doc_results[0].pictures) > 0, 'No pictures found in the document.'
-#
+# doc_results = doc_loader.load(urls_or_files=self.figure_tag_img_url)
+# for task_id, doc in enumerate(doc_results):
+# task_save_dir = os.path.join(save_dir, f'task_{task_id}')
+# if not os.path.exists(task_save_dir):
+# os.makedirs(task_save_dir)
+# print(f'Task {task_id} ...')
+# for idx, pic in enumerate(doc.pictures):
+# print(f'Picture: {pic.self_ref} ...')
+# if pic.image:
+# pic.image.pil_image.save(
+# os.path.join(task_save_dir,
+# 'picture_' + str(idx) + '.png'))
+# assert len(doc.pictures) > 0, 'No pictures found in the document.'
+
# @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
# def test_data_uri_img(self):
# save_dir = os.path.join(self.base_dir, 'data_uri_img')
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
# doc_loader = DocLoader()
-# doc_results = doc_loader.load(urls_or_files=[self.data_uri_img_url])
-# for idx, pic in enumerate(doc_results[0].pictures):
-# print(f'Picture: {pic.self_ref} ...')
-# if pic.image:
-# pic.image.pil_image.save(
-# os.path.join(save_dir, 'picture_' + str(idx) + '.png'))
-# assert len(
-# doc_results[0].pictures) > 0, 'No pictures found in the document.'
-#
-#
+# doc_results = doc_loader.load(urls_or_files=self.data_uri_img_url)
+# for task_id, doc in enumerate(doc_results):
+# task_save_dir = os.path.join(save_dir, f'task_{task_id}')
+# if not os.path.exists(task_save_dir):
+# os.makedirs(task_save_dir)
+# print(f'Task {task_id} ...')
+# for idx, pic in enumerate(doc.pictures):
+# print(f'Picture: {pic.self_ref} ...')
+# if pic.image:
+# pic.image.pil_image.save(
+# os.path.join(task_save_dir,
+# 'picture_' + str(idx) + '.png'))
+# assert len(doc.pictures) > 0, 'No pictures found in the document.'
+
# if __name__ == '__main__':
# unittest.main()