-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
checksum support , improve control with new cmd line switches #2134
base: main
Are you sure you want to change the base?
Changes from 18 commits
a608e57
3206092
a482213
ce4bffe
285ae7d
bfb95ab
d07c9eb
4769133
9a385a4
66b5506
f4a504c
4ac3780
283d6d2
797347c
66438f2
15f34a5
2375e70
9d2dbf1
4ef5622
7a1c07f
0b43812
d7f1c80
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -3,14 +3,15 @@ | |||||
import logging | ||||||
import os | ||||||
import re | ||||||
from typing import List, Optional, Union | ||||||
from typing import List, Optional, Union, NamedTuple, Tuple | ||||||
|
||||||
import fitz # type: ignore | ||||||
from azure.core.credentials_async import AsyncTokenCredential | ||||||
from azure.storage.blob import ( | ||||||
BlobSasPermissions, | ||||||
UserDelegationKey, | ||||||
generate_blob_sas, | ||||||
generate_blob_sas, | ||||||
BlobClient | ||||||
) | ||||||
from azure.storage.blob.aio import BlobServiceClient, ContainerClient | ||||||
from PIL import Image, ImageDraw, ImageFont | ||||||
|
@@ -20,7 +21,6 @@ | |||||
|
||||||
logger = logging.getLogger("scripts") | ||||||
|
||||||
|
||||||
class BlobManager: | ||||||
""" | ||||||
Class to manage uploading and deleting blobs containing citation information from a blob storage account | ||||||
|
@@ -45,29 +45,63 @@ def __init__( | |||||
self.subscriptionId = subscriptionId | ||||||
self.user_delegation_key: Optional[UserDelegationKey] = None | ||||||
|
||||||
#async def upload_blob(self, file: File, container_client:ContainerClient) -> Optional[List[str]]: | ||||||
|
||||||
async def _create_new_blob(self, file: File, container_client:ContainerClient) -> BlobClient: | ||||||
with open(file.content.name, "rb") as reopened_file: | ||||||
blob_name = BlobManager.blob_name_from_file_name(file.content.name) | ||||||
logger.info("Uploading blob for whole file -> %s", blob_name) | ||||||
return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata) | ||||||
|
||||||
async def _file_blob_update_needed(self, blob_client: BlobClient, file : File) -> bool: | ||||||
md5_check : int = 0 # 0= not done, 1, positive,. 2 negative | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you end up not using this variable? |
||||||
# Get existing blob properties | ||||||
blob_properties = await blob_client.get_blob_properties() | ||||||
blob_metadata = blob_properties.metadata | ||||||
|
||||||
# Check if the md5 values are the same | ||||||
file_md5 = file.metadata.get('md5') | ||||||
blob_md5 = blob_metadata.get('md5') | ||||||
|
||||||
# Remove md5 from file metadata if it matches the blob metadata | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment doesn't seem to describe the next line? |
||||||
if file_md5 and file_md5 != blob_md5: | ||||||
return True | ||||||
else: | ||||||
return False | ||||||
|
||||||
async def upload_blob(self, file: File) -> Optional[List[str]]: | ||||||
async with BlobServiceClient( | ||||||
account_url=self.endpoint, credential=self.credential, max_single_put_size=4 * 1024 * 1024 | ||||||
) as service_client, service_client.get_container_client(self.container) as container_client: | ||||||
if not await container_client.exists(): | ||||||
await container_client.create_container() | ||||||
|
||||||
# Re-open and upload the original file | ||||||
|
||||||
# Re-open and upload the original file | ||||||
md5_check : int = 0 # 0= not done, 1, positive,. 2 negative | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a Python Enum would be clearer here? |
||||||
|
||||||
# upload the file local storage zu azure storage | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# file.url is only None if files are not uploaded yet, for datalake it is set | ||||||
if file.url is None: | ||||||
with open(file.content.name, "rb") as reopened_file: | ||||||
blob_name = BlobManager.blob_name_from_file_name(file.content.name) | ||||||
logger.info("Uploading blob for whole file -> %s", blob_name) | ||||||
blob_client = await container_client.upload_blob(blob_name, reopened_file, overwrite=True) | ||||||
file.url = blob_client.url | ||||||
blob_client = container_client.get_blob_client(file.url) | ||||||
|
||||||
if self.store_page_images: | ||||||
if not await blob_client.exists(): | ||||||
blob_client = await self._create_new_blob(file, container_client) | ||||||
else: | ||||||
if self._blob_update_needed(blob_client, file): | ||||||
md5_check = 2 | ||||||
# Upload the file with the updated metadata | ||||||
with open(file.content.name, "rb") as data: | ||||||
await blob_client.upload_blob(data, overwrite=True, metadata=file.metadata) | ||||||
else: | ||||||
md5_check = 1 | ||||||
file.url = blob_client.url | ||||||
|
||||||
if md5_check!=1 and self.store_page_images: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please install the pre-commit (see CONTRIBUTING guide) as it'll fix Python formatting issues for you, like the whitespace around operators. Thanks! |
||||||
if os.path.splitext(file.content.name)[1].lower() == ".pdf": | ||||||
return await self.upload_pdf_blob_images(service_client, container_client, file) | ||||||
else: | ||||||
logger.info("File %s is not a PDF, skipping image upload", file.content.name) | ||||||
|
||||||
return None | ||||||
|
||||||
def get_managedidentity_connectionstring(self): | ||||||
return f"ResourceId=/subscriptions/{self.subscriptionId}/resourceGroups/{self.resourceGroup}/providers/Microsoft.Storage/storageAccounts/{self.account};" | ||||||
|
||||||
|
@@ -93,7 +127,21 @@ async def upload_pdf_blob_images( | |||||
|
||||||
for i in range(page_count): | ||||||
blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i) | ||||||
logger.info("Converting page %s to image and uploading -> %s", i, blob_name) | ||||||
|
||||||
blob_client = container_client.get_blob_client(blob_name) | ||||||
do_upload : bool = True | ||||||
if await blob_client.exists(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this additional check here, given the check that happens in the function that calls this code? |
||||||
# Get existing blob properties | ||||||
blob_properties = await blob_client.get_blob_properties() | ||||||
blob_metadata = blob_properties.metadata | ||||||
|
||||||
# Check if the md5 values are the same | ||||||
file_md5 = file.metadata.get('md5') | ||||||
blob_md5 = blob_metadata.get('md5') | ||||||
if file_md5 == blob_md5: | ||||||
continue # documemt already uploaded | ||||||
|
||||||
logger.debug("Converting page %s to image and uploading -> %s", i, blob_name) | ||||||
|
||||||
doc = fitz.open(file.content.name) | ||||||
page = doc.load_page(i) | ||||||
|
@@ -119,21 +167,21 @@ async def upload_pdf_blob_images( | |||||
output = io.BytesIO() | ||||||
new_img.save(output, format="PNG") | ||||||
output.seek(0) | ||||||
|
||||||
blob_client = await container_client.upload_blob(blob_name, output, overwrite=True) | ||||||
await blob_client.upload_blob(data=output, overwrite=True, metadata=file.metadata) | ||||||
if not self.user_delegation_key: | ||||||
self.user_delegation_key = await service_client.get_user_delegation_key(start_time, expiry_time) | ||||||
|
||||||
if blob_client.account_name is not None: | ||||||
if container_client.account_name is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Am curious why the need to go from blob_client to container_client here? Does it matter? |
||||||
sas_token = generate_blob_sas( | ||||||
account_name=blob_client.account_name, | ||||||
container_name=blob_client.container_name, | ||||||
blob_name=blob_client.blob_name, | ||||||
account_name=container_client.account_name, | ||||||
container_name=container_client.container_name, | ||||||
blob_name=blob_name, | ||||||
user_delegation_key=self.user_delegation_key, | ||||||
permission=BlobSasPermissions(read=True), | ||||||
expiry=expiry_time, | ||||||
start=start_time, | ||||||
) | ||||||
) | ||||||
sas_uris.append(f"{blob_client.url}?{sas_token}") | ||||||
|
||||||
return sas_uris | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,10 @@ | ||
import logging | ||
import asyncio | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import List, Optional | ||
|
||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import List, Optional | ||
from tqdm.asyncio import tqdm | ||
from .blobmanager import BlobManager | ||
from .embeddings import ImageEmbeddings, OpenAIEmbeddings | ||
from .fileprocessor import FileProcessor | ||
|
@@ -22,17 +26,16 @@ async def parse_file( | |
if processor is None: | ||
logger.info("Skipping '%s', no parser found.", file.filename()) | ||
return [] | ||
logger.info("Ingesting '%s'", file.filename()) | ||
logger.debug("Ingesting '%s'", file.filename()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like you changed the levels to debug, does the current output feel too verbose? I find it helpful. |
||
pages = [page async for page in processor.parser.parse(content=file.content)] | ||
logger.info("Splitting '%s' into sections", file.filename()) | ||
logger.debug("Splitting '%s' into sections", file.filename()) | ||
if image_embeddings: | ||
logger.warning("Each page will be split into smaller chunks of text, but images will be of the entire page.") | ||
logger.debug("Each page will be split into smaller chunks of text, but images will be of the entire page.") | ||
sections = [ | ||
Section(split_page, content=file, category=category) for split_page in processor.splitter.split_pages(pages) | ||
] | ||
return sections | ||
|
||
|
||
class FileStrategy(Strategy): | ||
""" | ||
Strategy for ingesting documents into a search service from files stored either locally or in a data lake storage account | ||
|
@@ -44,6 +47,7 @@ def __init__( | |
blob_manager: BlobManager, | ||
search_info: SearchInfo, | ||
file_processors: dict[str, FileProcessor], | ||
ignore_checksum: bool, | ||
document_action: DocumentAction = DocumentAction.Add, | ||
embeddings: Optional[OpenAIEmbeddings] = None, | ||
image_embeddings: Optional[ImageEmbeddings] = None, | ||
|
@@ -55,6 +59,7 @@ def __init__( | |
self.blob_manager = blob_manager | ||
self.file_processors = file_processors | ||
self.document_action = document_action | ||
self.ignore_checksum = ignore_checksum | ||
self.embeddings = embeddings | ||
self.image_embeddings = image_embeddings | ||
self.search_analyzer_name = search_analyzer_name | ||
|
@@ -77,29 +82,61 @@ async def run(self): | |
search_manager = SearchManager( | ||
self.search_info, self.search_analyzer_name, self.use_acls, False, self.embeddings | ||
) | ||
doccount = self.list_file_strategy.count_docs() | ||
logger.info(f"Processing {doccount} files") | ||
processed_count = 0 | ||
if self.document_action == DocumentAction.Add: | ||
files = self.list_file_strategy.list() | ||
async for file in files: | ||
try: | ||
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings) | ||
if sections: | ||
blob_sas_uris = await self.blob_manager.upload_blob(file) | ||
blob_image_embeddings: Optional[List[List[float]]] = None | ||
if self.image_embeddings and blob_sas_uris: | ||
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) | ||
await search_manager.update_content(sections, blob_image_embeddings, url=file.url) | ||
if self.ignore_checksum or not await search_manager.file_exists(file): | ||
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings) | ||
if sections: | ||
blob_sas_uris = await self.blob_manager.upload_blob(file) | ||
blob_image_embeddings: Optional[List[List[float]]] = None | ||
if self.image_embeddings and blob_sas_uris: | ||
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) | ||
await search_manager.update_content(sections=sections, file=file, image_embeddings=blob_image_embeddings) | ||
finally: | ||
if file: | ||
file.close() | ||
processed_count += 1 | ||
if processed_count % 10 == 0: | ||
remaining = max(doccount - processed_count, 1) | ||
logger.info(f"{processed_count} processed, {remaining} documents remaining") | ||
|
||
elif self.document_action == DocumentAction.Remove: | ||
doccount = self.list_file_strategy.count_docs() | ||
paths = self.list_file_strategy.list_paths() | ||
async for path in paths: | ||
await self.blob_manager.remove_blob(path) | ||
await search_manager.remove_content(path) | ||
processed_count += 1 | ||
if processed_count % 10 == 0: | ||
remaining = max(doccount - processed_count, 1) | ||
logger.info(f"{processed_count} removed, {remaining} documents remaining") | ||
|
||
elif self.document_action == DocumentAction.RemoveAll: | ||
await self.blob_manager.remove_blob() | ||
await search_manager.remove_content() | ||
|
||
async def process_file(self, file, search_manager): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two functions seem to be unused currently- perhaps your thought was to refactor the code above to call these two functions. I'll remove them for now to make the diff smaller. |
||
try: | ||
sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings) | ||
if sections: | ||
blob_sas_uris = await self.blob_manager.upload_blob(file) | ||
blob_image_embeddings: Optional[List[List[float]]] = None | ||
if self.image_embeddings and blob_sas_uris: | ||
blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) | ||
await search_manager.update_content(sections=sections, file=file, image_embeddings=blob_image_embeddings) | ||
finally: | ||
if file: | ||
file.close() | ||
|
||
async def remove_file(self, path, search_manager): | ||
await self.blob_manager.remove_blob(path) | ||
await search_manager.remove_content(path) | ||
|
||
|
||
class UploadUserFileStrategy: | ||
""" | ||
|
@@ -124,7 +161,7 @@ async def add_file(self, file: File): | |
logging.warning("Image embeddings are not currently supported for the user upload feature") | ||
sections = await parse_file(file, self.file_processors) | ||
if sections: | ||
await self.search_manager.update_content(sections, url=file.url) | ||
await self.search_manager.update_content(sections=sections, file=file) | ||
|
||
async def remove_file(self, filename: str, oid: str): | ||
if filename is None or filename == "": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from azure.storage.filedatalake import DataLakeServiceClient | ||
from azure.storage.blob import BlobServiceClient | ||
import base64 | ||
import hashlib | ||
import logging | ||
|
@@ -8,6 +10,8 @@ | |
from glob import glob | ||
from typing import IO, AsyncGenerator, Dict, List, Optional, Union | ||
|
||
from azure.identity import DefaultAzureCredential | ||
|
||
from azure.core.credentials_async import AsyncTokenCredential | ||
from azure.storage.filedatalake.aio import ( | ||
DataLakeServiceClient, | ||
|
@@ -22,10 +26,11 @@ class File: | |
This file might contain access control information about which users or groups can access it | ||
""" | ||
|
||
def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None): | ||
def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None, metadata : Dict[str, str]= None): | ||
self.content = content | ||
self.acls = acls or {} | ||
self.url = url | ||
self.metadata = metadata | ||
|
||
def filename(self): | ||
return os.path.basename(self.content.name) | ||
|
@@ -58,7 +63,10 @@ async def list(self) -> AsyncGenerator[File, None]: | |
async def list_paths(self) -> AsyncGenerator[str, None]: | ||
if False: # pragma: no cover - this is necessary for mypy to type check | ||
yield | ||
|
||
|
||
def count_docs(self) -> int: | ||
if False: # pragma: no cover - this is necessary for mypy to type check | ||
yield | ||
|
||
class LocalListFileStrategy(ListFileStrategy): | ||
""" | ||
|
@@ -109,7 +117,23 @@ def check_md5(self, path: str) -> bool: | |
md5_f.write(existing_hash) | ||
|
||
return False | ||
|
||
|
||
|
||
def count_docs(self) -> int: | ||
""" | ||
Return the number of files that match the path pattern. | ||
""" | ||
return sum(1 for _ in self._list_paths_sync(self.path_pattern)) | ||
|
||
def _list_paths_sync(self, path_pattern: str): | ||
""" | ||
Synchronous version of _list_paths to be used for counting files. | ||
""" | ||
for path in glob(path_pattern): | ||
if os.path.isdir(path): | ||
yield from self._list_paths_sync(f"{path}/*") | ||
else: | ||
yield path | ||
|
||
class ADLSGen2ListFileStrategy(ListFileStrategy): | ||
""" | ||
|
@@ -167,11 +191,32 @@ async def list(self) -> AsyncGenerator[File, None]: | |
if acl_parts[0] == "user" and "r" in acl_parts[2]: | ||
acls["oids"].append(acl_parts[1]) | ||
if acl_parts[0] == "group" and "r" in acl_parts[2]: | ||
acls["groups"].append(acl_parts[1]) | ||
yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url) | ||
acls["groups"].append(acl_parts[1]) | ||
properties = await file_client.get_file_properties() | ||
yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata) | ||
except Exception as data_lake_exception: | ||
logger.error(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file") | ||
try: | ||
os.remove(temp_file_path) | ||
except Exception as file_delete_exception: | ||
logger.error(f"\tGot an error while deleting {temp_file_path} -> {file_delete_exception}") | ||
|
||
def count_docs(self) -> int: | ||
""" | ||
Return the number of blobs in the specified folder within the Azure Blob Storage container. | ||
""" | ||
|
||
# Create a BlobServiceClient using account URL and credentials | ||
service_client = BlobServiceClient( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you use BlobServiceClient to interact with a DataLake Storage account? That's what you seem to be doing here, but I didn't realize that was possible. |
||
account_url=f"https://{self.data_lake_storage_account}.blob.core.windows.net", | ||
credential=DefaultAzureCredential()) | ||
|
||
# Get the container client | ||
container_client = service_client.get_container_client(self.data_lake_filesystem) | ||
|
||
# Count blobs within the specified folder | ||
if self.data_lake_path != "/": | ||
return sum(1 for _ in container_client.list_blobs(name_starts_with= self.data_lake_path)) | ||
else: | ||
return sum(1 for _ in container_client.list_blobs()) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line can be rm'ed now?