Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

File CDK unstructured parser: Improve file type detection #31997

Merged
merged 4 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,29 @@
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType
from unstructured.documents.elements import Formula, ListItem, Title
from unstructured.file_utils.filetype import FileType, detect_filetype
from unstructured.file_utils.filetype import STR_TO_FILETYPE, FileType, detect_filetype

unstructured_partition = None
unstructured_partition_pdf = None
unstructured_partition_docx = None
unstructured_partition_pptx = None
unstructured_optional_decode = None


def _import_unstructured() -> None:
"""Dynamically imported as needed, due to slow import speed."""
global unstructured_partition
global unstructured_partition_pdf
global unstructured_partition_docx
global unstructured_partition_pptx
global unstructured_optional_decode
from unstructured.partition.auto import partition
from unstructured.partition.docx import partition_docx
from unstructured.partition.md import optional_decode
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.pptx import partition_pptx

unstructured_partition = partition
# separate global variables to properly propagate typing
unstructured_partition_pdf = partition_pdf
unstructured_partition_docx = partition_docx
unstructured_partition_pptx = partition_pptx
unstructured_optional_decode = optional_decode


Expand All @@ -52,7 +61,7 @@ async def infer_schema(
logger: logging.Logger,
) -> SchemaType:
with stream_reader.open_file(file, self.file_read_mode, None, logger) as file_handle:
filetype = self._get_filetype(file_handle, file.uri)
filetype = self._get_filetype(file_handle, file)

if filetype not in self._supported_file_types():
raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri)
Expand All @@ -71,46 +80,77 @@ def parse_records(
discovered_schema: Optional[Mapping[str, SchemaType]],
) -> Iterable[Dict[str, Any]]:
with stream_reader.open_file(file, self.file_read_mode, None, logger) as file_handle:
markdown = self._read_file(file_handle, file.uri)
markdown = self._read_file(file_handle, file)
yield {
"content": markdown,
"document_key": file.uri,
}

def _read_file(self, file_handle: IOBase, file_name: str) -> str:
def _read_file(self, file_handle: IOBase, remote_file: RemoteFile) -> str:
_import_unstructured()
if (not unstructured_partition) or (not unstructured_optional_decode):
if (
(not unstructured_partition_pdf)
or (not unstructured_partition_docx)
or (not unstructured_partition_pptx)
or (not unstructured_optional_decode)
):
# check whether unstructured library is actually available for better error message and to ensure proper typing (can't be None after this point)
raise Exception("unstructured library is not available")

filetype = self._get_filetype(file_handle, file_name)
filetype = self._get_filetype(file_handle, remote_file)

if filetype == FileType.MD:
file_content: bytes = file_handle.read()
decoded_content: str = unstructured_optional_decode(file_content)
return decoded_content
if filetype not in self._supported_file_types():
raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=file_name)
raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=remote_file.uri)

file: Any = file_handle
if filetype == FileType.PDF:
# for PDF, read the file into a BytesIO object because some code paths in pdf parsing are doing an instance check on the file object and don't work with file-like objects
file_handle.seek(0)
file = BytesIO(file_handle.read())
file_handle.seek(0)
elements = unstructured_partition_pdf(file=file)
elif filetype == FileType.DOCX:
elements = unstructured_partition_docx(file=file)
elif filetype == FileType.PPTX:
elements = unstructured_partition_pptx(file=file)

elements = unstructured_partition(file=file, metadata_filename=file_name)
return self._render_markdown(elements)

def _get_filetype(self, file: IOBase, file_name: str) -> Any:
def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileType]:
"""
Detect the file type based on the file name and the file content.

There are three strategies to determine the file type:
1. Use the mime type if available (only some sources support it)
2. Use the file name if available
3. Use the file content
"""
if remote_file.mime_type and remote_file.mime_type in STR_TO_FILETYPE:
return STR_TO_FILETYPE[remote_file.mime_type]

# set name to none, otherwise unstructured will try to get the modified date from the local file system
if hasattr(file, "name"):
file.name = None

return detect_filetype(
file=file,
file_filename=file_name,
# detect_filetype is either using the file name or file content
# if possible, try to leverage the file name to detect the file type
# if the file name is not available, use the file content
file_type = detect_filetype(
filename=remote_file.uri,
)
if file_type is not None and not file_type == FileType.UNK:
return file_type

type_based_on_content = detect_filetype(file=file)

# detect_filetype is reading to read the file content
file.seek(0)

return type_based_on_content

def _supported_file_types(self) -> List[Any]:
return [FileType.MD, FileType.PDF, FileType.DOCX, FileType.PPTX]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

from datetime import datetime
from typing import Optional

from pydantic import BaseModel

Expand All @@ -14,3 +15,4 @@ class RemoteFile(BaseModel):

uri: str
last_modified: datetime
mime_type: Optional[str] = None
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
#

import asyncio
from datetime import datetime
from unittest.mock import MagicMock, mock_open, patch

import pytest
from airbyte_cdk.sources.file_based.exceptions import RecordParseError
from airbyte_cdk.sources.file_based.file_types import UnstructuredParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from unstructured.documents.elements import ElementMetadata, Formula, ListItem, Text, Title
from unstructured.file_utils.filetype import FileType

Expand Down Expand Up @@ -143,17 +145,31 @@ def test_infer_schema(mock_detect_filetype, filetype, raises):
),
],
)
@patch("unstructured.partition.auto.partition")
@patch("unstructured.partition.pdf.partition_pdf")
@patch("unstructured.partition.pptx.partition_pptx")
@patch("unstructured.partition.docx.partition_docx")
@patch("unstructured.partition.md.optional_decode")
@patch("airbyte_cdk.sources.file_based.file_types.unstructured_parser.detect_filetype")
def test_parse_records(mock_detect_filetype, mock_optional_decode, mock_partition, filetype, parse_result, raises, expected_records):
def test_parse_records(
mock_detect_filetype,
mock_optional_decode,
mock_partition_docx,
mock_partition_pptx,
mock_partition_pdf,
filetype,
parse_result,
raises,
expected_records,
):
stream_reader = MagicMock()
mock_open(stream_reader.open_file, read_data=bytes(str(parse_result), "utf-8"))
fake_file = MagicMock()
fake_file = RemoteFile(uri=FILE_URI, last_modified=datetime.now())
fake_file.uri = FILE_URI
logger = MagicMock()
mock_detect_filetype.return_value = filetype
mock_partition.return_value = parse_result
mock_partition_docx.return_value = parse_result
mock_partition_pptx.return_value = parse_result
mock_partition_pdf.return_value = parse_result
mock_optional_decode.side_effect = lambda x: x.decode("utf-8")
if raises:
with pytest.raises(RecordParseError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def get_matching_files(
) -> Iterable[RemoteFile]:
yield from self.filter_files_by_globs_and_start_date(
[
RemoteFile(uri=f, last_modified=datetime.strptime(data["last_modified"], "%Y-%m-%dT%H:%M:%S.%fZ"))
RemoteFile(
uri=f,
mime_type=data.get("mime_type", None),
last_modified=datetime.strptime(data["last_modified"], "%Y-%m-%dT%H:%M:%S.%fZ"),
)
for f, data in self.files.items()
],
globs,
Expand Down

Large diffs are not rendered by default.