Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checksum support , improve control with new cmd line switches #2134

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 66 additions & 14 deletions app/backend/prepdocslib/blobmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import logging
import os
import re
from enum import Enum
from typing import List, Optional, Union

import fitz # type: ignore
from azure.core.credentials_async import AsyncTokenCredential
from azure.storage.blob import (
BlobClient,
BlobSasPermissions,
UserDelegationKey,
generate_blob_sas,
Expand Down Expand Up @@ -45,29 +47,65 @@ def __init__(
self.subscriptionId = subscriptionId
self.user_delegation_key: Optional[UserDelegationKey] = None

async def _create_new_blob(self, file: File, container_client: ContainerClient) -> BlobClient:
with open(file.content.name, "rb") as reopened_file:
blob_name = BlobManager.blob_name_from_file_name(file.content.name)
logger.info("Uploading blob for whole file -> %s", blob_name)
return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata)

async def _file_blob_update_needed(self, blob_client: BlobClient, file: File) -> bool:
# Get existing blob properties
blob_properties = await blob_client.get_blob_properties()
blob_metadata = blob_properties.metadata

# Check if the md5 values are the same
file_md5 = file.metadata.get("md5")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see logic in the main prepdocs for setting the md5 of the local file metadata, I only see that in ADLS2. How does this work when not using the ADLS2 strategy?

blob_md5 = blob_metadata.get("md5")

# If the file has an md5 value, check if it is different from the blob
return file_md5 and file_md5 != blob_md5

async def upload_blob(self, file: File) -> Optional[List[str]]:
async with BlobServiceClient(
account_url=self.endpoint, credential=self.credential, max_single_put_size=4 * 1024 * 1024
) as service_client, service_client.get_container_client(self.container) as container_client:
if not await container_client.exists():
await container_client.create_container()

# Re-open and upload the original file
# Re-open and upload the original file if the blob does not exist or the md5 values do not match
class MD5Check(Enum):
NOT_DONE = 0
MATCH = 1
NO_MATCH = 2

md5_check = MD5Check.NOT_DONE

# Upload the file to Azure Storage
# file.url is only None if files are not uploaded yet, for datalake it is set
if file.url is None:
with open(file.content.name, "rb") as reopened_file:
blob_name = BlobManager.blob_name_from_file_name(file.content.name)
logger.info("Uploading blob for whole file -> %s", blob_name)
blob_client = await container_client.upload_blob(blob_name, reopened_file, overwrite=True)
file.url = blob_client.url
blob_client = container_client.get_blob_client(file.url)

if self.store_page_images:
if not await blob_client.exists():
logger.info("Blob %s does not exist, uploading", file.url)
blob_client = await self._create_new_blob(file, container_client)
else:
if self._blob_update_needed(blob_client, file):
logger.info("Blob %s exists but md5 values do not match, updating", file.url)
md5_check = MD5Check.NO_MATCH
# Upload the file with the updated metadata
with open(file.content.name, "rb") as data:
await blob_client.upload_blob(data, overwrite=True, metadata=file.metadata)
else:
logger.info("Blob %s exists and md5 values match, skipping upload", file.url)
md5_check = MD5Check.MATCH
file.url = blob_client.url

if md5_check != MD5Check.MATCH and self.store_page_images:
if os.path.splitext(file.content.name)[1].lower() == ".pdf":
return await self.upload_pdf_blob_images(service_client, container_client, file)
else:
logger.info("File %s is not a PDF, skipping image upload", file.content.name)

return None

def get_managedidentity_connectionstring(self):
return f"ResourceId=/subscriptions/{self.subscriptionId}/resourceGroups/{self.resourceGroup}/providers/Microsoft.Storage/storageAccounts/{self.account};"

Expand All @@ -93,6 +131,20 @@ async def upload_pdf_blob_images(

for i in range(page_count):
blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i)

blob_client = container_client.get_blob_client(blob_name)
if await blob_client.exists():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this additional check here, given the check that happens in the function that calls this code?

# Get existing blob properties
blob_properties = await blob_client.get_blob_properties()
blob_metadata = blob_properties.metadata

# Check if the md5 values are the same
file_md5 = file.metadata.get("md5")
blob_md5 = blob_metadata.get("md5")
if file_md5 == blob_md5:
logger.info("Blob %s exists and md5 values match, skipping upload", blob_name)
continue # document already uploaded

logger.info("Converting page %s to image and uploading -> %s", i, blob_name)

doc = fitz.open(file.content.name)
Expand Down Expand Up @@ -120,15 +172,15 @@ async def upload_pdf_blob_images(
new_img.save(output, format="PNG")
output.seek(0)

blob_client = await container_client.upload_blob(blob_name, output, overwrite=True)
await blob_client.upload_blob(data=output, overwrite=True, metadata=file.metadata)
if not self.user_delegation_key:
self.user_delegation_key = await service_client.get_user_delegation_key(start_time, expiry_time)

if blob_client.account_name is not None:
if container_client.account_name is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am curious why the need to go from blob_client to container_client here? Does it matter?

sas_token = generate_blob_sas(
account_name=blob_client.account_name,
container_name=blob_client.container_name,
blob_name=blob_client.blob_name,
account_name=container_client.account_name,
container_name=container_client.container_name,
blob_name=blob_name,
user_delegation_key=self.user_delegation_key,
permission=BlobSasPermissions(read=True),
expiry=expiry_time,
Expand Down
25 changes: 23 additions & 2 deletions app/backend/prepdocslib/filestrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
blob_manager: BlobManager,
search_info: SearchInfo,
file_processors: dict[str, FileProcessor],
ignore_checksum: bool,
document_action: DocumentAction = DocumentAction.Add,
embeddings: Optional[OpenAIEmbeddings] = None,
image_embeddings: Optional[ImageEmbeddings] = None,
Expand All @@ -55,6 +56,7 @@ def __init__(
self.blob_manager = blob_manager
self.file_processors = file_processors
self.document_action = document_action
self.ignore_checksum = ignore_checksum
self.embeddings = embeddings
self.image_embeddings = image_embeddings
self.search_analyzer_name = search_analyzer_name
Expand All @@ -77,25 +79,44 @@ async def run(self):
search_manager = SearchManager(
self.search_info, self.search_analyzer_name, self.use_acls, False, self.embeddings
)
doc_count = self.list_file_strategy.count_docs()
logger.info("Processing %s files", doc_count)
processed_count = 0
if self.document_action == DocumentAction.Add:
files = self.list_file_strategy.list()
async for file in files:
try:
if not self.ignore_checksum and await search_manager.file_exists(file):
logger.info("'%s' has already been processed", file.filename())
continue
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings)
if sections:
blob_sas_uris = await self.blob_manager.upload_blob(file)
blob_image_embeddings: Optional[List[List[float]]] = None
if self.image_embeddings and blob_sas_uris:
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris)
await search_manager.update_content(sections, blob_image_embeddings, url=file.url)
await search_manager.update_content(
sections=sections, file=file, image_embeddings=blob_image_embeddings
)
finally:
if file:
file.close()
processed_count += 1
if processed_count % 10 == 0:
remaining = max(doc_count - processed_count, 1)
logger.info("%s processed, %s documents remaining", processed_count, remaining)

elif self.document_action == DocumentAction.Remove:
doc_count = self.list_file_strategy.count_docs()
paths = self.list_file_strategy.list_paths()
async for path in paths:
await self.blob_manager.remove_blob(path)
await search_manager.remove_content(path)
processed_count += 1
if processed_count % 10 == 0:
remaining = max(doc_count - processed_count, 1)
logger.info("%s removed, %s documents remaining", processed_count, remaining)

elif self.document_action == DocumentAction.RemoveAll:
await self.blob_manager.remove_blob()
await search_manager.remove_content()
Expand Down Expand Up @@ -124,7 +145,7 @@ async def add_file(self, file: File):
logging.warning("Image embeddings are not currently supported for the user upload feature")
sections = await parse_file(file, self.file_processors)
if sections:
await self.search_manager.update_content(sections, url=file.url)
await self.search_manager.update_content(sections=sections, file=file)

async def remove_file(self, filename: str, oid: str):
if filename is None or filename == "":
Expand Down
59 changes: 54 additions & 5 deletions app/backend/prepdocslib/listfilestrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from typing import IO, AsyncGenerator, Dict, List, Optional, Union

from azure.core.credentials_async import AsyncTokenCredential
from azure.storage.filedatalake.aio import (
DataLakeServiceClient,
)
from azure.storage.blob import BlobServiceClient
from azure.storage.filedatalake.aio import DataLakeServiceClient

logger = logging.getLogger("scripts")

Expand All @@ -22,10 +21,17 @@ class File:
This file might contain access control information about which users or groups can access it
"""

def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None):
def __init__(
self,
content: IO,
acls: Optional[dict[str, list]] = None,
url: Optional[str] = None,
metadata: Dict[str, str] = None,
):
self.content = content
self.acls = acls or {}
self.url = url
self.metadata = metadata

def filename(self):
return os.path.basename(self.content.name)
Expand Down Expand Up @@ -59,6 +65,10 @@ async def list_paths(self) -> AsyncGenerator[str, None]:
if False: # pragma: no cover - this is necessary for mypy to type check
yield

def count_docs(self) -> int:
if False: # pragma: no cover - this is necessary for mypy to type check
yield


class LocalListFileStrategy(ListFileStrategy):
"""
Expand Down Expand Up @@ -110,6 +120,22 @@ def check_md5(self, path: str) -> bool:

return False

def count_docs(self) -> int:
"""
Return the number of files that match the path pattern.
"""
return sum(1 for _ in self._list_paths_sync(self.path_pattern))

def _list_paths_sync(self, path_pattern: str):
"""
Synchronous version of _list_paths to be used for counting files.
"""
for path in glob(path_pattern):
if os.path.isdir(path):
yield from self._list_paths_sync(f"{path}/*")
else:
yield path


class ADLSGen2ListFileStrategy(ListFileStrategy):
"""
Expand Down Expand Up @@ -168,10 +194,33 @@ async def list(self) -> AsyncGenerator[File, None]:
acls["oids"].append(acl_parts[1])
if acl_parts[0] == "group" and "r" in acl_parts[2]:
acls["groups"].append(acl_parts[1])
yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url)
properties = await file_client.get_file_properties()
yield File(
content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata
)
except Exception as data_lake_exception:
logger.error(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file")
try:
os.remove(temp_file_path)
except Exception as file_delete_exception:
logger.error(f"\tGot an error while deleting {temp_file_path} -> {file_delete_exception}")

def count_docs(self) -> int:
"""
Return the number of blobs in the specified folder within the Azure Blob Storage container.
"""

# Create a BlobServiceClient using account URL and credentials
service_client = BlobServiceClient(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use BlobServiceClient to interact with a DataLake Storage account? That's what you seem to be doing here, but I didn't realize that was possible.

account_url=f"https://{self.data_lake_storage_account}.blob.core.windows.net",
credential=self.credential,
)

# Get the container client
container_client = service_client.get_container_client(self.data_lake_filesystem)

# Count blobs within the specified folder
if self.data_lake_path != "/":
return sum(1 for _ in container_client.list_blobs(name_starts_with=self.data_lake_path))
else:
return sum(1 for _ in container_client.list_blobs())
Loading
Loading