Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable support for custom filesystem #117

Merged
merged 12 commits into from
Sep 10, 2024
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.git
__pycache__/
*.pyc
.DS_Store
.DS_Store
.idea
86 changes: 50 additions & 36 deletions llama_parse/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import os
import asyncio
from io import TextIOWrapper

import httpx
import mimetypes
import time
from enum import Enum
from pathlib import Path
from pathlib import Path, PurePath
from typing import List, Optional, Union

from fsspec import AbstractFileSystem
from fsspec.spec import AbstractBufferedFile
from llama_index.core.async_utils import run_jobs
from llama_index.core.bridge.pydantic import Field, validator
from llama_index.core.constants import DEFAULT_BASE_URL
from llama_index.core.readers.base import BasePydanticReader
from llama_index.core.readers.file.base import get_default_fs
from llama_index.core.schema import Document


Expand Down Expand Up @@ -127,7 +132,7 @@ class Language(str, Enum):

# Open Office
".sxw",
".stw",
".stw",
".sxg",

# Apple
Expand Down Expand Up @@ -161,9 +166,9 @@ class Language(str, Enum):
".odg",
".otp",
".fopd",
".sxi",
".sxi",
".sti",

# ebook
".epub"
]
Expand All @@ -183,7 +188,7 @@ class LlamaParse(BasePydanticReader):
num_workers: int = Field(
default=4,
gt=0,
lt=10,
lt=10,
description="The number of workers to use sending API requests for parsing."
)
check_interval: int = Field(
Expand Down Expand Up @@ -214,34 +219,37 @@ def validate_api_key(cls, v: str) -> str:
if api_key is None:
raise ValueError("The API key is required.")
return api_key

return v

@validator("base_url", pre=True, always=True)
def validate_base_url(cls, v: str) -> str:
"""Validate the base URL."""
url = os.getenv("LLAMA_CLOUD_BASE_URL", None)
return url or v or DEFAULT_BASE_URL

# upload a document and get back a job_id
async def _create_job(self, file_path: str, extra_info: Optional[dict] = None) -> str:
file_path = str(file_path)
file_ext = os.path.splitext(file_path)[1]
async def _create_job(self, file_path: str | PurePath, extra_info: Optional[dict] = None, fs: Optional[AbstractFileSystem] = None,) -> str:
logan-markewich marked this conversation as resolved.
Show resolved Hide resolved
str_file_path = file_path
if isinstance(file_path, PurePath):
str_file_path = file_path.name
file_ext = os.path.splitext(str_file_path)[1]
if file_ext not in SUPPORTED_FILE_TYPES:
raise Exception(
f"Currently, only the following file types are supported: {SUPPORTED_FILE_TYPES}\n"
f"Current file type: {file_ext}"
)

extra_info = extra_info or {}
extra_info["file_path"] = file_path
extra_info["file_path"] = str_file_path

headers = {"Authorization": f"Bearer {self.api_key}"}

# load data, set the mime type
with open(file_path, "rb") as f:
mime_type = mimetypes.guess_type(file_path)[0]
files = {"file": (f.name, f, mime_type)}
fs = fs or get_default_fs()
with fs.open(file_path, "rb") as f:
mime_type = mimetypes.guess_type(str_file_path)[0]
files = {"file": (self.__get_filename(f), f, mime_type)}
Copy link
Contributor

@logan-markewich logan-markewich Apr 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this works, at least locally, it doesn't work for me

Since we already have the path, can't we use that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using it in production right now and it is working properly. What is the error that you are getting locally?

The issue with the path is that I've seen situations in production where people upload a file with .txt but the file is actually a .csv and it doesn't work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:sigh: this still doesn't work guys lol

>>> from llama_parse import LlamaParse
>>> documents = LlamaParse(api_key="llx-...").load_data("2023.acl-srw.0.pdf")
Error while parsing the file '2023.acl-srw.0.pdf': '_io.BufferedReader' object has no attribute 'full_name'
>>> 


# send the request, start job
url = f"{self.base_url}/api/parsing/upload"
Expand All @@ -254,6 +262,12 @@ async def _create_job(self, file_path: str, extra_info: Optional[dict] = None) -
job_id = response.json()["id"]
return job_id

@staticmethod
def __get_filename(f: TextIOWrapper | AbstractBufferedFile) -> str:
logan-markewich marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(f, TextIOWrapper):
return f.name
return f.full_name

async def _get_job_result(self, job_id: str, result_type: str) -> dict:
result_url = f"{self.base_url}/api/parsing/job/{job_id}/result/{result_type}"
headers = {"Authorization": f"Bearer {self.api_key}"}
Expand All @@ -262,9 +276,9 @@ async def _get_job_result(self, job_id: str, result_type: str) -> dict:
tries = 0
while True:
await asyncio.sleep(self.check_interval)
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
tries += 1
async with httpx.AsyncClient(timeout=self.max_timeout) as client:
tries += 1

result = await client.get(result_url, headers=headers)

if result.status_code == 404:
Expand All @@ -283,13 +297,13 @@ async def _get_job_result(self, job_id: str, result_type: str) -> dict:

return result.json()

async def _aload_data(self, file_path: str, extra_info: Optional[dict] = None) -> List[Document]:
async def _aload_data(self, file_path: str | PurePath, extra_info: Optional[dict] = None, fs: Optional[AbstractFileSystem] = None,) -> List[Document]:
"""Load data from the input path."""
try:
job_id = await self._create_job(file_path, extra_info=extra_info)
job_id = await self._create_job(file_path, extra_info=extra_info, fs=fs)
if self.verbose:
print("Started parsing the file under job_id %s" % job_id)

result = await self._get_job_result(job_id, self.result_type.value)

return [
Expand All @@ -298,22 +312,22 @@ async def _aload_data(self, file_path: str, extra_info: Optional[dict] = None) -
metadata=extra_info or {},
)
]

except Exception as e:
print(f"Error while parsing the file '{file_path}':", e)
raise e
return []


async def aload_data(self, file_path: Union[List[str], str], extra_info: Optional[dict] = None) -> List[Document]:

async def aload_data(self, file_path: Union[List[str], str, PurePath, List[PurePath]], extra_info: Optional[dict] = None, fs: Optional[AbstractFileSystem] = None,) -> List[Document]:
"""Load data from the input path."""
if isinstance(file_path, (str, Path)):
return await self._aload_data(file_path, extra_info=extra_info)
if isinstance(file_path, (str, PurePath)):
return await self._aload_data(file_path, extra_info=extra_info, fs=fs)
elif isinstance(file_path, list):
jobs = [self._aload_data(f, extra_info=extra_info) for f in file_path]
jobs = [self._aload_data(f, extra_info=extra_info, fs=fs) for f in file_path]
try:
results = await run_jobs(jobs, workers=self.num_workers)

# return flattened results
return [item for sublist in results for item in sublist]
except RuntimeError as e:
Expand All @@ -324,34 +338,34 @@ async def aload_data(self, file_path: Union[List[str], str], extra_info: Optiona
else:
raise ValueError("The input file_path must be a string or a list of strings.")

def load_data(self, file_path: Union[List[str], str], extra_info: Optional[dict] = None) -> List[Document]:
def load_data(self, file_path: Union[List[str], str, PurePath, List[PurePath]], extra_info: Optional[dict] = None, fs: Optional[AbstractFileSystem] = None,) -> List[Document]:
"""Load data from the input path."""
try:
return asyncio.run(self.aload_data(file_path, extra_info))
return asyncio.run(self.aload_data(file_path, extra_info, fs=fs))
except RuntimeError as e:
if nest_asyncio_err in str(e):
raise RuntimeError(nest_asyncio_msg)
else:
raise e


async def _aget_json(self, file_path: str, extra_info: Optional[dict] = None) -> List[dict]:
"""Load data from the input path."""
try:
job_id = await self._create_job(file_path, extra_info=extra_info)
if self.verbose:
print("Started parsing the file under job_id %s" % job_id)

result = await self._get_job_result(job_id, "json")
result["job_id"] = job_id
result["file_path"] = file_path
return [result]

except Exception as e:
print(f"Error while parsing the file '{file_path}':", e)
raise e



async def aget_json(self, file_path: Union[List[str], str], extra_info: Optional[dict] = None) -> List[dict]:
"""Load data from the input path."""
Expand All @@ -361,7 +375,7 @@ async def aget_json(self, file_path: Union[List[str], str], extra_info: Optional
jobs = [self._aget_json(f, extra_info=extra_info) for f in file_path]
try:
results = await run_jobs(jobs, workers=self.num_workers)

# return flattened results
return [item for sublist in results for item in sublist]
except RuntimeError as e:
Expand All @@ -382,7 +396,7 @@ def get_json_result(self, file_path: Union[List[str], str], extra_info: Optional
raise RuntimeError(nest_asyncio_msg)
else:
raise e

def get_images(self, json_result: list[dict], download_path: str) -> List[dict]:
"""Download images from the parsed result."""
headers = {"Authorization": f"Bearer {self.api_key}"}
Expand Down
Loading