diff --git a/.gitignore b/.gitignore index d324c63..c1eb246 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__/ *.pyc .DS_Store +.idea diff --git a/llama_parse/base.py b/llama_parse/base.py index bda8b9c..8468f40 100644 --- a/llama_parse/base.py +++ b/llama_parse/base.py @@ -1,17 +1,22 @@ import os import asyncio +from io import TextIOWrapper + import httpx import mimetypes import time -from pathlib import Path -from typing import AsyncGenerator, List, Optional, Union +from pathlib import Path, PurePath, PurePosixPath +from typing import AsyncGenerator, Any, Dict, List, Optional, Union from contextlib import asynccontextmanager from io import BufferedIOBase +from fsspec import AbstractFileSystem +from fsspec.spec import AbstractBufferedFile from llama_index.core.async_utils import run_jobs from llama_index.core.bridge.pydantic import Field, field_validator from llama_index.core.constants import DEFAULT_BASE_URL from llama_index.core.readers.base import BasePydanticReader +from llama_index.core.readers.file.base import get_default_fs from llama_index.core.schema import Document from llama_parse.utils import ( nest_asyncio_err, @@ -178,7 +183,10 @@ async def client_context(self) -> AsyncGenerator[httpx.AsyncClient, None]: # upload a document and get back a job_id async def _create_job( - self, file_input: FileInput, extra_info: Optional[dict] = None + self, + file_input: FileInput, + extra_info: Optional[dict] = None, + fs: Optional[AbstractFileSystem] = None, ) -> str: headers = {"Authorization": f"Bearer {self.api_key}"} url = f"{self.base_url}/api/parsing/upload" @@ -193,7 +201,7 @@ async def _create_job( file_name = extra_info["file_name"] mime_type = mimetypes.guess_type(file_name)[0] files = {"file": (file_name, file_input, mime_type)} - elif isinstance(file_input, (str, Path)): + elif isinstance(file_input, (str, Path, PurePosixPath, PurePath)): file_path = str(file_input) file_ext = os.path.splitext(file_path)[1].lower() if file_ext not in SUPPORTED_FILE_TYPES: @@ -203,7 +211,9 @@ async def _create_job( ) mime_type = mimetypes.guess_type(file_path)[0] # Open the file here for the duration of the async context - file_handle = open(file_path, "rb") + # load data, set the mime type + fs = fs or get_default_fs() + file_handle = fs.open(file_input, "rb") files = {"file": (os.path.basename(file_path), file_handle, mime_type)} else: raise ValueError( @@ -259,9 +269,15 @@ async def _create_job( if file_handle is not None: file_handle.close() + @staticmethod + def __get_filename(f: Union[TextIOWrapper, AbstractBufferedFile]) -> str: + if isinstance(f, TextIOWrapper): + return f.name + return f.full_name + async def _get_job_result( self, job_id: str, result_type: str, verbose: bool = False - ) -> dict: + ) -> Dict[str, Any]: result_url = f"{self.base_url}/api/parsing/job/{job_id}/result/{result_type}" status_url = f"{self.base_url}/api/parsing/job/{job_id}" headers = {"Authorization": f"Bearer {self.api_key}"} @@ -300,21 +316,16 @@ async def _get_job_result( await asyncio.sleep(self.check_interval) - continue - else: - raise Exception( - f"Failed to parse the file: {job_id}, status: {status}" - ) - async def _aload_data( self, file_path: FileInput, extra_info: Optional[dict] = None, + fs: Optional[AbstractFileSystem] = None, verbose: bool = False, ) -> List[Document]: """Load data from the input path.""" try: - job_id = await self._create_job(file_path, extra_info=extra_info) + job_id = await self._create_job(file_path, extra_info=extra_info, fs=fs) if verbose: print("Started parsing the file under job_id %s" % job_id) @@ -345,17 +356,19 @@ async def aload_data( self, file_path: Union[List[FileInput], FileInput], extra_info: Optional[dict] = None, + fs: Optional[AbstractFileSystem] = None, ) -> List[Document]: """Load data from the input path.""" if isinstance(file_path, (str, Path, bytes, BufferedIOBase)): return await self._aload_data( - file_path, extra_info=extra_info, verbose=self.verbose + file_path, extra_info=extra_info, fs=fs, verbose=self.verbose ) elif isinstance(file_path, list): jobs = [ self._aload_data( f, extra_info=extra_info, + fs=fs, verbose=self.verbose and not self.show_progress, ) for f in file_path @@ -384,10 +397,11 @@ def load_data( self, file_path: Union[List[FileInput], FileInput], extra_info: Optional[dict] = None, + fs: Optional[AbstractFileSystem] = None, ) -> List[Document]: """Load data from the input path.""" try: - return asyncio.run(self.aload_data(file_path, extra_info)) + return asyncio.run(self.aload_data(file_path, extra_info, fs=fs)) except RuntimeError as e: if nest_asyncio_err in str(e): raise RuntimeError(nest_asyncio_msg) diff --git a/pyproject.toml b/pyproject.toml index 9ff52f8..ef33b93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "llama-parse" -version = "0.5.3" +version = "0.5.4" description = "Parse files into RAG-Optimized formats." authors = ["Logan Markewich "] license = "MIT" diff --git a/tests/test_reader.py b/tests/test_reader.py index 091da24..70da8aa 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -1,6 +1,8 @@ import os import pytest +from fsspec.implementations.local import LocalFileSystem from httpx import AsyncClient + from llama_parse import LlamaParse @@ -70,6 +72,20 @@ def test_simple_page_markdown_buffer(markdown_parser: LlamaParse) -> None: assert len(result[0].text) > 0 +@pytest.mark.skipif( + os.environ.get("LLAMA_CLOUD_API_KEY", "") == "", + reason="LLAMA_CLOUD_API_KEY not set", +) +def test_simple_page_with_custom_fs() -> None: + parser = LlamaParse(result_type="markdown") + fs = LocalFileSystem() + filepath = os.path.join( + os.path.dirname(__file__), "test_files/attention_is_all_you_need.pdf" + ) + result = parser.load_data(filepath, fs=fs) + assert len(result) == 1 + + @pytest.mark.skipif( os.environ.get("LLAMA_CLOUD_API_KEY", "") == "", reason="LLAMA_CLOUD_API_KEY not set",