diff --git a/app/backend/prepdocslib/blobmanager.py b/app/backend/prepdocslib/blobmanager.py index e9f18e795a..4d5b37772c 100644 --- a/app/backend/prepdocslib/blobmanager.py +++ b/app/backend/prepdocslib/blobmanager.py @@ -3,11 +3,13 @@ import logging import os import re +from enum import Enum from typing import List, Optional, Union import fitz # type: ignore from azure.core.credentials_async import AsyncTokenCredential from azure.storage.blob import ( + BlobClient, BlobSasPermissions, UserDelegationKey, generate_blob_sas, @@ -45,6 +47,24 @@ def __init__( self.subscriptionId = subscriptionId self.user_delegation_key: Optional[UserDelegationKey] = None + async def _create_new_blob(self, file: File, container_client: ContainerClient) -> BlobClient: + with open(file.content.name, "rb") as reopened_file: + blob_name = BlobManager.blob_name_from_file_name(file.content.name) + logger.info("Uploading blob for whole file -> %s", blob_name) + return await container_client.upload_blob(blob_name, reopened_file, overwrite=True, metadata=file.metadata) + + async def _file_blob_update_needed(self, blob_client: BlobClient, file: File) -> bool: + # Get existing blob properties + blob_properties = await blob_client.get_blob_properties() + blob_metadata = blob_properties.metadata + + # Check if the md5 values are the same + file_md5 = file.metadata.get("md5") + blob_md5 = blob_metadata.get("md5") + + # If the file has an md5 value, check if it is different from the blob + return file_md5 and file_md5 != blob_md5 + async def upload_blob(self, file: File) -> Optional[List[str]]: async with BlobServiceClient( account_url=self.endpoint, credential=self.credential, max_single_put_size=4 * 1024 * 1024 @@ -52,22 +72,40 @@ async def upload_blob(self, file: File) -> Optional[List[str]]: if not await container_client.exists(): await container_client.create_container() - # Re-open and upload the original file + # Re-open and upload the original file if the blob does not exist or the md5 values do not match + class MD5Check(Enum): + NOT_DONE = 0 + MATCH = 1 + NO_MATCH = 2 + + md5_check = MD5Check.NOT_DONE + + # Upload the file to Azure Storage + # file.url is only None if files are not uploaded yet, for datalake it is set if file.url is None: - with open(file.content.name, "rb") as reopened_file: - blob_name = BlobManager.blob_name_from_file_name(file.content.name) - logger.info("Uploading blob for whole file -> %s", blob_name) - blob_client = await container_client.upload_blob(blob_name, reopened_file, overwrite=True) - file.url = blob_client.url + blob_client = container_client.get_blob_client(file.url) - if self.store_page_images: + if not await blob_client.exists(): + logger.info("Blob %s does not exist, uploading", file.url) + blob_client = await self._create_new_blob(file, container_client) + else: + if self._blob_update_needed(blob_client, file): + logger.info("Blob %s exists but md5 values do not match, updating", file.url) + md5_check = MD5Check.NO_MATCH + # Upload the file with the updated metadata + with open(file.content.name, "rb") as data: + await blob_client.upload_blob(data, overwrite=True, metadata=file.metadata) + else: + logger.info("Blob %s exists and md5 values match, skipping upload", file.url) + md5_check = MD5Check.MATCH + file.url = blob_client.url + + if md5_check != MD5Check.MATCH and self.store_page_images: if os.path.splitext(file.content.name)[1].lower() == ".pdf": return await self.upload_pdf_blob_images(service_client, container_client, file) else: logger.info("File %s is not a PDF, skipping image upload", file.content.name) - return None - def get_managedidentity_connectionstring(self): return f"ResourceId=/subscriptions/{self.subscriptionId}/resourceGroups/{self.resourceGroup}/providers/Microsoft.Storage/storageAccounts/{self.account};" @@ -93,6 +131,20 @@ async def upload_pdf_blob_images( for i in range(page_count): blob_name = BlobManager.blob_image_name_from_file_page(file.content.name, i) + + blob_client = container_client.get_blob_client(blob_name) + if await blob_client.exists(): + # Get existing blob properties + blob_properties = await blob_client.get_blob_properties() + blob_metadata = blob_properties.metadata + + # Check if the md5 values are the same + file_md5 = file.metadata.get("md5") + blob_md5 = blob_metadata.get("md5") + if file_md5 == blob_md5: + logger.info("Blob %s exists and md5 values match, skipping upload", blob_name) + continue # document already uploaded + logger.info("Converting page %s to image and uploading -> %s", i, blob_name) doc = fitz.open(file.content.name) @@ -120,15 +172,15 @@ async def upload_pdf_blob_images( new_img.save(output, format="PNG") output.seek(0) - blob_client = await container_client.upload_blob(blob_name, output, overwrite=True) + await blob_client.upload_blob(data=output, overwrite=True, metadata=file.metadata) if not self.user_delegation_key: self.user_delegation_key = await service_client.get_user_delegation_key(start_time, expiry_time) - if blob_client.account_name is not None: + if container_client.account_name is not None: sas_token = generate_blob_sas( - account_name=blob_client.account_name, - container_name=blob_client.container_name, - blob_name=blob_client.blob_name, + account_name=container_client.account_name, + container_name=container_client.container_name, + blob_name=blob_name, user_delegation_key=self.user_delegation_key, permission=BlobSasPermissions(read=True), expiry=expiry_time, diff --git a/app/backend/prepdocslib/filestrategy.py b/app/backend/prepdocslib/filestrategy.py index 55b24b6f3a..fc8b8f41af 100644 --- a/app/backend/prepdocslib/filestrategy.py +++ b/app/backend/prepdocslib/filestrategy.py @@ -44,6 +44,7 @@ def __init__( blob_manager: BlobManager, search_info: SearchInfo, file_processors: dict[str, FileProcessor], + ignore_checksum: bool, document_action: DocumentAction = DocumentAction.Add, embeddings: Optional[OpenAIEmbeddings] = None, image_embeddings: Optional[ImageEmbeddings] = None, @@ -55,6 +56,7 @@ def __init__( self.blob_manager = blob_manager self.file_processors = file_processors self.document_action = document_action + self.ignore_checksum = ignore_checksum self.embeddings = embeddings self.image_embeddings = image_embeddings self.search_analyzer_name = search_analyzer_name @@ -77,25 +79,44 @@ async def run(self): search_manager = SearchManager( self.search_info, self.search_analyzer_name, self.use_acls, False, self.embeddings ) + doc_count = self.list_file_strategy.count_docs() + logger.info("Processing %s files", doc_count) + processed_count = 0 if self.document_action == DocumentAction.Add: files = self.list_file_strategy.list() async for file in files: try: + if not self.ignore_checksum and await search_manager.file_exists(file): + logger.info("'%s' has already been processed", file.filename()) + continue sections = await parse_file(file, self.file_processors, self.category, self.image_embeddings) if sections: blob_sas_uris = await self.blob_manager.upload_blob(file) blob_image_embeddings: Optional[List[List[float]]] = None if self.image_embeddings and blob_sas_uris: blob_image_embeddings = await self.image_embeddings.create_embeddings(blob_sas_uris) - await search_manager.update_content(sections, blob_image_embeddings, url=file.url) + await search_manager.update_content( + sections=sections, file=file, image_embeddings=blob_image_embeddings + ) finally: if file: file.close() + processed_count += 1 + if processed_count % 10 == 0: + remaining = max(doc_count - processed_count, 1) + logger.info("%s processed, %s documents remaining", processed_count, remaining) + elif self.document_action == DocumentAction.Remove: + doc_count = self.list_file_strategy.count_docs() paths = self.list_file_strategy.list_paths() async for path in paths: await self.blob_manager.remove_blob(path) await search_manager.remove_content(path) + processed_count += 1 + if processed_count % 10 == 0: + remaining = max(doc_count - processed_count, 1) + logger.info("%s removed, %s documents remaining", processed_count, remaining) + elif self.document_action == DocumentAction.RemoveAll: await self.blob_manager.remove_blob() await search_manager.remove_content() @@ -124,7 +145,7 @@ async def add_file(self, file: File): logging.warning("Image embeddings are not currently supported for the user upload feature") sections = await parse_file(file, self.file_processors) if sections: - await self.search_manager.update_content(sections, url=file.url) + await self.search_manager.update_content(sections=sections, file=file) async def remove_file(self, filename: str, oid: str): if filename is None or filename == "": diff --git a/app/backend/prepdocslib/listfilestrategy.py b/app/backend/prepdocslib/listfilestrategy.py index 3c8fcd27b0..dd426c55f9 100644 --- a/app/backend/prepdocslib/listfilestrategy.py +++ b/app/backend/prepdocslib/listfilestrategy.py @@ -9,9 +9,8 @@ from typing import IO, AsyncGenerator, Dict, List, Optional, Union from azure.core.credentials_async import AsyncTokenCredential -from azure.storage.filedatalake.aio import ( - DataLakeServiceClient, -) +from azure.storage.blob import BlobServiceClient +from azure.storage.filedatalake.aio import DataLakeServiceClient logger = logging.getLogger("scripts") @@ -22,10 +21,17 @@ class File: This file might contain access control information about which users or groups can access it """ - def __init__(self, content: IO, acls: Optional[dict[str, list]] = None, url: Optional[str] = None): + def __init__( + self, + content: IO, + acls: Optional[dict[str, list]] = None, + url: Optional[str] = None, + metadata: Dict[str, str] = None, + ): self.content = content self.acls = acls or {} self.url = url + self.metadata = metadata def filename(self): return os.path.basename(self.content.name) @@ -59,6 +65,10 @@ async def list_paths(self) -> AsyncGenerator[str, None]: if False: # pragma: no cover - this is necessary for mypy to type check yield + def count_docs(self) -> int: + if False: # pragma: no cover - this is necessary for mypy to type check + yield + class LocalListFileStrategy(ListFileStrategy): """ @@ -110,6 +120,22 @@ def check_md5(self, path: str) -> bool: return False + def count_docs(self) -> int: + """ + Return the number of files that match the path pattern. + """ + return sum(1 for _ in self._list_paths_sync(self.path_pattern)) + + def _list_paths_sync(self, path_pattern: str): + """ + Synchronous version of _list_paths to be used for counting files. + """ + for path in glob(path_pattern): + if os.path.isdir(path): + yield from self._list_paths_sync(f"{path}/*") + else: + yield path + class ADLSGen2ListFileStrategy(ListFileStrategy): """ @@ -168,10 +194,33 @@ async def list(self) -> AsyncGenerator[File, None]: acls["oids"].append(acl_parts[1]) if acl_parts[0] == "group" and "r" in acl_parts[2]: acls["groups"].append(acl_parts[1]) - yield File(content=open(temp_file_path, "rb"), acls=acls, url=file_client.url) + properties = await file_client.get_file_properties() + yield File( + content=open(temp_file_path, "rb"), acls=acls, url=file_client.url, metadata=properties.metadata + ) except Exception as data_lake_exception: logger.error(f"\tGot an error while reading {path} -> {data_lake_exception} --> skipping file") try: os.remove(temp_file_path) except Exception as file_delete_exception: logger.error(f"\tGot an error while deleting {temp_file_path} -> {file_delete_exception}") + + def count_docs(self) -> int: + """ + Return the number of blobs in the specified folder within the Azure Blob Storage container. + """ + + # Create a BlobServiceClient using account URL and credentials + service_client = BlobServiceClient( + account_url=f"https://{self.data_lake_storage_account}.blob.core.windows.net", + credential=self.credential, + ) + + # Get the container client + container_client = service_client.get_container_client(self.data_lake_filesystem) + + # Count blobs within the specified folder + if self.data_lake_path != "/": + return sum(1 for _ in container_client.list_blobs(name_starts_with=self.data_lake_path)) + else: + return sum(1 for _ in container_client.list_blobs()) diff --git a/app/backend/prepdocslib/searchmanager.py b/app/backend/prepdocslib/searchmanager.py index f75af03514..47945786a5 100644 --- a/app/backend/prepdocslib/searchmanager.py +++ b/app/backend/prepdocslib/searchmanager.py @@ -1,8 +1,10 @@ import asyncio +import datetime import logging import os from typing import List, Optional +import dateutil.parser as parser from azure.search.documents.indexes.models import ( AzureOpenAIVectorizer, AzureOpenAIVectorizerParameters, @@ -70,92 +72,95 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] logger.info("Checking whether search index %s exists...", self.search_info.index_name) async with self.search_info.create_search_index_client() as search_index_client: - - if self.search_info.index_name not in [name async for name in search_index_client.list_index_names()]: - logger.info("Creating new search index %s", self.search_info.index_name) - fields = [ - ( - SimpleField(name="id", type="Edm.String", key=True) - if not self.use_int_vectorization - else SearchField( - name="id", - type="Edm.String", - key=True, - sortable=True, - filterable=True, - facetable=True, - analyzer_name="keyword", - ) - ), - SearchableField( - name="content", + fields = [ + ( + SimpleField(name="id", type="Edm.String", key=True) + if not self.use_int_vectorization + else SearchField( + name="id", type="Edm.String", - analyzer_name=self.search_analyzer_name, - ), + key=True, + sortable=True, + filterable=True, + facetable=True, + analyzer_name="keyword", + ) + ), + SearchableField( + name="content", + type="Edm.String", + analyzer_name=self.search_analyzer_name, + ), + SearchField( + name="embedding", + type=SearchFieldDataType.Collection(SearchFieldDataType.Single), + hidden=False, + searchable=True, + filterable=False, + sortable=False, + facetable=False, + vector_search_dimensions=self.embedding_dimensions, + vector_search_profile_name="embedding_config", + ), + SimpleField(name="category", type="Edm.String", filterable=True, facetable=True), + SimpleField(name="md5", type="Edm.String", filterable=True, facetable=True), + SimpleField(name="deeplink", type="Edm.String", filterable=True, facetable=False), + SimpleField(name="updated", type="Edm.DateTimeOffset", filterable=True, facetable=True), + SimpleField( + name="sourcepage", + type="Edm.String", + filterable=True, + facetable=True, + ), + SimpleField( + name="sourcefile", + type="Edm.String", + filterable=True, + facetable=True, + ), + SimpleField( + name="storageUrl", + type="Edm.String", + filterable=True, + facetable=False, + ), + ] + if self.use_acls: + fields.append( + SimpleField( + name="oids", + type=SearchFieldDataType.Collection(SearchFieldDataType.String), + filterable=True, + ) + ) + fields.append( + SimpleField( + name="groups", + type=SearchFieldDataType.Collection(SearchFieldDataType.String), + filterable=True, + ) + ) + if self.use_int_vectorization: + logger.info("Including parent_id field in new index %s", self.search_info.index_name) + fields.append(SearchableField(name="parent_id", type="Edm.String", filterable=True)) + if self.search_images: + logger.info("Including imageEmbedding field in new index %s", self.search_info.index_name) + fields.append( SearchField( - name="embedding", + name="imageEmbedding", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), hidden=False, searchable=True, filterable=False, sortable=False, facetable=False, - vector_search_dimensions=self.embedding_dimensions, + vector_search_dimensions=1024, vector_search_profile_name="embedding_config", ), - SimpleField(name="category", type="Edm.String", filterable=True, facetable=True), - SimpleField( - name="sourcepage", - type="Edm.String", - filterable=True, - facetable=True, - ), - SimpleField( - name="sourcefile", - type="Edm.String", - filterable=True, - facetable=True, - ), - SimpleField( - name="storageUrl", - type="Edm.String", - filterable=True, - facetable=False, - ), - ] - if self.use_acls: - fields.append( - SimpleField( - name="oids", - type=SearchFieldDataType.Collection(SearchFieldDataType.String), - filterable=True, - ) - ) - fields.append( - SimpleField( - name="groups", - type=SearchFieldDataType.Collection(SearchFieldDataType.String), - filterable=True, - ) - ) - if self.use_int_vectorization: - logger.info("Including parent_id field in new index %s", self.search_info.index_name) - fields.append(SearchableField(name="parent_id", type="Edm.String", filterable=True)) - if self.search_images: - logger.info("Including imageEmbedding field in new index %s", self.search_info.index_name) - fields.append( - SearchField( - name="imageEmbedding", - type=SearchFieldDataType.Collection(SearchFieldDataType.Single), - hidden=False, - searchable=True, - filterable=False, - sortable=False, - facetable=False, - vector_search_dimensions=1024, - vector_search_profile_name="embedding_config", - ), - ) + ) + + if self.search_info.index_name not in [name async for name in search_index_client.list_index_names()]: + logger.info("Creating new search index %s", self.search_info.index_name) vectorizers = [] if self.embeddings and isinstance(self.embeddings, AzureOpenAIEmbeddingService): @@ -217,16 +222,17 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] else: logger.info("Search index %s already exists", self.search_info.index_name) existing_index = await search_index_client.get_index(self.search_info.index_name) - if not any(field.name == "storageUrl" for field in existing_index.fields): - logger.info("Adding storageUrl field to index %s", self.search_info.index_name) - existing_index.fields.append( - SimpleField( - name="storageUrl", - type="Edm.String", - filterable=True, - facetable=False, - ), + existing_field_names = {field.name for field in existing_index.fields} + + # Check and add missing fields + missing_fields = [field for field in fields if field.name not in existing_field_names] + if missing_fields: + logger.info( + "Adding missing fields to index %s: %s", + self.search_info.index_name, + [field.name for field in missing_fields], ) + existing_index.fields.extend(missing_fields) await search_index_client.create_or_update_index(existing_index) if existing_index.vector_search is not None and ( @@ -252,19 +258,55 @@ async def create_index(self, vectorizers: Optional[List[VectorSearchVectorizer]] self.search_info, ) + async def file_exists(self, file: File) -> bool: + async with self.search_info.create_search_client() as search_client: + ## make sure that we don't update unchanged sections, by if sourcefile and md5 are the same + if file.metadata.get("md5") is not None: + filter = None + assert file.filename() is not None + filter = f"sourcefile eq '{str(file.filename())}' and md5 eq '{file.metadata.get('md5')}'" + + # make sure (when applicable) that we don't skip if different categories have same file.filename() + # TODO: refactoring: check if using file.filename() as primary for blob is a good idea, or better use sha256(instead as md5) as reliable for blob and index primary key + if file.metadata.get("category") is not None: + filter = filter + f" and category eq '{file.metadata.get('category')}'" + max_results = 1 + result = await search_client.search( + search_text="", filter=filter, top=max_results, include_total_count=True + ) + result_count = await result.get_count() + if result_count > 0: + logger.debug("Skipping %s, no changes detected.", file.filename()) + return True + else: + return False + ## -- end of check + async def update_content( - self, sections: List[Section], image_embeddings: Optional[List[List[float]]] = None, url: Optional[str] = None + self, sections: List[Section], file: File, image_embeddings: Optional[List[List[float]]] = None ): MAX_BATCH_SIZE = 1000 section_batches = [sections[i : i + MAX_BATCH_SIZE] for i in range(0, len(sections), MAX_BATCH_SIZE)] async with self.search_info.create_search_client() as search_client: + + ## caluclate a (default) updated timestamp in format of index + if file.metadata.get("updated") is None: + docdate = datetime.now(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + else: + docdate = parser.isoparse(file.metadata.get("updated")).strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + for batch_index, batch in enumerate(section_batches): documents = [ { "id": f"{section.content.filename_to_id()}-page-{section_index + batch_index * MAX_BATCH_SIZE}", "content": section.split_page.text, - "category": section.category, + "category": file.metadata.get("category"), + "md5": file.metadata.get("md5"), + "deeplink": file.metadata.get( + "deeplink" + ), # optional deel link original doc source for citiation,inline view + "updated": docdate, "sourcepage": ( BlobManager.blob_image_name_from_file_page( filename=section.content.filename(), @@ -281,9 +323,9 @@ async def update_content( } for section_index, section in enumerate(batch) ] - if url: + if file.url: for document in documents: - document["storageUrl"] = url + document["storageUrl"] = file.url if self.embeddings: embeddings = await self.embeddings.create_embeddings( texts=[section.split_page.text for section in batch] diff --git a/app/backend/requirements.txt b/app/backend/requirements.txt index fe339f08c1..51df00e14b 100644 --- a/app/backend/requirements.txt +++ b/app/backend/requirements.txt @@ -350,7 +350,7 @@ python-dateutil==2.9.0.post0 # time-machine python-dotenv==1.0.1 # via -r requirements.in -quart==0.19.7 +quart==0.19.6 # via # -r requirements.in # quart-cors @@ -403,7 +403,6 @@ typing-extensions==4.12.2 # azure-ai-documentintelligence # azure-core # azure-identity - # azure-search-documents # azure-storage-blob # azure-storage-file-datalake # openai @@ -416,7 +415,7 @@ urllib3==2.2.2 # via requests uvicorn==0.30.6 # via -r requirements.in -werkzeug==3.0.6 +werkzeug==3.0.4 # via # flask # quart diff --git a/docs/login_and_acl.md b/docs/login_and_acl.md index 77748a975c..247fd139a4 100644 --- a/docs/login_and_acl.md +++ b/docs/login_and_acl.md @@ -254,7 +254,9 @@ The script performs the following steps: - Creates example [groups](https://learn.microsoft.com/entra/fundamentals/how-to-manage-groups) listed in the [sampleacls.json](/scripts/sampleacls.json) file. - Creates a filesystem / container `gptkbcontainer` in the storage account. - Creates the directories listed in the [sampleacls.json](/scripts/sampleacls.json) file. -- Uploads the sample PDFs referenced in the [sampleacls.json](/scripts/sampleacls.json) file into the appropriate directories. +- Scans the directories for files recursively if you add the option '--scandirs' (default false) cto the argument list (default off) and you don't have '"scandir": false' (default true) below the directory element in the sampleacls.json file. +- Caluclates md5 checksuk of each file refrenced anc compares with existing 'filename.ext.md5' file. Skip upload if same else upload and storenew md5 value in 'filename.ext.md5' +- Uploads the sample PDFs referenced in the [sampleacls.json](/scripts/sampleacls.json) file or files found in the folders with scandir option set to true into the appropriate directories. - [Recursively sets Access Control Lists (ACLs)](https://learn.microsoft.com/azure/storage/blobs/data-lake-storage-acl-cli) using the information from the [sampleacls.json](/scripts/sampleacls.json) file. In order to use the sample access control, you need to join these groups in your Microsoft Entra tenant. diff --git a/scripts/adlsgen2setup.py b/scripts/adlsgen2setup.py index 1deccdf199..10a5cd59a0 100644 --- a/scripts/adlsgen2setup.py +++ b/scripts/adlsgen2setup.py @@ -1,8 +1,10 @@ import argparse import asyncio +import hashlib import json import logging import os +from datetime import datetime from typing import Any, Optional import aiohttp @@ -16,6 +18,9 @@ from load_azd_env import load_azd_env logger = logging.getLogger("scripts") +# Set the logging level for the azure package to DEBUG +logging.getLogger("azure").setLevel(logging.DEBUG) +logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.DEBUG) class AdlsGen2Setup: @@ -56,7 +61,7 @@ def __init__( self.data_access_control_format = data_access_control_format self.graph_headers: Optional[dict[str, str]] = None - async def run(self): + async def run(self, scandirs: bool = False): async with self.create_service_client() as service_client: logger.info(f"Ensuring {self.filesystem_name} exists...") async with service_client.get_file_system_client(self.filesystem_name) as filesystem_client: @@ -80,6 +85,10 @@ async def run(self): ) directories[directory] = directory_client + logger.info("Uploading scanned files...") + if scandirs: + await self.scan_and_upload_directories(directories, filesystem_client) + logger.info("Uploading files...") for file, file_info in self.data_access_control_format["files"].items(): directory = file_info["directory"] @@ -110,15 +119,109 @@ async def run(self): for directory_client in directories.values(): await directory_client.close() + async def walk_files(self, src_filepath="."): + filepath_list = [] + + # This for loop uses the os.walk() function to walk through the files and directories + # and records the filepaths of the files to a list + for root, dirs, files in os.walk(src_filepath): + + # iterate through the files currently obtained by os.walk() and + # create the filepath string for that file and add it to the filepath_list list + root_found: bool = False + for file in files: + # Checks to see if the root is '.' and changes it to the correct current + # working directory by calling os.getcwd(). Otherwise root_path will just be the root variable value. + + if not root_found and root == ".": + filepath = os.path.join(os.getcwd() + "/", file) + root_found = True + else: + filepath = os.path.join(root, file) + + # Appends filepath to filepath_list if filepath does not currently exist in filepath_list + if filepath not in filepath_list: + filepath_list.append(filepath) + + # Return filepath_list + return filepath_list + + async def scan_and_upload_directories(self, directories: dict[str, DataLakeDirectoryClient], filesystem_client): + logger.info("Scanning and uploading files from directories recursively...") + + for directory, directory_client in directories.items(): + directory_path = os.path.join(self.data_directory, directory) + if directory == "/": + continue + + # Check if 'scandir' exists and is set to False + if not self.data_access_control_format["directories"][directory].get("scandir", True): + logger.info(f"Skipping directory {directory} as 'scandir' is set to False") + continue + + # Check if the directory exists before walking it + if not os.path.exists(directory_path): + logger.warning(f"Directory does not exist: {directory_path}") + continue + + # Get all file paths using the walk_files function + file_paths = await self.walk_files(directory_path) + + # Upload each file collected + count = 0 + num = len(file_paths) + for file_path in file_paths: + await self.upload_file(directory_client, file_path, directory) + count = +1 + logger.info(f"Uploaded [{count}/{num}] {directory}/{file_path}") + def create_service_client(self): return DataLakeServiceClient( account_url=f"https://{self.storage_account_name}.dfs.core.windows.net", credential=self.credentials ) - async def upload_file(self, directory_client: DataLakeDirectoryClient, file_path: str): - with open(file=file_path, mode="rb") as f: - file_client = directory_client.get_file_client(file=os.path.basename(file_path)) - await file_client.upload_data(f, overwrite=True) + async def calc_md5(self, path: str) -> str: + hash_md5 = hashlib.md5() + with open(path, "rb") as file: + for chunk in iter(lambda: file.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + async def get_blob_md5(self, directory_client: DataLakeDirectoryClient, filename: str) -> Optional[str]: + """ + Retrieves the MD5 checksum from the metadata of the specified blob. + """ + file_client = directory_client.get_file_client(filename) + try: + properties = await file_client.get_file_properties() + return properties.metadata.get("md5") + except Exception as e: + logger.error(f"Error getting blob properties for {filename}: {e}") + return None + + async def upload_file(self, directory_client: DataLakeDirectoryClient, file_path: str, category: str = ""): + # Calculate MD5 hash once + md5_hash = await self.calc_md5(file_path) + + # Get the filename + filename = os.path.basename(file_path) + + # Get the MD5 checksum from the blob metadata + blob_md5 = await self.get_blob_md5(directory_client, filename) + + # Upload the file if it does not exist or the checksum differs + if blob_md5 is None or md5_hash != blob_md5: + with open(file_path, "rb") as f: + file_client = directory_client.get_file_client(filename) + tmtime = os.path.getmtime(file_path) + last_modified = datetime.fromtimestamp(tmtime).isoformat() + title = os.path.splitext(filename)[0] + metadata = {"md5": md5_hash, "category": category, "updated": last_modified, "title": title} + await file_client.upload_data(f, overwrite=True) + await file_client.set_metadata(metadata) + logger.info(f"Uploaded and updated metadata for {filename}") + else: + logger.info(f"No upload needed for {filename}, checksums match") async def create_or_get_group(self, group_name: str): group_id = None @@ -166,18 +269,18 @@ async def main(args: Any): command = AdlsGen2Setup( data_directory=args.data_directory, storage_account_name=os.environ["AZURE_ADLS_GEN2_STORAGE_ACCOUNT"], - filesystem_name="gptkbcontainer", + filesystem_name=os.environ["AZURE_ADLS_GEN2_FILESYSTEM"], security_enabled_groups=args.create_security_enabled_groups, credentials=credentials, data_access_control_format=data_access_control_format, ) - await command.run() + await command.run(args.scandirs) if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Upload sample data to a Data Lake Storage Gen2 account and associate sample access control lists with it using sample groups", - epilog="Example: ./scripts/adlsgen2setup.py ./data --data-access-control ./scripts/sampleacls.json --create-security-enabled-groups ", + description="Upload data to a Data Lake Storage Gen2 account and associate access control lists with it using sample groups", + epilog="Example: ./scripts/adlsgen2setup.py ./data --data-access-control .azure/${AZURE_ENV_NAME}/docs_acls.json --create-security-enabled-groups --scandirs", ) parser.add_argument("data_directory", help="Data directory that contains sample PDFs") parser.add_argument( @@ -190,6 +293,9 @@ async def main(args: Any): "--data-access-control", required=True, help="JSON file describing access control for the sample data" ) parser.add_argument("--verbose", "-v", required=False, action="store_true", help="Verbose output") + parser.add_argument( + "--scandirs", required=False, action="store_true", help="Scan and upload all files from directories recursively" + ) args = parser.parse_args() if args.verbose: logging.basicConfig() diff --git a/scripts/sampleacls.json b/scripts/sampleacls.json index dd2d4888fa..b83ca3708c 100644 --- a/scripts/sampleacls.json +++ b/scripts/sampleacls.json @@ -21,10 +21,16 @@ }, "directories": { "employeeinfo": { - "groups": ["GPTKB_HRTest"] + "groups": ["GPTKB_HRTest"], + "scandir": false }, "benefitinfo": { - "groups": ["GPTKB_EmployeeTest", "GPTKB_HRTest"] + "groups": ["GPTKB_EmployeeTest", "GPTKB_HRTest"], + "scandir": false + }, + "GPT4V_Examples": { + "groups": ["GPTKB_EmployeeTest", "GPTKB_HRTest"], + "scandir": true }, "/": { "groups": ["GPTKB_AdminTest"]