Skip to content

Commit

Permalink
Merge pull request #43 from hv0905/custom-comment
Browse files Browse the repository at this point in the history
Add comments field for images and improve documentation
  • Loading branch information
hv0905 authored Jul 12, 2024
2 parents 0f242b1 + 0c3a34d commit a6cc2b2
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 125 deletions.
25 changes: 14 additions & 11 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DuplicateValidationResponse
from app.Models.api_response.base import NekoProtocol
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import ServiceProvider
from app.Services.vector_db_context import PointNotFoundError
Expand Down Expand Up @@ -82,6 +82,8 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id
point.starred = model.starred
if model.categories is not None:
point.categories = model.categories
if model.comments is not None:
point.comments = model.comments

await services.db_context.updatePayload(point)
logger.success("Image {} updated.", point.id)
Expand Down Expand Up @@ -129,16 +131,17 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
logger.warning("Invalid image file from upload request. id: {}", img_id)
raise HTTPException(422, "Cannot open the image file.") from ex

image_data = ImageData(id=img_id,
url=model.url,
thumbnail_url=model.thumbnail_url,
local=model.local,
categories=model.categories,
starred=model.starred,
format=img_type,
index_date=datetime.now())

await services.upload_service.queue_upload_image(image_data, img_bytes, model.skip_ocr, model.local_thumbnail)
mapped_image = MappedImage(id=img_id,
url=model.url,
thumbnail_url=model.thumbnail_url,
local=model.local,
categories=model.categories,
starred=model.starred,
comments=model.comments,
format=img_type,
index_date=datetime.now())

await services.upload_service.queue_upload_image(mapped_image, img_bytes, model.skip_ocr, model.local_thumbnail)
return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id)


Expand Down
2 changes: 2 additions & 0 deletions app/Models/api_models/admin_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class ImageOptUpdateModel(BaseModel):
description="The url of the thumbnail. Leave empty to keep the value "
"unchanged. Changing the thumbnail_url of an image with a local "
"thumbnail is not allowed.")
comments: Optional[str] = Field(None,
description="The comments of the image. Leave empty to keep the value unchanged.")

def empty(self) -> bool:
return all([item is None for item in self.model_dump().values()])
Expand Down
5 changes: 4 additions & 1 deletion app/Models/api_models/admin_query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ def __init__(self,
"This is the default value if `local=True`\n"
" - `always`: Always generate thumbnail.\n"
" - `never`: Never generate thumbnail. This is the default value if `local=False`."),
skip_ocr: bool = Query(False, description="Whether to skip the OCR process.")):
skip_ocr: bool = Query(False, description="Whether to skip the OCR process."),
comments: Optional[str] = Query(None,
description="Any custom comments or text payload for the image.")):
self.url = url
self.thumbnail_url = thumbnail_url
self.categories = [t.strip() for t in categories.split(',') if t.strip()] if categories else None
self.starred = starred
self.local = local
self.skip_ocr = skip_ocr
self.comments = comments
self.local_thumbnail = local_thumbnail if (local_thumbnail is not None) else (
UploadImageThumbnailMode.IF_NECESSARY if local else UploadImageThumbnailMode.NEVER)
if not self.url and not self.local:
Expand Down
6 changes: 3 additions & 3 deletions app/Models/api_response/images_api_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pydantic import Field

from app.Models.api_response.base import NekoProtocol
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage


class ImageStatus(str, Enum):
Expand All @@ -16,10 +16,10 @@ class QueryByIdApiResponse(NekoProtocol):
"Warning: If NekoImageGallery is deployed in a cluster, "
"the `in_queue` might not be accurate since the index queue "
"is independent of each service instance.")
img: ImageData | None = Field(description="The mapped image data. Only available when `img_status = mapped`.")
img: MappedImage | None = Field(description="The mapped image data. Only available when `img_status = mapped`.")


class QueryImagesApiResponse(NekoProtocol):
images: list[ImageData] = Field(description="The list of images.")
images: list[MappedImage] = Field(description="The list of images.")
next_page_offset: str | None = Field(description="The offset ID for the next page query. "
"If there are no more images, this field will be null.")
53 changes: 0 additions & 53 deletions app/Models/img_data.py

This file was deleted.

62 changes: 62 additions & 0 deletions app/Models/mapped_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from datetime import datetime
from typing import Optional, Annotated
from uuid import UUID

from numpy import ndarray
from pydantic import BaseModel, Field, ConfigDict


class MappedImage(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra='ignore')

id: Annotated[
UUID, Field(description="The unique ID of the image. The ID is generated from the digest of the image.")] = None
url: Annotated[Optional[str], Field(
description="The URL of the image. For non-local images, this is specified by uploader.")] = None
thumbnail_url: Annotated[Optional[str], Field(
description="The URL of the thumbnail image. For non-local thumbnails, "
"this is specified by uploader.")] = None
ocr_text: Annotated[
Optional[str], Field(description="The OCR text of the image. None if no OCR text was detected.")] = None
image_vector: Annotated[Optional[ndarray], Field(exclude=True)] = None
text_contain_vector: Annotated[Optional[ndarray], Field(exclude=True)] = None
index_date: Annotated[datetime, Field(description="The date when the image was indexed.")]
width: Annotated[Optional[int], Field(description="The width of the image in pixels.")] = None
height: Annotated[Optional[int], Field(description="The height of the image in pixels.")] = None
aspect_ratio: Annotated[Optional[float], Field(
description="The aspect ratio of the image. calculated by width / height.")] = None
starred: Annotated[Optional[bool], Field(description="Whether the image is starred.")] = False
categories: Annotated[Optional[list[str]], Field(description="The categories of the image.")] = []
local: Annotated[Optional[bool], Field(
description="Whether the image is stored in local storage.(local image)")] = False
local_thumbnail: Annotated[Optional[bool], Field(
description="Whether the thumbnail image is stored in local storage.")] = False
format: Optional[str] = None # required for s3 local storage
comments: Annotated[Optional[str], Field(description="Any custom comments or text payload for the image.")] = None

@property
def ocr_text_lower(self) -> str | None:
if self.ocr_text is None:
return None
return self.ocr_text.lower()

@property
def payload(self):
result = self.model_dump(exclude={'id', 'index_date'})
# Qdrant database cannot accept datetime object, so we have to convert it to string
result['index_date'] = self.index_date.isoformat()
# Qdrant doesn't support case-insensitive search, so we need to store a lowercase version of the text
result['ocr_text_lower'] = self.ocr_text_lower
return result

@classmethod
def from_payload(cls, img_id: str, payload: dict,
image_vector: Optional[ndarray] = None, text_contain_vector: Optional[ndarray] = None):
# Convert the datetime string back to datetime object
index_date = datetime.fromisoformat(payload['index_date'])
del payload['index_date']
return cls(id=UUID(img_id),
index_date=index_date,
**payload,
image_vector=image_vector if image_vector is not None else None,
text_contain_vector=text_contain_vector if text_contain_vector is not None else None)
5 changes: 3 additions & 2 deletions app/Models/search_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from pydantic import BaseModel
from .img_data import ImageData

from .mapped_image import MappedImage


class SearchResult(BaseModel):
img: ImageData
img: MappedImage
score: float
10 changes: 5 additions & 5 deletions app/Services/index_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastapi.concurrency import run_in_threadpool

from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage
from app.Services.lifespan_service import LifespanService
from app.Services.ocr_services import OCRService
from app.Services.transformers_service import TransformersService
Expand All @@ -16,7 +16,7 @@ def __init__(self, ocr_service: OCRService, transformers_service: TransformersSe
self._transformers_service = transformers_service
self._db_context = db_context

def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False):
def _prepare_image(self, image: Image.Image, image_data: MappedImage, skip_ocr=False):
image_data.width = image.width
image_data.height = image.height
image_data.aspect_ratio = float(image.width) / image.height
Expand All @@ -34,12 +34,12 @@ def _prepare_image(self, image: Image.Image, image_data: ImageData, skip_ocr=Fal
image_data.ocr_text = None

# currently, here only need just a simple check
async def _is_point_duplicate(self, image_data: list[ImageData]) -> bool:
async def _is_point_duplicate(self, image_data: list[MappedImage]) -> bool:
image_id_list = [str(item.id) for item in image_data]
result = await self._db_context.validate_ids(image_id_list)
return len(result) != 0

async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=False, skip_duplicate_check=False,
async def index_image(self, image: Image.Image, image_data: MappedImage, skip_ocr=False, skip_duplicate_check=False,
background=False):
if not skip_duplicate_check and (await self._is_point_duplicate([image_data])):
raise PointDuplicateError("The uploaded points are contained in the database!", image_data.id)
Expand All @@ -51,7 +51,7 @@ async def index_image(self, image: Image.Image, image_data: ImageData, skip_ocr=

await self._db_context.insertItems([image_data])

async def index_image_batch(self, image: list[Image.Image], image_data: list[ImageData],
async def index_image_batch(self, image: list[Image.Image], image_data: list[MappedImage],
skip_ocr=False, allow_overwrite=False):
if not allow_overwrite and (await self._is_point_duplicate(image_data)):
raise PointDuplicateError("The uploaded points are contained in the database!")
Expand Down
46 changes: 23 additions & 23 deletions app/Services/upload_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from app.Models.api_models.admin_query_params import UploadImageThumbnailMode
from app.Models.errors import PointDuplicateError
from app.Models.img_data import ImageData
from app.Models.mapped_image import MappedImage
from app.Services.index_service import IndexService
from app.Services.lifespan_service import LifespanService
from app.Services.storage import StorageService
Expand Down Expand Up @@ -46,44 +46,44 @@ async def _upload_worker(self):
if self._processed_count % 50 == 0:
gc.collect()

async def _upload_task(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
async def _upload_task(self, mapped_img: MappedImage, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
img = Image.open(BytesIO(img_bytes))
logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes))
file_name = f"{img_data.id}.{img_data.format}"
thumb_path = f"thumbnails/{img_data.id}.webp"
logger.info('Start indexing image {}. Local: {}. Size: {}', mapped_img.id, mapped_img.local, len(img_bytes))
file_name = f"{mapped_img.id}.{mapped_img.format}"
thumb_path = f"thumbnails/{mapped_img.id}.webp"
gen_thumb = thumbnail_mode == UploadImageThumbnailMode.ALWAYS or (
thumbnail_mode == UploadImageThumbnailMode.IF_NECESSARY and len(img_bytes) > 1024 * 500)

if img_data.local:
img_data.url = await self._storage_service.active_storage.url(file_name)
if mapped_img.local:
mapped_img.url = await self._storage_service.active_storage.url(file_name)
if gen_thumb:
img_data.thumbnail_url = await self._storage_service.active_storage.url(
f"thumbnails/{img_data.id}.webp")
img_data.local_thumbnail = True
mapped_img.thumbnail_url = await self._storage_service.active_storage.url(
f"thumbnails/{mapped_img.id}.webp")
mapped_img.local_thumbnail = True

await self._index_service.index_image(img, img_data, skip_ocr=skip_ocr, background=True)
logger.success("Image {} indexed.", img_data.id)
await self._index_service.index_image(img, mapped_img, skip_ocr=skip_ocr, background=True)
logger.success("Image {} indexed.", mapped_img.id)

if img_data.local:
logger.info("Start uploading image {} to local storage.", img_data.id)
if mapped_img.local:
logger.info("Start uploading image {} to local storage.", mapped_img.id)
await self._storage_service.active_storage.upload(img_bytes, file_name)
logger.success("Image {} uploaded to local storage.", img_data.id)
logger.success("Image {} uploaded to local storage.", mapped_img.id)
if gen_thumb:
logger.info("Start generate and upload thumbnail for {}.", img_data.id)
logger.info("Start generate and upload thumbnail for {}.", mapped_img.id)
img.thumbnail((256, 256), resample=Image.Resampling.LANCZOS)
img_byte_arr = BytesIO()
img.save(img_byte_arr, 'WebP', save_all=True)
await self._storage_service.active_storage.upload(img_byte_arr.getvalue(), thumb_path)
logger.success("Thumbnail for {} generated and uploaded!", img_data.id)
logger.success("Thumbnail for {} generated and uploaded!", mapped_img.id)

img.close()

async def queue_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
async def queue_upload_image(self, mapped_img: MappedImage, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
self.uploading_ids.add(img_data.id)
await self._queue.put((img_data, img_bytes, skip_ocr, thumbnail_mode))
logger.success("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize())
self.uploading_ids.add(mapped_img.id)
await self._queue.put((mapped_img, img_bytes, skip_ocr, thumbnail_mode))
logger.success("Image {} added to upload queue. Queue Length: {} [+1]", mapped_img.id, self._queue.qsize())

async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes):
img_id = generate_uuid(img_file)
Expand All @@ -94,9 +94,9 @@ async def assign_image_id(self, img_file: pathlib.Path | io.BytesIO | bytes):
img_id)
return img_id

async def sync_upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
async def sync_upload_image(self, mapped_img: MappedImage, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
await self._upload_task(img_data, img_bytes, skip_ocr, thumbnail_mode)
await self._upload_task(mapped_img, img_bytes, skip_ocr, thumbnail_mode)

def get_queue_size(self):
return self._queue.qsize()
Expand Down
Loading

0 comments on commit a6cc2b2

Please sign in to comment.