From 1694660eb2d357c31c8a02e2519f404f7debfd44 Mon Sep 17 00:00:00 2001 From: EdgeNeko Date: Tue, 26 Dec 2023 23:48:02 +0800 Subject: [PATCH] Allow exact match in OCR text search --- app/Controllers/search.py | 17 +++++++++++------ app/Models/query_params.py | 1 + app/Services/vector_db_context.py | 8 ++++++++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/app/Controllers/search.py b/app/Controllers/search.py index 3546dd7..fa7904f 100644 --- a/app/Controllers/search.py +++ b/app/Controllers/search.py @@ -46,11 +46,16 @@ async def textSearch( str, Path(max_length=100, description="The image prompt text you want to search.")], basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], filter_param: Annotated[FilterParams, Depends(FilterParams)], - paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)] + paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)], + exact: Annotated[bool, Query( + description="If using OCR search, this option will require the ocr text contains **exactly** the " + "criteria you have given. This won't take any effect in vision search.")] = False ) -> SearchApiResponse: logger.info("Text search request received, prompt: {}", prompt) text_vector = transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \ else transformers_service.get_bert_vector(prompt) + if basis.basis == SearchBasisEnum.ocr and exact: + filter_param.ocr_text = prompt results = await db_context.querySearch(text_vector, query_vector_name=db_context.getVectorByBasis(basis.basis), filter_param=filter_param, @@ -77,17 +82,17 @@ async def imageSearch( return SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()) -@searchRouter.get("/similar/{id}", +@searchRouter.get("/similar/{image_id}", description="Search images similar to the image with given id. " "Won't include the given image itself in the result.") async def similarWith( - id: Annotated[UUID, Path(description="The id of the image you want to search.")], + image_id: Annotated[UUID, Path(description="The id of the image you want to search.")], basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)], filter_param: Annotated[FilterParams, Depends(FilterParams)], paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)] ) -> SearchApiResponse: - logger.info("Similar search request received, id: {}", id) - results = await db_context.querySimilar(search_id=str(id), + logger.info("Similar search request received, id: {}", image_id) + results = await db_context.querySimilar(search_id=str(image_id), top_k=paging.count, skip=paging.skip, filter_param=filter_param, @@ -149,7 +154,7 @@ async def process_advanced_and_combined_search_query(model: Union[AdvancedSearch positive_vectors = [transformers_service.get_text_vector(t) for t in model.criteria] negative_vectors = [transformers_service.get_text_vector(t) for t in model.negative_criteria] # In order to ensure the query effect of the combined query, modify the actual top_k - _query_top_k = min(max(30, paging.count*3), 100) if isinstance(model, CombinedSearchModel) else paging.count + _query_top_k = min(max(30, paging.count * 3), 100) if isinstance(model, CombinedSearchModel) else paging.count result = await db_context.querySimilar(query_vector_name=db_context.getVectorByBasis(basis.basis), positive_vectors=positive_vectors, negative_vectors=negative_vectors, diff --git a/app/Models/query_params.py b/app/Models/query_params.py index 6a1b857..02655a4 100644 --- a/app/Models/query_params.py +++ b/app/Models/query_params.py @@ -28,6 +28,7 @@ def __init__( self.min_width = min_width self.min_height = min_height self.starred = starred + self.ocr_text = None # For exact search if self.preferred_ratio: self.min_ratio = self.preferred_ratio * (1 - self.ratio_tolerance) diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index 248de15..5615ea9 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -188,6 +188,14 @@ def getFiltersByFilterParam(filter_param: FilterParams | None) -> models.Filter ) )) + if filter_param.ocr_text is not None: + filters.append(models.FieldCondition( + key="ocr_text", + match=models.MatchText( + text=filter_param.ocr_text + ) + )) + if len(filters) == 0: return None return models.Filter(