Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ms_agent/tools/docling/doc_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from docling_core.types import DoclingDocument
from docling_core.types.doc import DocItem
from ms_agent.tools.docling.doc_postprocess import PostProcess
from ms_agent.tools.docling.patches import (download_models_ms,
from ms_agent.tools.docling.patches import (convert_ms, download_models_ms,
download_models_pic_classifier_ms,
html_handle_figure,
html_handle_image,
Expand Down Expand Up @@ -211,6 +211,7 @@ def _postprocess(doc: DoclingDocument) -> Union[DoclingDocument, None]:

return doc

@patch(DocumentConverter, '_convert', convert_ms)
@patch(LayoutModel, 'download_models', download_models_ms)
@patch(TableStructureModel, 'download_models', download_models_ms)
@patch(DocumentPictureClassifier, 'download_models',
Expand All @@ -230,10 +231,16 @@ def load(self, urls_or_files: list[str]) -> List[DoclingDocument]:
# TODO: Support progress bar for document loading (with pather)
results: Iterator[ConversionResult] = self._converter.convert_all(
source=urls_or_files, )
iter_urls_or_files = iter(urls_or_files)

final_results = []
while True:
try:
# Record the current URL for parsing images
setattr(self._converter, 'current_url',
next(iter_urls_or_files))
assert self._converter.current_url is not None, 'Current URL should not be None'

res: ConversionResult = next(results)
if res is None or res.document is None:
continue
Expand Down
86 changes: 65 additions & 21 deletions ms_agent/tools/docling/patches.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
# flake8: noqa
import time
from functools import partial
from pathlib import Path
from typing import Dict, Iterator

from bs4 import Tag
from docling.datamodel.document import (ConversionResult,
_DocumentConversionInput)
from docling.datamodel.settings import settings
from docling.utils.utils import chunkify
from docling_core.types import DoclingDocument
from docling_core.types.doc import DocItemLabel, ImageRef
from ms_agent.utils.logger import get_logger
from ms_agent.utils.utils import (load_image_from_uri_to_pil,
load_image_from_url_to_pil, validate_url)
from ms_agent.utils.utils import extract_image

logger = get_logger()

Expand All @@ -25,16 +31,11 @@ def html_handle_figure(self, element: Tag, doc: DoclingDocument) -> None:
else:
img_url = None

if img_url:
if img_url.startswith('data:'):
img_pil = load_image_from_uri_to_pil(img_url)
else:
if not img_url.startswith('http'):
img_url = validate_url(img_url=img_url, backend=self)
img_pil = load_image_from_url_to_pil(
img_url) if img_url.startswith('http') else None
else:
img_pil = None
# extract image from url or data URI
img_pil: 'Image.Image' = extract_image(
img_url=img_url,
backend=self,
base_url=self.current_url if hasattr(self, 'current_url') else None)

dpi: int = int(img_pil.info.get('dpi', (96, 96))[0]) if img_pil else 96
img_ref: ImageRef = None
Expand Down Expand Up @@ -82,15 +83,11 @@ def html_handle_image(self, element: Tag, doc: DoclingDocument) -> None:
# Get the image from element
img_url: str = element.attrs.get('src', None)

if img_url:
if img_url.startswith('data:'):
img_pil = load_image_from_uri_to_pil(img_url)
else:
if not img_url.startswith('http'):
img_url = validate_url(img_url=img_url, backend=self)
img_pil = load_image_from_url_to_pil(img_url)
else:
img_pil = None
# extract image from url or data URI
img_pil: 'Image.Image' = extract_image(
img_url=img_url,
backend=self,
base_url=self.current_url if hasattr(self, 'current_url') else None)

dpi: int = int(img_pil.info.get('dpi', (96, 96))[0]) if img_pil else 96

Expand Down Expand Up @@ -170,3 +167,50 @@ def patch_easyocr_models():
'url'] = 'https://modelscope.cn/models/ms-agent/kannada_g2/resolve/master/kannada_g2.zip'
recognition_models['gen2']['cyrillic_g2'][
'url'] = 'https://modelscope.cn/models/ms-agent/cyrillic_g2/resolve/master/cyrillic_g2.zip'


def convert_ms(self, conv_input: _DocumentConversionInput,
raises_on_error: bool) -> Iterator[ConversionResult]:
"""
Patch the `docling.document_converter.DocumentConverter._convert` method for image parsing.
"""
start_time = time.monotonic()

def _add_custom_attributes(doc: DoclingDocument, target: str,
attributes_dict: Dict) -> DoclingDocument:
"""
Add custom attributes to the target object.
"""
for key, value in attributes_dict.items():
target_obj = getattr(doc, target) if target is not None else doc
if not hasattr(target_obj, key):
setattr(target_obj, key, value)
else:
raise ValueError(
f"Attribute '{key}' already exists in the document.")
return doc

for input_batch in chunkify(
conv_input.docs(self.format_to_options),
settings.perf.doc_batch_size, # pass format_options
):
# parallel processing only within input_batch
# with ThreadPoolExecutor(
# max_workers=settings.perf.doc_batch_concurrency
# ) as pool:
# yield from pool.map(self.process_document, input_batch)
# Note: PDF backends are not thread-safe, thread pool usage was disabled.

for item in map(
partial(
self._process_document, raises_on_error=raises_on_error),
map(
lambda doc: _add_custom_attributes(
doc, '_backend', {'current_url': self.current_url}),
input_batch)):
elapsed = time.monotonic() - start_time
start_time = time.monotonic()
logger.info(
f'Finished converting document {item.input.file.name} in {elapsed:.2f} sec.'
)
yield item
75 changes: 66 additions & 9 deletions ms_agent/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import requests
import yaml
from omegaconf import DictConfig, OmegaConf
from PIL import Image

from .logger import get_logger

Expand Down Expand Up @@ -322,7 +323,6 @@ def load_image_from_url_to_pil(url: str) -> 'Image.Image':
Returns:
A PIL Image object if successful, None otherwise.
"""
from PIL import Image
try:
response = requests.get(url)
# Raise an HTTPError for bad responses (4xx or 5xx)
Expand All @@ -349,7 +349,6 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image':
Returns:
tuple: (PIL Image object, file extension string) or None if failed
"""
from PIL import Image
try:
header, encoded = uri.split(',', 1)
if ';base64' in header:
Expand All @@ -373,26 +372,56 @@ def load_image_from_uri_to_pil(uri: str) -> 'Image.Image':
return None


def validate_url(
img_url: str,
backend: 'docling.backend.html_backend.HTMLDocumentBackend') -> str:
def resolve_url(img_url: str,
backend: 'docling.backend.html_backend.HTMLDocumentBackend',
base_url: str) -> str:
"""
Validates and resolves a relative image URL using the base URL from the HTML document's metadata.
If the base URL is not provided or is invalid, the function attempts to resolve relative image
URLs by looking for base URLs in the following order:
1. `<base href="...">` tag
2. `<link rel="canonical" href="...">` tag
3. `<meta property="og:url" content="...">` tag

This function attempts to resolve relative image URLs by looking for base URLs in the following order:
1. <base href="..."> tag
2. <link rel="canonical" href="..."> tag
3. <meta property="og:url" content="..."> tag

Args:
img_url (str): The image URL to validate/resolve
backend (HTMLDocumentBackend): The HTML document backend containing the parsed document
base_url (str): The base URL to use for resolving relative URLs.

Returns:
str: The resolved absolute URL if successful, None otherwise
"""
from urllib.parse import urljoin, urlparse

base_url = None if base_url is None else base_url.strip()

def normalized_base(url: str) -> str:
'''
If the base URL doesn't ends with '/' and the last segment doesn't contain a dot
(meaning it doesn't look like a file), it indicates that it's actually a directory.
In this case, append a '/' to ensure correct concatenation.
'''
parsed = urlparse(url)
need_append_slash = False
if parsed.scheme in {'http', 'https'}:
if not parsed.path.endswith('/'):
need_append_slash = ('.' not in parsed.path.split('/')[-1]
or 'arxiv.org' in parsed.netloc)
Comment on lines +409 to +410

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic includes a hardcoded special case for arxiv.org. This makes the function less generic and might require adding more special cases in the future as you encounter other sites with similar URL structures. Consider adding a code comment to explain why this special case is necessary. If a more general solution is feasible, that would be even better for long-term maintainability.


return url + '/' if need_append_slash else url

# If the base URL is already valid, just join it with the image URL
if base_url and urlparse(base_url).scheme and urlparse(base_url).netloc:
base_url = normalized_base(base_url)
try:
valid_url = urljoin(base_url, img_url)
return valid_url
except Exception as e:
logger.error(
f'Error joining base URL with image URL: {e}, continuing to resolve...'
)

# Check if we have a valid soup object in the backend
if not backend or not hasattr(
backend, 'soup') or not backend.soup or not backend.soup.head:
Expand Down Expand Up @@ -428,6 +457,34 @@ def validate_url(
return img_url


def extract_image(
img_url: Optional[str],
backend: 'docling.backend.html_backend.HTMLDocumentBackend',
base_url: Optional[str] = None,
) -> Optional['Image.Image']:
"""
Extracts an image from image URL, resolving the URL if necessary.

Args:
img_url (Optional[str]): The image URL to extract.
backend (HTMLDocumentBackend): The HTML document backend containing the parsed document.
base_url (Optional[str]): The base URL to use for resolving relative URLs.

Returns:
Optional[Image.Image]: A PIL Image object if successful, None otherwise.
"""
if not img_url or not isinstance(img_url, str):
return None

if img_url.startswith('data:'):
img_pil = load_image_from_uri_to_pil(img_url)
else:
if not img_url.startswith('http'):
img_url = resolve_url(img_url, backend, base_url)
img_pil = load_image_from_url_to_pil(img_url) if img_url else None
return img_pil if isinstance(img_pil, Image.Image) else None


def get_default_config():
"""
Load and return the default configuration from 'ms_agent/agent/agent.yaml'.
Expand Down
Loading
Loading