From a0c18aac0c5d922117b1a5a147cd297cb289c958 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Sun, 24 Dec 2023 00:38:08 +0800 Subject: [PATCH] Add aspect_ratio and min_wh filter feature --- app/Controllers/search.py | 35 ++++++++++++--------- app/Models/query_params.py | 35 +++++++++++++++++++++ app/Services/authentication.py | 25 ++++++++++----- app/Services/vector_db_context.py | 51 +++++++++++++++++++++++++++++-- 4 files changed, 120 insertions(+), 26 deletions(-) create mode 100644 app/Models/query_params.py diff --git a/app/Controllers/search.py b/app/Controllers/search.py index 6b80393..ed39f06 100644 --- a/app/Controllers/search.py +++ b/app/Controllers/search.py @@ -9,6 +9,7 @@ from app.Models.api_model import AdvancedSearchModel, SearchBasisEnum from app.Models.api_response.search_api_response import SearchApiResponse +from app.Models.query_params import SearchPagingParams, FilterParams from app.Services import db_context from app.Services import transformers_service from app.Services.authentication import force_access_token_verify @@ -17,16 +18,6 @@ searchRouter = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None)) -class SearchPagingParams: - def __init__( - self, - count: Annotated[int, Query(ge=1, le=100, description="The number of results you want to get.")] = 10, - skip: Annotated[int, Query(ge=0, description="The number of results you want to skip.")] = 0 - ): - self.count = count - self.skip = skip - - class SearchBasisParams: def __init__(self, basis: Annotated[SearchBasisEnum, Query( @@ -41,6 +32,7 @@ async def textSearch( prompt: Annotated[ str, Path(min_length=3, max_length=100, description="The image prompt text you want to search.")], basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], + filter_param: Annotated[FilterParams, Depends(FilterParams)], paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)] ) -> SearchApiResponse: logger.info("Text search request received, prompt: {}", prompt) @@ -48,6 +40,7 @@ async def textSearch( else transformers_service.get_bert_vector(prompt) results = await db_context.querySearch(text_vector, query_vector_name=db_context.getVectorByBasis(basis.basis), + filter_param=filter_param, top_k=paging.count, skip=paging.skip) return SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()) @@ -57,12 +50,17 @@ async def textSearch( async def imageSearch( image: Annotated[bytes, File(max_length=10 * 1024 * 1024, media_type="image/*", description="The image you want to search.")], - paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: + filter_param: Annotated[FilterParams, Depends(FilterParams)], + paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)] +) -> SearchApiResponse: fakefile = BytesIO(image) img = Image.open(fakefile) logger.info("Image search request received") image_vector = transformers_service.get_image_vector(img) - results = await db_context.querySearch(image_vector, top_k=paging.count, skip=paging.skip) + results = await db_context.querySearch(image_vector, + top_k=paging.count, + skip=paging.skip, + filter_param=filter_param) return SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()) @@ -72,12 +70,14 @@ async def imageSearch( async def similarWith( id: Annotated[UUID, Path(description="The id of the image you want to search.")], basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], + filter_param: Annotated[FilterParams, Depends(FilterParams)], paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)] ) -> SearchApiResponse: logger.info("Similar search request received, id: {}", id) results = await db_context.querySimilar(str(id), top_k=paging.count, skip=paging.skip, + filter_param=filter_param, query_vector_name=db_context.getVectorByBasis(basis.basis)) return SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()) @@ -86,6 +86,7 @@ async def similarWith( async def advancedSearch( model: AdvancedSearchModel, basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], + filter_param: Annotated[FilterParams, Depends(FilterParams)], paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: if len(model.criteria) + len(model.negative_criteria) == 0: raise ValueError("At least one criteria should be provided.") @@ -98,16 +99,20 @@ async def advancedSearch( negative_vectors = [transformers_service.get_text_vector(t) for t in model.negative_criteria] result = await db_context.queryAdvanced(positive_vectors, negative_vectors, db_context.getVectorByBasis(basis.basis), model.mode, + filter_param=filter_param, top_k=paging.count, - skip=paging.skip) + skip=paging.skip + ) return SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()) @searchRouter.get("/random", description="Get random images") -async def randomPick(paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: +async def randomPick( + filter_param: Annotated[FilterParams, Depends(FilterParams)], + paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse: logger.info("Random pick request received") random_vector = transformers_service.get_random_vector() - result = await db_context.querySearch(random_vector, top_k=paging.count) + result = await db_context.querySearch(random_vector, top_k=paging.count, filter_param=filter_param) return SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()) diff --git a/app/Models/query_params.py b/app/Models/query_params.py new file mode 100644 index 0000000..8ee64b4 --- /dev/null +++ b/app/Models/query_params.py @@ -0,0 +1,35 @@ +from typing import Annotated + +from fastapi.params import Query + + +class SearchPagingParams: + def __init__( + self, + count: Annotated[int, Query(ge=1, le=100, description="The number of results you want to get.")] = 10, + skip: Annotated[int, Query(ge=0, description="The number of results you want to skip.")] = 0 + ): + self.count = count + self.skip = skip + + +class FilterParams: + def __init__( + self, + preferred_ratio: Annotated[ + float | None, Query(gt=0, description="The preferred aspect ratio of the image.")] = None, + ratio_tolerance: Annotated[ + float, Query(gt=0, lt=1, description="The tolerance of the aspect ratio.")] = 0.1, + min_width: Annotated[int | None, Query(gt=0, description="The minimum width of the image.")] = None, + min_height: Annotated[int | None, Query(gt=0, description="The minimum height of the image.")] = None): + self.preferred_ratio = preferred_ratio + self.ratio_tolerance = ratio_tolerance + self.min_width = min_width + self.min_height = min_height + + if self.preferred_ratio: + self.min_ratio = self.preferred_ratio * (1 - self.ratio_tolerance) + self.max_ratio = self.preferred_ratio * (1 + self.ratio_tolerance) + else: + self.min_ratio = None + self.max_ratio = None diff --git a/app/Services/authentication.py b/app/Services/authentication.py index aa0e27e..23ed0ac 100644 --- a/app/Services/authentication.py +++ b/app/Services/authentication.py @@ -1,7 +1,7 @@ from typing import Annotated from fastapi import HTTPException -from fastapi.params import Header +from fastapi.params import Header, Depends from app.config import config @@ -12,14 +12,23 @@ def verify_access_token(token: str | None) -> bool: return token is not None and token == config.access_token -def force_access_token_verify( - x_access_token: Annotated[str | None, Header( - description="Access token set in configuration (if access_protected is enabled)")] = None): - if not verify_access_token(x_access_token): - raise HTTPException(status_code=401, detail="Access token is not present or invalid.") - - def permissive_access_token_verify( x_access_token: Annotated[str | None, Header( description="Access token set in configuration (if access_protected is enabled)")] = None) -> bool: return verify_access_token(x_access_token) + + +def force_access_token_verify(token_passed: Annotated[bool, Depends(permissive_access_token_verify)]): + if not token_passed: + raise HTTPException(status_code=401, detail="Access token is not present or invalid.") + + +def permissive_admin_token_verify( + x_admin_token: Annotated[str | None, Header( + description="Admin token set in configuration (if admin_api_enable is enabled)")] = None) -> bool: + return x_admin_token == config.admin_token + + +def force_admin_token_verify(token_passed: Annotated[bool, Depends(permissive_admin_token_verify)]): + if not token_passed: + raise HTTPException(status_code=401, detail="Admin token is not present or invalid.") diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index 48b2169..786cf03 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -1,11 +1,13 @@ import numpy from loguru import logger from qdrant_client import AsyncQdrantClient +from qdrant_client.http import models from qdrant_client.http.models import PointStruct from qdrant_client.models import RecommendStrategy from app.Models.api_model import SearchModelEnum, SearchBasisEnum from app.Models.img_data import ImageData +from app.Models.query_params import FilterParams from app.Models.search_result import SearchResult from app.config import config @@ -26,23 +28,27 @@ async def retrieve_by_id(self, id: str, with_vectors=False) -> ImageData: return ImageData.from_payload(result[0].id, result[0].payload, numpy.array(result[0].vector, dtype=numpy.float32) if with_vectors else None) - async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR, top_k=10, skip=0) -> list[ + async def querySearch(self, query_vector, query_vector_name: str = IMG_VECTOR, + top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[ SearchResult]: logger.info("Querying Qdrant... top_k = {}", top_k) result = await self.client.search(collection_name=self.collection_name, query_vector=(query_vector_name, query_vector), + filters=self.getFiltersByFilterParam(filter_param), limit=top_k, offset=skip, with_payload=True) logger.success("Query completed!") return [SearchResult(img=ImageData.from_payload(t.id, t.payload), score=t.score) for t in result] - async def querySimilar(self, id: str, query_vector_name: str = IMG_VECTOR, top_k=10, skip=0) -> list[SearchResult]: + async def querySimilar(self, id: str, query_vector_name: str = IMG_VECTOR, + top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]: logger.info("Querying Qdrant... top_k = {}", top_k) result = await self.client.recommend(collection_name=self.collection_name, positive=[id], negative=[], using=query_vector_name, + query_filter=self.getFiltersByFilterParam(filter_param), limit=top_k, offset=skip, with_vectors=False, @@ -52,12 +58,13 @@ async def querySimilar(self, id: str, query_vector_name: str = IMG_VECTOR, top_k async def queryAdvanced(self, positive_vectors: list[numpy.ndarray], negative_vectors: list[numpy.ndarray], query_vector_name: str = IMG_VECTOR, mode: SearchModelEnum = SearchModelEnum.average, - top_k=10, skip=0) -> list[SearchResult]: + top_k=10, skip=0, filter_param: FilterParams | None = None) -> list[SearchResult]: logger.info("Querying Qdrant... top_k = {}", top_k) result = await self.client.recommend(collection_name=self.collection_name, using=query_vector_name, positive=[t.tolist() for t in positive_vectors], negative=[t.tolist() for t in negative_vectors], + query_filter=self.getFiltersByFilterParam(filter_param), limit=top_k, offset=skip, strategy= @@ -111,3 +118,41 @@ def getVectorByBasis(cls, basis: SearchBasisEnum) -> str: return cls.TEXT_VECTOR case _: raise ValueError("Invalid basis") + + @staticmethod + def getFiltersByFilterParam(filter_param: FilterParams | None) -> models.Filter | None: + if filter_param is None: + return None + + filters = [] + if filter_param.min_width is not None: + filters.append(models.FieldCondition( + key="width", + range=models.Range( + gte=filter_param.min_width + ) + )) + + if filter_param.min_height is not None: + filters.append(models.FieldCondition( + key="height", + range=models.Range( + gte=filter_param.min_height + ) + )) + + if filter_param.min_ratio is not None: + filters.append(models.FieldCondition( + key="aspect_ratio", + range=models.Range( + gte=filter_param.min_ratio, + lte=filter_param.max_ratio + ) + )) + + if len(filters) > 0: + return models.Filter( + must=filters + ) + else: + return None