From d693e03ead6e97e978b7a43ee82c36e1bd6d9bfc Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Fri, 12 Jul 2024 22:08:41 +0800 Subject: [PATCH 1/3] Add comments field for images --- app/Controllers/admin.py | 3 ++ app/Models/api_models/admin_api_model.py | 2 ++ app/Models/api_models/admin_query_params.py | 5 +++- app/Models/img_data.py | 32 +++++++++++++-------- tests/api/test_upload.py | 3 +- 5 files changed, 31 insertions(+), 14 deletions(-) diff --git a/app/Controllers/admin.py b/app/Controllers/admin.py index b4c3416..175ec02 100644 --- a/app/Controllers/admin.py +++ b/app/Controllers/admin.py @@ -82,6 +82,8 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id point.starred = model.starred if model.categories is not None: point.categories = model.categories + if model.comments is not None: + point.comments = model.comments await services.db_context.updatePayload(point) logger.success("Image {} updated.", point.id) @@ -135,6 +137,7 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i local=model.local, categories=model.categories, starred=model.starred, + comments=model.comments, format=img_type, index_date=datetime.now()) diff --git a/app/Models/api_models/admin_api_model.py b/app/Models/api_models/admin_api_model.py index 12037aa..773fc92 100644 --- a/app/Models/api_models/admin_api_model.py +++ b/app/Models/api_models/admin_api_model.py @@ -18,6 +18,8 @@ class ImageOptUpdateModel(BaseModel): description="The url of the thumbnail. Leave empty to keep the value " "unchanged. Changing the thumbnail_url of an image with a local " "thumbnail is not allowed.") + comments: Optional[str] = Field(None, + description="The comments of the image. Leave empty to keep the value unchanged.") def empty(self) -> bool: return all([item is None for item in self.model_dump().values()]) diff --git a/app/Models/api_models/admin_query_params.py b/app/Models/api_models/admin_query_params.py index eba8dd4..03bb112 100644 --- a/app/Models/api_models/admin_query_params.py +++ b/app/Models/api_models/admin_query_params.py @@ -35,13 +35,16 @@ def __init__(self, "This is the default value if `local=True`\n" " - `always`: Always generate thumbnail.\n" " - `never`: Never generate thumbnail. This is the default value if `local=False`."), - skip_ocr: bool = Query(False, description="Whether to skip the OCR process.")): + skip_ocr: bool = Query(False, description="Whether to skip the OCR process."), + comments: Optional[str] = Query(None, + description="Any custom comments or text payload for the image.")): self.url = url self.thumbnail_url = thumbnail_url self.categories = [t.strip() for t in categories.split(',') if t.strip()] if categories else None self.starred = starred self.local = local self.skip_ocr = skip_ocr + self.comments = comments self.local_thumbnail = local_thumbnail if (local_thumbnail is not None) else ( UploadImageThumbnailMode.IF_NECESSARY if local else UploadImageThumbnailMode.NEVER) if not self.url and not self.local: diff --git a/app/Models/img_data.py b/app/Models/img_data.py index 9c8a854..66e474b 100644 --- a/app/Models/img_data.py +++ b/app/Models/img_data.py @@ -9,21 +9,29 @@ class ImageData(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra='ignore') - id: UUID - url: Optional[str] = None - thumbnail_url: Optional[str] = None - ocr_text: Optional[str] = None + id: UUID = Field(description="The unique ID of the image. The ID is generated from the digest of the image.") + url: Optional[str] = Field(default=None, + description="The URL of the image. For non-local images, this is specified by uploader.") + thumbnail_url: Optional[str] = Field(default=None, + description="The URL of the thumbnail image. For non-local thumbnails, " + "this is specified by uploader.") + ocr_text: Optional[str] = Field(default=None, + description="The OCR text of the image. None if no OCR text was detected.") image_vector: Optional[ndarray] = Field(None, exclude=True) text_contain_vector: Optional[ndarray] = Field(None, exclude=True) - index_date: datetime - width: Optional[int] = None - height: Optional[int] = None - aspect_ratio: Optional[float] = None - starred: Optional[bool] = False - categories: Optional[list[str]] = [] - local: Optional[bool] = False - local_thumbnail: Optional[bool] = False + index_date: datetime = Field(description="The date when the image was indexed.") + width: Optional[int] = Field(default=None, description="The width of the image in pixels.") + height: Optional[int] = Field(default=None, description="The height of the image in pixels.") + aspect_ratio: Optional[float] = Field(default=None, + description="The aspect ratio of the image. calculated by width / height.") + starred: Optional[bool] = Field(default=False, description="Whether the image is starred.") + categories: Optional[list[str]] = Field(default=[], description="The categories of the image.") + local: Optional[bool] = Field(default=False, + description="Whether the image is stored in local storage.(local image).") + local_thumbnail: Optional[bool] = Field(default=False, + description="Whether the thumbnail image is stored in local storage.") format: Optional[str] = None # required for s3 local storage + comments: Optional[str] = Field(default=None, description="Any custom comments or text payload for the image.") @property def ocr_text_lower(self) -> str | None: diff --git a/tests/api/test_upload.py b/tests/api/test_upload.py index 03e6ffa..a7c3168 100644 --- a/tests/api/test_upload.py +++ b/tests/api/test_upload.py @@ -156,7 +156,8 @@ async def test_upload_thumbnails(test_client, ensure_local_dir_empty, wait_for_b ({'url': TEST_FAKE_URL}, {'url': TEST_FAKE_URL_NEW, 'thumbnail_url': TEST_FAKE_THUMBNAIL_URL_NEW}, {'url': TEST_FAKE_URL_NEW, 'thumbnail_url': TEST_FAKE_THUMBNAIL_URL_NEW}, 200), ({'local_thumbnail': 'always', 'url': TEST_FAKE_URL}, {'url': TEST_FAKE_URL_NEW}, {'url': TEST_FAKE_URL_NEW}, 200), - ({'local': True}, {'categories': ['1'], 'starred': True}, {'categories': ['1'], 'starred': True}, 200), + ({'local': True}, {'categories': ['1'], 'starred': True, 'comments': 'ciallo'}, + {'categories': ['1'], 'starred': True, 'comments': 'ciallo'}, 200), ({'local': True}, {'url': TEST_FAKE_URL_NEW}, {}, 422), ({'local': True}, {'thumbnail_url': TEST_FAKE_THUMBNAIL_URL_NEW}, {}, 422), ({'local_thumbnail': 'always', 'url': TEST_FAKE_URL}, {'thumbnail_url': TEST_FAKE_THUMBNAIL_URL_NEW}, {}, 422), From 0d9ee7687940965e1b63b13201ea8de62ffc7b2e Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Fri, 12 Jul 2024 22:16:30 +0800 Subject: [PATCH 2/3] Rename ImageData to MappedImage to have a better readability --- app/Controllers/admin.py | 24 +++++----- .../api_response/images_api_response.py | 6 +-- app/Models/{img_data.py => mapped_image.py} | 2 +- app/Models/search_result.py | 5 +- app/Services/index_service.py | 10 ++-- app/Services/upload_service.py | 46 +++++++++---------- app/Services/vector_db_context.py | 36 +++++++-------- scripts/local_indexing.py | 16 +++---- 8 files changed, 73 insertions(+), 72 deletions(-) rename app/Models/{img_data.py => mapped_image.py} (99%) diff --git a/app/Controllers/admin.py b/app/Controllers/admin.py index 175ec02..adc1c4c 100644 --- a/app/Controllers/admin.py +++ b/app/Controllers/admin.py @@ -14,7 +14,7 @@ DuplicateValidationResponse from app.Models.api_response.base import NekoProtocol from app.Models.errors import PointDuplicateError -from app.Models.img_data import ImageData +from app.Models.mapped_image import MappedImage from app.Services.authentication import force_admin_token_verify from app.Services.provider import ServiceProvider from app.Services.vector_db_context import PointNotFoundError @@ -131,17 +131,17 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i logger.warning("Invalid image file from upload request. id: {}", img_id) raise HTTPException(422, "Cannot open the image file.") from ex - image_data = ImageData(id=img_id, - url=model.url, - thumbnail_url=model.thumbnail_url, - local=model.local, - categories=model.categories, - starred=model.starred, - comments=model.comments, - format=img_type, - index_date=datetime.now()) - - await services.upload_service.queue_upload_image(image_data, img_bytes, model.skip_ocr, model.local_thumbnail) + mapped_image = MappedImage(id=img_id, + url=model.url, + thumbnail_url=model.thumbnail_url, + local=model.local, + categories=model.categories, + starred=model.starred, + comments=model.comments, + format=img_type, + index_date=datetime.now()) + + await services.upload_service.queue_upload_image(mapped_image, img_bytes, model.skip_ocr, model.local_thumbnail) return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id) diff --git a/app/Models/api_response/images_api_response.py b/app/Models/api_response/images_api_response.py index 875d19a..ec3709a 100644 --- a/app/Models/api_response/images_api_response.py +++ b/app/Models/api_response/images_api_response.py @@ -3,7 +3,7 @@ from pydantic import Field from app.Models.api_response.base import NekoProtocol -from app.Models.img_data import ImageData +from app.Models.mapped_image import MappedImage class ImageStatus(str, Enum): @@ -16,10 +16,10 @@ class QueryByIdApiResponse(NekoProtocol): "Warning: If NekoImageGallery is deployed in a cluster, " "the `in_queue` might not be accurate since the index queue " "is independent of each service instance.") - img: ImageData | None = Field(description="The mapped image data. Only available when `img_status = mapped`.") + img: MappedImage | None = Field(description="The mapped image data. Only available when `img_status = mapped`.") class QueryImagesApiResponse(NekoProtocol): - images: list[ImageData] = Field(description="The list of images.") + images: list[MappedImage] = Field(description="The list of images.") next_page_offset: str | None = Field(description="The offset ID for the next page query. " "If there are no more images, this field will be null.") diff --git a/app/Models/img_data.py b/app/Models/mapped_image.py similarity index 99% rename from app/Models/img_data.py rename to app/Models/mapped_image.py index 66e474b..4238135 100644 --- a/app/Models/img_data.py +++ b/app/Models/mapped_image.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, ConfigDict -class ImageData(BaseModel): +class MappedImage(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra='ignore') id: UUID = Field(description="The unique ID of the image. The ID is generated from the digest of the image.") diff --git a/app/Models/search_result.py b/app/Models/search_result.py index c5bd431..a657360 100644 --- a/app/Models/search_result.py +++ b/app/Models/search_result.py @@ -1,7 +1,8 @@ from pydantic import BaseModel -from .img_data import ImageData + +from .mapped_image import MappedImage class SearchResult(BaseModel): - img: ImageData + img: MappedImage score: float diff --git a/app/Services/index_service.py b/app/Services/index_service.py index 00cdcf5..6b9f930 100644 --- a/app/Services/index_service.py +++ b/app/Services/index_service.py @@ -2,7 +2,7 @@ from fastapi.concurrency import run_in_threadpool from app.Models.errors import PointDuplicateError -from app.Models.img_data import ImageData +from app.Models.mapped_image import MappedImage from app.Services.lifespan_service import LifespanService from app.Services.ocr_services import OCRService from app.Services.transformers_service import TransformersService @@ -16,7 +16,7 @@ def __init__(self, ocr_service: OCRService, transformers_service: TransformersSe self._transformers_service = transformers_service self._db_context = db_context - def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False): + def _prepare_image(self, image: Image.Image, image_data: MappedImage, skip_ocr=False): image_data.width = image.width image_data.height = image.height image_data.aspect_ratio = float(image.width) / image.height @@ -34,12 +34,12 @@ def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=Fal image_data.ocr_text = None # currently, here only need just a simple check - async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool: + async def _is_point_duplicate(self, image_data: list[MappedImage]) -> bool: image_id_list = [str(item.id) for item in image_data] result = await self._db_context.validate_ids(image_id_list) return len(result) != 0 - async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False, + async def index_image(self, image: Image.Image, image_data: MappedImage, skip_ocr=False, skip_duplicate_check=False, background=False): if not skip_duplicate_check and (await self._is_point_duplicate([image_data])): raise PointDuplicateError("The uploaded points are contained in the database!", image_data.id) @@ -51,7 +51,7 @@ async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr= await self._db_context.insertItems([image_data]) - async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData], + async def index_image_batch(self, image: list[Image.Image], image_data: list[MappedImage], skip_ocr=False, allow_overwrite=False): if not allow_overwrite and (await self._is_point_duplicate(image_data)): raise PointDuplicateError("The uploaded points are contained in the database!") diff --git a/app/Services/upload_service.py b/app/Services/upload_service.py index 51b59fb..815b127 100644 --- a/app/Services/upload_service.py +++ b/app/Services/upload_service.py @@ -9,7 +9,7 @@ from app.Models.api_models.admin_query_params import UploadImageThumbnailMode from app.Models.errors import PointDuplicateError -from app.Models.img_data import ImageData +from app.Models.mapped_image import MappedImage from app.Services.index_service import IndexService from app.Services.lifespan_service import LifespanService from app.Services.storage import StorageService @@ -46,44 +46,44 @@ async def _upload_worker(self): if self._processed_count % 50 == 0: gc.collect() - async def _upload_task(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool, + async def _upload_task(self, mapped_img: MappedImage, img_bytes: bytes, skip_ocr: bool, thumbnail_mode: UploadImageThumbnailMode): img = Image.open(BytesIO(img_bytes)) - logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes)) - file_name = f"{img_data.id}.{img_data.format}" - thumb_path = f"thumbnails/{img_data.id}.webp" + logger.info('Start indexing image {}. Local: {}. Size: {}', mapped_img.id, mapped_img.local, len(img_bytes)) + file_name = f"{mapped_img.id}.{mapped_img.format}" + thumb_path = f"thumbnails/{mapped_img.id}.webp" gen_thumb = thumbnail_mode == UploadImageThumbnailMode.ALWAYS or ( thumbnail_mode == UploadImageThumbnailMode.IF_NECESSARY and len(img_bytes) > 1024 * 500) - if img_data.local: - img_data.url = await self._storage_service.active_storage.url(file_name) + if mapped_img.local: + mapped_img.url = await self._storage_service.active_storage.url(file_name) if gen_thumb: - img_data.thumbnail_url = await self._storage_service.active_storage.url( - f"thumbnails/{img_data.id}.webp") - img_data.local_thumbnail = True + mapped_img.thumbnail_url = await self._storage_service.active_storage.url( + f"thumbnails/{mapped_img.id}.webp") + mapped_img.local_thumbnail = True - await self._index_service.index_image(img, img_data, skip_ocr=skip_ocr, background=True) - logger.success("Image {} indexed.", img_data.id) + await self._index_service.index_image(img, mapped_img, skip_ocr=skip_ocr, background=True) + logger.success("Image {} indexed.", mapped_img.id) - if img_data.local: - logger.info("Start uploading image {} to local storage.", img_data.id) + if mapped_img.local: + logger.info("Start uploading image {} to local storage.", mapped_img.id) await self._storage_service.active_storage.upload(img_bytes, file_name) - logger.success("Image {} uploaded to local storage.", img_data.id) + logger.success("Image {} uploaded to local storage.", mapped_img.id) if gen_thumb: - logger.info("Start generate and upload thumbnail for {}.", img_data.id) + logger.info("Start generate and upload thumbnail for {}.", mapped_img.id) img.thumbnail((256, 256), resample=Image.Resampling.LANCZOS) img_byte_arr = BytesIO() img.save(img_byte_arr, 'WebP', save_all=True) await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path) - logger.success("Thumbnail for {} generated and uploaded!", img_data.id) + logger.success("Thumbnail for {} generated and uploaded!", mapped_img.id) img.close() - async def queue_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool, + async def queue_upload_image(self, mapped_img: MappedImage, img_bytes: bytes, skip_ocr: bool, thumbnail_mode: UploadImageThumbnailMode): - self.uploading_ids.add(img_data.id) - await self._queue.put((img_data, img_bytes, skip_ocr, thumbnail_mode)) - logger.success("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize()) + self.uploading_ids.add(mapped_img.id) + await self._queue.put((mapped_img, img_bytes, skip_ocr, thumbnail_mode)) + logger.success("Image {} added to upload queue. Queue Length: {} [+1]", mapped_img.id, self._queue.qsize()) async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes): img_id = generate_uuid(img_file) @@ -94,9 +94,9 @@ async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes): img_id) return img_id - async def sync_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool, + async def sync_upload_image(self, mapped_img: MappedImage, img_bytes: bytes, skip_ocr: bool, thumbnail_mode: UploadImageThumbnailMode): - await self._upload_task(img_data, img_bytes, skip_ocr, thumbnail_mode) + await self._upload_task(mapped_img, img_bytes, skip_ocr, thumbnail_mode) def get_queue_size(self): return self._queue.qsize() diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index fe1c22b..4a09485 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -9,7 +9,7 @@ from qdrant_client.models import RecommendStrategy from app.Models.api_models.search_api_model import SearchModelEnum, SearchBasisEnum -from app.Models.img_data import ImageData +from app.Models.mapped_image import MappedImage from app.Models.query_params import FilterParams from app.Models.search_result import SearchResult from app.Services.lifespan_service import LifespanService @@ -50,7 +50,7 @@ async def on_load(self): logger.warning("Collection not found. Initializing...") await self.initialize_collection() - async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData: + async def retrieve_by_id(self, image_id: str, with_vectors=False) -> MappedImage: """ Retrieve an item from database by id. Will raise PointNotFoundError if the given ID doesn't exist. :param image_id: The ID to retrieve. @@ -65,9 +65,9 @@ async def retrieve_by_id(self, image_id: str, with_vectors=False) -> ImageData: if len(result) != 1: logger.error("Point not exist.") raise PointNotFoundError(image_id) - return self._get_img_data_from_point(result[0]) + return self._get_mapped_image_from_point(result[0]) - async def retrieve_by_ids(self, image_id: list[str], with_vectors=False) -> list[ImageData]: + async def retrieve_by_ids(self, image_id: list[str], with_vectors=False) -> list[MappedImage]: """ Retrieve items from the database by IDs. An exception is thrown if there are items in the IDs that do not exist in the database. @@ -85,7 +85,7 @@ async def retrieve_by_ids(self, image_id: list[str], with_vectors=False) -> list if len(missing_point_ids) > 0: logger.error("{} points not exist.", len(missing_point_ids)) raise PointNotFoundError(str(missing_point_ids)) - return self._get_img_data_from_points(result) + return self._get_mapped_image_from_point_batch(result) async def validate_ids(self, image_id: list[str]) -> list[str]: """ @@ -144,10 +144,10 @@ async def querySimilar(self, return [self._get_search_result_from_scored_point(t) for t in result] - async def insertItems(self, items: list[ImageData]): + async def insertItems(self, items: list[MappedImage]): logger.info("Inserting {} items into Qdrant...", len(items)) - points = [self._get_point_from_img_data(t) for t in items] + points = [self._get_point_from_mapped_image(t) for t in items] response = await self._client.upsert(collection_name=self.collection_name, wait=True, @@ -163,7 +163,7 @@ async def deleteItems(self, ids: list[str]): ) logger.success("Delete completed! Status: {}", response.status) - async def updatePayload(self, new_data: ImageData): + async def updatePayload(self, new_data: MappedImage): """ Update the payload of an existing item in the database. Warning: This method will not update the vector of the item. @@ -175,7 +175,7 @@ async def updatePayload(self, new_data: ImageData): wait=True) logger.success("Update completed! Status: {}", response.status) - async def updateVectors(self, new_points: list[ImageData]): + async def updateVectors(self, new_points: list[MappedImage]): resp = await self._client.update_vectors(collection_name=self.collection_name, points=[self._get_vector_from_img_data(t) for t in new_points], ) @@ -186,7 +186,7 @@ async def scroll_points(self, count=50, with_vectors=False, filter_param: FilterParams | None = None, - ) -> tuple[list[ImageData], str]: + ) -> tuple[list[MappedImage], str]: resp, next_id = await self._client.scroll(collection_name=self.collection_name, limit=count, offset=from_id, @@ -194,7 +194,7 @@ async def scroll_points(self, scroll_filter=self._get_filters_by_filter_param(filter_param) ) - return [self._get_img_data_from_point(t) for t in resp], next_id + return [self._get_mapped_image_from_point(t) for t in resp], next_id async def get_counts(self, exact: bool) -> int: resp = await self._client.count(collection_name=self.collection_name, exact=exact) @@ -219,7 +219,7 @@ async def initialize_collection(self): logger.success("Collection created!") @classmethod - def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors: + def _get_vector_from_img_data(cls, img_data: MappedImage) -> models.PointVectors: vector = {} if img_data.image_vector is not None: vector[cls.IMG_VECTOR] = img_data.image_vector.tolist() @@ -231,15 +231,15 @@ def _get_vector_from_img_data(cls, img_data: ImageData) -> models.PointVectors: ) @classmethod - def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct: + def _get_point_from_mapped_image(cls, img_data: MappedImage) -> models.PointStruct: return models.PointStruct( id=str(img_data.id), payload=img_data.payload, vector=cls._get_vector_from_img_data(img_data).vector ) - def _get_img_data_from_point(self, point: AVAILABLE_POINT_TYPES) -> ImageData: - return (ImageData + def _get_mapped_image_from_point(self, point: AVAILABLE_POINT_TYPES) -> MappedImage: + return (MappedImage .from_payload(point.id, point.payload, image_vector=numpy.array(point.vector[self.IMG_VECTOR], dtype=numpy.float32) @@ -248,11 +248,11 @@ def _get_img_data_from_point(self, point: AVAILABLE_POINT_TYPES) -> ImageData: if point.vector and self.TEXT_VECTOR in point.vector else None )) - def _get_img_data_from_points(self, points: list[AVAILABLE_POINT_TYPES]) -> list[ImageData]: - return [self._get_img_data_from_point(t) for t in points] + def _get_mapped_image_from_point_batch(self, points: list[AVAILABLE_POINT_TYPES]) -> list[MappedImage]: + return [self._get_mapped_image_from_point(t) for t in points] def _get_search_result_from_scored_point(self, point: models.ScoredPoint) -> SearchResult: - return SearchResult(img=self._get_img_data_from_point(point), score=point.score) + return SearchResult(img=self._get_mapped_image_from_point(point), score=point.score) @classmethod def vector_name_for_basis(cls, basis: SearchBasisEnum) -> str: diff --git a/scripts/local_indexing.py b/scripts/local_indexing.py index 11c086e..9ac0a09 100644 --- a/scripts/local_indexing.py +++ b/scripts/local_indexing.py @@ -8,7 +8,7 @@ from app.Models.api_models.admin_query_params import UploadImageThumbnailMode from app.Models.errors import PointDuplicateError -from app.Models.img_data import ImageData +from app.Models.mapped_image import MappedImage from app.Services.provider import ServiceProvider from app.util.local_file_utility import glob_local_files @@ -18,13 +18,13 @@ async def index_task(file_path: Path, categories: list[str], starred: bool, thumbnail_mode: UploadImageThumbnailMode): try: img_id = await services.upload_service.assign_image_id(file_path) - image_data = ImageData(id=img_id, - local=True, - categories=categories, - starred=starred, - format=file_path.suffix[1:], # remove the dot - index_date=datetime.now()) - await services.upload_service.sync_upload_image(image_data, file_path.read_bytes(), skip_ocr=False, + mapped_image = MappedImage(id=img_id, + local=True, + categories=categories, + starred=starred, + format=file_path.suffix[1:], # remove the dot + index_date=datetime.now()) + await services.upload_service.sync_upload_image(mapped_image, file_path.read_bytes(), skip_ocr=False, thumbnail_mode=thumbnail_mode) except PointDuplicateError as ex: logger.warning("Image {} already exists in the database", file_path) From 0c3a34d28c958ae11fcc3488a8e1049da6d38876 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Fri, 12 Jul 2024 22:35:06 +0800 Subject: [PATCH 3/3] Convert field description to annotated format for MappedImage --- app/Models/mapped_image.py | 47 +++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/app/Models/mapped_image.py b/app/Models/mapped_image.py index 4238135..d55027d 100644 --- a/app/Models/mapped_image.py +++ b/app/Models/mapped_image.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional +from typing import Optional, Annotated from uuid import UUID from numpy import ndarray @@ -9,29 +9,30 @@ class MappedImage(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra='ignore') - id: UUID = Field(description="The unique ID of the image. The ID is generated from the digest of the image.") - url: Optional[str] = Field(default=None, - description="The URL of the image. For non-local images, this is specified by uploader.") - thumbnail_url: Optional[str] = Field(default=None, - description="The URL of the thumbnail image. For non-local thumbnails, " - "this is specified by uploader.") - ocr_text: Optional[str] = Field(default=None, - description="The OCR text of the image. None if no OCR text was detected.") - image_vector: Optional[ndarray] = Field(None, exclude=True) - text_contain_vector: Optional[ndarray] = Field(None, exclude=True) - index_date: datetime = Field(description="The date when the image was indexed.") - width: Optional[int] = Field(default=None, description="The width of the image in pixels.") - height: Optional[int] = Field(default=None, description="The height of the image in pixels.") - aspect_ratio: Optional[float] = Field(default=None, - description="The aspect ratio of the image. calculated by width / height.") - starred: Optional[bool] = Field(default=False, description="Whether the image is starred.") - categories: Optional[list[str]] = Field(default=[], description="The categories of the image.") - local: Optional[bool] = Field(default=False, - description="Whether the image is stored in local storage.(local image).") - local_thumbnail: Optional[bool] = Field(default=False, - description="Whether the thumbnail image is stored in local storage.") + id: Annotated[ + UUID, Field(description="The unique ID of the image. The ID is generated from the digest of the image.")] = None + url: Annotated[Optional[str], Field( + description="The URL of the image. For non-local images, this is specified by uploader.")] = None + thumbnail_url: Annotated[Optional[str], Field( + description="The URL of the thumbnail image. For non-local thumbnails, " + "this is specified by uploader.")] = None + ocr_text: Annotated[ + Optional[str], Field(description="The OCR text of the image. None if no OCR text was detected.")] = None + image_vector: Annotated[Optional[ndarray], Field(exclude=True)] = None + text_contain_vector: Annotated[Optional[ndarray], Field(exclude=True)] = None + index_date: Annotated[datetime, Field(description="The date when the image was indexed.")] + width: Annotated[Optional[int], Field(description="The width of the image in pixels.")] = None + height: Annotated[Optional[int], Field(description="The height of the image in pixels.")] = None + aspect_ratio: Annotated[Optional[float], Field( + description="The aspect ratio of the image. calculated by width / height.")] = None + starred: Annotated[Optional[bool], Field(description="Whether the image is starred.")] = False + categories: Annotated[Optional[list[str]], Field(description="The categories of the image.")] = [] + local: Annotated[Optional[bool], Field( + description="Whether the image is stored in local storage.(local image)")] = False + local_thumbnail: Annotated[Optional[bool], Field( + description="Whether the thumbnail image is stored in local storage.")] = False format: Optional[str] = None # required for s3 local storage - comments: Optional[str] = Field(default=None, description="Any custom comments or text payload for the image.") + comments: Annotated[Optional[str], Field(description="Any custom comments or text payload for the image.")] = None @property def ocr_text_lower(self) -> str | None: