Skip to content

Commit

Permalink
File CDK unstructured parser: Improve file type detection (#31997)
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Reuter authored Nov 2, 2023
1 parent 402ac60 commit 66dd29f
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,29 @@
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from airbyte_cdk.sources.file_based.schema_helpers import SchemaType
from unstructured.documents.elements import Formula, ListItem, Title
from unstructured.file_utils.filetype import FileType, detect_filetype
from unstructured.file_utils.filetype import STR_TO_FILETYPE, FileType, detect_filetype

unstructured_partition = None
unstructured_partition_pdf = None
unstructured_partition_docx = None
unstructured_partition_pptx = None
unstructured_optional_decode = None


def _import_unstructured() -> None:
"""Dynamically imported as needed, due to slow import speed."""
global unstructured_partition
global unstructured_partition_pdf
global unstructured_partition_docx
global unstructured_partition_pptx
global unstructured_optional_decode
from unstructured.partition.auto import partition
from unstructured.partition.docx import partition_docx
from unstructured.partition.md import optional_decode
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.pptx import partition_pptx

unstructured_partition = partition
# separate global variables to properly propagate typing
unstructured_partition_pdf = partition_pdf
unstructured_partition_docx = partition_docx
unstructured_partition_pptx = partition_pptx
unstructured_optional_decode = optional_decode


Expand All @@ -52,7 +61,7 @@ async def infer_schema(
logger: logging.Logger,
) -> SchemaType:
with stream_reader.open_file(file, self.file_read_mode, None, logger) as file_handle:
filetype = self._get_filetype(file_handle, file.uri)
filetype = self._get_filetype(file_handle, file)

if filetype not in self._supported_file_types():
raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=file.uri)
Expand All @@ -71,46 +80,77 @@ def parse_records(
discovered_schema: Optional[Mapping[str, SchemaType]],
) -> Iterable[Dict[str, Any]]:
with stream_reader.open_file(file, self.file_read_mode, None, logger) as file_handle:
markdown = self._read_file(file_handle, file.uri)
markdown = self._read_file(file_handle, file)
yield {
"content": markdown,
"document_key": file.uri,
}

def _read_file(self, file_handle: IOBase, file_name: str) -> str:
def _read_file(self, file_handle: IOBase, remote_file: RemoteFile) -> str:
_import_unstructured()
if (not unstructured_partition) or (not unstructured_optional_decode):
if (
(not unstructured_partition_pdf)
or (not unstructured_partition_docx)
or (not unstructured_partition_pptx)
or (not unstructured_optional_decode)
):
# check whether unstructured library is actually available for better error message and to ensure proper typing (can't be None after this point)
raise Exception("unstructured library is not available")

filetype = self._get_filetype(file_handle, file_name)
filetype = self._get_filetype(file_handle, remote_file)

if filetype == FileType.MD:
file_content: bytes = file_handle.read()
decoded_content: str = unstructured_optional_decode(file_content)
return decoded_content
if filetype not in self._supported_file_types():
raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=file_name)
raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD, filename=remote_file.uri)

file: Any = file_handle
if filetype == FileType.PDF:
# for PDF, read the file into a BytesIO object because some code paths in pdf parsing are doing an instance check on the file object and don't work with file-like objects
file_handle.seek(0)
file = BytesIO(file_handle.read())
file_handle.seek(0)
elements = unstructured_partition_pdf(file=file)
elif filetype == FileType.DOCX:
elements = unstructured_partition_docx(file=file)
elif filetype == FileType.PPTX:
elements = unstructured_partition_pptx(file=file)

elements = unstructured_partition(file=file, metadata_filename=file_name)
return self._render_markdown(elements)

def _get_filetype(self, file: IOBase, file_name: str) -> Any:
def _get_filetype(self, file: IOBase, remote_file: RemoteFile) -> Optional[FileType]:
"""
Detect the file type based on the file name and the file content.
There are three strategies to determine the file type:
1. Use the mime type if available (only some sources support it)
2. Use the file name if available
3. Use the file content
"""
if remote_file.mime_type and remote_file.mime_type in STR_TO_FILETYPE:
return STR_TO_FILETYPE[remote_file.mime_type]

# set name to none, otherwise unstructured will try to get the modified date from the local file system
if hasattr(file, "name"):
file.name = None

return detect_filetype(
file=file,
file_filename=file_name,
# detect_filetype is either using the file name or file content
# if possible, try to leverage the file name to detect the file type
# if the file name is not available, use the file content
file_type = detect_filetype(
filename=remote_file.uri,
)
if file_type is not None and not file_type == FileType.UNK:
return file_type

type_based_on_content = detect_filetype(file=file)

# detect_filetype is reading to read the file content
file.seek(0)

return type_based_on_content

def _supported_file_types(self) -> List[Any]:
return [FileType.MD, FileType.PDF, FileType.DOCX, FileType.PPTX]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

from datetime import datetime
from typing import Optional

from pydantic import BaseModel

Expand All @@ -14,3 +15,4 @@ class RemoteFile(BaseModel):

uri: str
last_modified: datetime
mime_type: Optional[str] = None
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
#

import asyncio
from datetime import datetime
from unittest.mock import MagicMock, mock_open, patch

import pytest
from airbyte_cdk.sources.file_based.exceptions import RecordParseError
from airbyte_cdk.sources.file_based.file_types import UnstructuredParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
from unstructured.documents.elements import ElementMetadata, Formula, ListItem, Text, Title
from unstructured.file_utils.filetype import FileType

Expand Down Expand Up @@ -143,17 +145,31 @@ def test_infer_schema(mock_detect_filetype, filetype, raises):
),
],
)
@patch("unstructured.partition.auto.partition")
@patch("unstructured.partition.pdf.partition_pdf")
@patch("unstructured.partition.pptx.partition_pptx")
@patch("unstructured.partition.docx.partition_docx")
@patch("unstructured.partition.md.optional_decode")
@patch("airbyte_cdk.sources.file_based.file_types.unstructured_parser.detect_filetype")
def test_parse_records(mock_detect_filetype, mock_optional_decode, mock_partition, filetype, parse_result, raises, expected_records):
def test_parse_records(
mock_detect_filetype,
mock_optional_decode,
mock_partition_docx,
mock_partition_pptx,
mock_partition_pdf,
filetype,
parse_result,
raises,
expected_records,
):
stream_reader = MagicMock()
mock_open(stream_reader.open_file, read_data=bytes(str(parse_result), "utf-8"))
fake_file = MagicMock()
fake_file = RemoteFile(uri=FILE_URI, last_modified=datetime.now())
fake_file.uri = FILE_URI
logger = MagicMock()
mock_detect_filetype.return_value = filetype
mock_partition.return_value = parse_result
mock_partition_docx.return_value = parse_result
mock_partition_pptx.return_value = parse_result
mock_partition_pdf.return_value = parse_result
mock_optional_decode.side_effect = lambda x: x.decode("utf-8")
if raises:
with pytest.raises(RecordParseError):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def get_matching_files(
) -> Iterable[RemoteFile]:
yield from self.filter_files_by_globs_and_start_date(
[
RemoteFile(uri=f, last_modified=datetime.strptime(data["last_modified"], "%Y-%m-%dT%H:%M:%S.%fZ"))
RemoteFile(
uri=f,
mime_type=data.get("mime_type", None),
last_modified=datetime.strptime(data["last_modified"], "%Y-%m-%dT%H:%M:%S.%fZ"),
)
for f, data in self.files.items()
],
globs,
Expand Down

Large diffs are not rendered by default.

0 comments on commit 66dd29f

Please sign in to comment.