Skip to content

Commit

Permalink
Allow exact match in OCR text search
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed Dec 26, 2023
1 parent f065642 commit 1694660
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
17 changes: 11 additions & 6 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,16 @@ async def textSearch(
str, Path(max_length=100, description="The image prompt text you want to search.")],
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
exact: Annotated[bool, Query(
description="If using OCR search, this option will require the ocr text contains **exactly** the "
"criteria you have given. This won't take any effect in vision search.")] = False
) -> SearchApiResponse:
logger.info("Text search request received, prompt: {}", prompt)
text_vector = transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \
else transformers_service.get_bert_vector(prompt)
if basis.basis == SearchBasisEnum.ocr and exact:
filter_param.ocr_text = prompt
results = await db_context.querySearch(text_vector,
query_vector_name=db_context.getVectorByBasis(basis.basis),
filter_param=filter_param,
Expand All @@ -77,17 +82,17 @@ async def imageSearch(
return SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4())


@searchRouter.get("/similar/{id}",
@searchRouter.get("/similar/{image_id}",
description="Search images similar to the image with given id. "
"Won't include the given image itself in the result.")
async def similarWith(
id: Annotated[UUID, Path(description="The id of the image you want to search.")],
image_id: Annotated[UUID, Path(description="The id of the image you want to search.")],
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
logger.info("Similar search request received, id: {}", id)
results = await db_context.querySimilar(search_id=str(id),
logger.info("Similar search request received, id: {}", image_id)
results = await db_context.querySimilar(search_id=str(image_id),
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param,
Expand Down Expand Up @@ -149,7 +154,7 @@ async def process_advanced_and_combined_search_query(model: Union[AdvancedSearch
positive_vectors = [transformers_service.get_text_vector(t) for t in model.criteria]
negative_vectors = [transformers_service.get_text_vector(t) for t in model.negative_criteria]
# In order to ensure the query effect of the combined query, modify the actual top_k
_query_top_k = min(max(30, paging.count*3), 100) if isinstance(model, CombinedSearchModel) else paging.count
_query_top_k = min(max(30, paging.count * 3), 100) if isinstance(model, CombinedSearchModel) else paging.count
result = await db_context.querySimilar(query_vector_name=db_context.getVectorByBasis(basis.basis),
positive_vectors=positive_vectors,
negative_vectors=negative_vectors,
Expand Down
1 change: 1 addition & 0 deletions app/Models/query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
self.min_width = min_width
self.min_height = min_height
self.starred = starred
self.ocr_text = None # For exact search

if self.preferred_ratio:
self.min_ratio = self.preferred_ratio * (1 - self.ratio_tolerance)
Expand Down
8 changes: 8 additions & 0 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ def getFiltersByFilterParam(filter_param: FilterParams | None) -> models.Filter
)
))

if filter_param.ocr_text is not None:
filters.append(models.FieldCondition(
key="ocr_text",
match=models.MatchText(
text=filter_param.ocr_text
)
))

if len(filters) == 0:
return None
return models.Filter(
Expand Down

0 comments on commit 1694660

Please sign in to comment.