Skip to content

Commit

Permalink
Merge pull request #39 from hv0905/fix_combined
Browse files Browse the repository at this point in the history
Improve tests and fix combined search
  • Loading branch information
hv0905 authored Jul 2, 2024
2 parents 14f2172 + 8a6111d commit fc7d148
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 111 deletions.
98 changes: 53 additions & 45 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from io import BytesIO
from typing import Annotated, List, Union
from typing import Annotated, List
from uuid import uuid4, UUID

from PIL import Image
from fastapi import APIRouter, HTTPException
from fastapi.params import File, Query, Path, Depends
from loguru import logger

from app.Models.api_models.search_api_model import AdvancedSearchModel, CombinedSearchModel, SearchBasisEnum, \
SearchCombinedBasisEnum
from app.Models.api_models.search_api_model import AdvancedSearchModel, CombinedSearchModel, SearchBasisEnum
from app.Models.api_response.search_api_response import SearchApiResponse
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
Expand All @@ -32,16 +31,6 @@ def __init__(self,
self.basis = basis


class SearchCombinedParams:
def __init__(self,
basis: Annotated[SearchCombinedBasisEnum, Query(
description="The primary basis used for searching the image.")] = SearchCombinedBasisEnum.vision):
if not config.ocr_search.enable:
raise HTTPException(400, "You used combined search, but it needs OCR search which is not "
"enabled.")
self.basis = basis


async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
if not config.storage.method.enabled:
return resp
Expand Down Expand Up @@ -74,7 +63,8 @@ async def textSearch(
if basis.basis == SearchBasisEnum.ocr and exact:
filter_param.ocr_text = prompt
results = await services.db_context.querySearch(text_vector,
query_vector_name=services.db_context.getVectorByBasis(basis.basis),
query_vector_name=services.db_context.vector_name_for_basis(
basis.basis),
filter_param=filter_param,
top_k=paging.count,
skip=paging.skip)
Expand Down Expand Up @@ -115,7 +105,7 @@ async def similarWith(
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param,
query_vector_name=services.db_context.getVectorByBasis(
query_vector_name=services.db_context.vector_name_for_basis(
basis.basis))
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))
Expand All @@ -136,11 +126,14 @@ async def advancedSearch(
@search_router.post("/combined", description="Search with combined criteria")
async def combinedSearch(
model: CombinedSearchModel,
basis: Annotated[SearchCombinedParams, Depends(SearchCombinedParams)],
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
if not config.ocr_search.enable:
raise HTTPException(400, "You used combined search, but it needs OCR search which is not "
"enabled.")
logger.info("Combined search request received: {}", model)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging, True)
calculate_and_sort_by_combined_scores(model, basis, result)
result = result[:paging.count] if len(result) > paging.count else result
return await result_postprocessing(
Expand All @@ -166,41 +159,56 @@ async def randomPick(
# async def recallQuery(query_id: str):
# raise NotImplementedError()

async def process_advanced_and_combined_search_query(model: Union[AdvancedSearchModel, CombinedSearchModel],
basis: Union[SearchBasisParams, SearchCombinedParams],
async def process_advanced_and_combined_search_query(model: AdvancedSearchModel,
basis: SearchBasisParams,
filter_param: FilterParams,
paging: SearchPagingParams) -> List[SearchResult]:
if basis.basis == SearchBasisEnum.ocr:
positive_vectors = [services.transformers_service.get_bert_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_bert_vector(t) for t in model.negative_criteria]
else:
positive_vectors = [services.transformers_service.get_text_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_text_vector(t) for t in model.negative_criteria]
paging: SearchPagingParams,
is_combined_search=False) -> List[SearchResult]:
match basis.basis:
case SearchBasisEnum.ocr:
positive_vectors = [services.transformers_service.get_bert_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_bert_vector(t) for t in model.negative_criteria]
case SearchBasisEnum.vision:
positive_vectors = [services.transformers_service.get_text_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_text_vector(t) for t in model.negative_criteria]
case _: # pragma: no cover
raise NotImplementedError()
# In order to ensure the query effect of the combined query, modify the actual top_k
_query_top_k = min(max(30, paging.count * 3), 100) if isinstance(model, CombinedSearchModel) else paging.count
result = await services.db_context.querySimilar(query_vector_name=services.db_context.getVectorByBasis(basis.basis),
positive_vectors=positive_vectors,
negative_vectors=negative_vectors,
mode=model.mode,
filter_param=filter_param,
with_vectors=True if isinstance(basis,
SearchCombinedParams) else False,
top_k=_query_top_k,
skip=paging.skip)
_query_top_k = min(max(30, paging.count * 3), 100) if is_combined_search else paging.count
result = await services.db_context.querySimilar(
query_vector_name=services.db_context.vector_name_for_basis(basis.basis),
positive_vectors=positive_vectors,
negative_vectors=negative_vectors,
mode=model.mode,
filter_param=filter_param,
with_vectors=is_combined_search,
top_k=_query_top_k,
skip=paging.skip)
return result


def calculate_and_sort_by_combined_scores(model: CombinedSearchModel,
basis: SearchCombinedParams,
basis: SearchBasisParams,
result: List[SearchResult]) -> None:
# First, calculate the extra prompt vector
extra_prompt_vector = services.transformers_service.get_text_vector(model.extra_prompt) \
if basis.basis == SearchCombinedBasisEnum.ocr \
else services.transformers_service.get_bert_vector(model.extra_prompt)
# Then, calculate combined_similar_score (original score * similar_score) and write to SearchResult.score
# Use a different method to calculate the extra prompt vector based on the basis
match basis.basis:
case SearchBasisEnum.ocr:
extra_prompt_vector = services.transformers_service.get_text_vector(model.extra_prompt)
case SearchBasisEnum.vision:
extra_prompt_vector = services.transformers_service.get_bert_vector(model.extra_prompt)
case _: # pragma: no cover
raise NotImplementedError()
# Calculate combined_similar_score (original score * similar_score) and write to SearchResult.score
for itm in result:
extra_vector = itm.img.image_vector if itm.img.image_vector is not None else itm.img.text_contain_vector
similar_score = calculate_vectors_cosine(extra_vector, extra_prompt_vector)
itm.score = similar_score * itm.score
match basis.basis:
case SearchBasisEnum.ocr:
extra_vector = itm.img.image_vector
case SearchBasisEnum.vision:
extra_vector = itm.img.text_contain_vector
case _: # pragma: no cover
raise NotImplementedError()
if extra_vector is not None:
similar_score = calculate_vectors_cosine(extra_vector, extra_prompt_vector)
itm.score = (1 + similar_score) * itm.score
# Finally, sort the result by combined_similar_score
result.sort(key=lambda i: i.score, reverse=True)
5 changes: 0 additions & 5 deletions app/Models/api_models/search_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ class SearchModelEnum(str, Enum):
best = "best"


class SearchCombinedBasisEnum(str, Enum):
vision = "vision"
ocr = "ocr"


class AdvancedSearchModel(BaseModel):
criteria: list[str] = Field([],
description="The positive criteria you want to search with",
Expand Down
4 changes: 2 additions & 2 deletions app/Services/vector_db_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def querySimilar(self,
mode == SearchModelEnum.average else RecommendStrategy.BEST_SCORE)
# since only combined_search need return vectors, We can define _combined_search_need_vectors like below
_combined_search_need_vectors = [
self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.IMG_VECTOR] if with_vectors else None
self.IMG_VECTOR if query_vector_name == self.TEXT_VECTOR else self.TEXT_VECTOR] if with_vectors else None
logger.info("Querying Qdrant... top_k = {}", top_k)
result = await self._client.recommend(collection_name=self.collection_name,
using=query_vector_name,
Expand Down Expand Up @@ -251,7 +251,7 @@ def _get_search_result_from_scored_point(self, point: models.ScoredPoint) -> Sea
return SearchResult(img=self._get_img_data_from_point(point), score=point.score)

@classmethod
def getVectorByBasis(cls, basis: SearchBasisEnum) -> str:
def vector_name_for_basis(cls, basis: SearchBasisEnum) -> str:
match basis:
case SearchBasisEnum.vision:
return cls.IMG_VECTOR
Expand Down
2 changes: 1 addition & 1 deletion app/util/fastapi_log_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from loguru import logger


class InterceptHandler(logging.Handler):
class InterceptHandler(logging.Handler): # pragma: no cover Hard to test in test environments

def emit(self, record: logging.LogRecord):
# Get corresponding Loguru level if it exists
Expand Down
19 changes: 15 additions & 4 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


@pytest.fixture(scope="session")
def test_client(tmp_path_factory) -> TestClient:
def unauthorized_test_client(tmp_path_factory) -> TestClient:
# Modify the configuration for testing
config.config.qdrant.mode = "memory"
config.config.admin_api_enable = True
Expand All @@ -27,9 +27,14 @@ def test_client(tmp_path_factory) -> TestClient:
yield client


@pytest.fixture()
@pytest.fixture(scope="module")
def test_client(unauthorized_test_client):
unauthorized_test_client.headers = {'x-access-token': TEST_ACCESS_TOKEN, 'x-admin-token': TEST_ADMIN_TOKEN}
yield unauthorized_test_client
unauthorized_test_client.headers = {}


def check_local_dir_empty():
yield
dir = Path(config.config.storage.local.path)
files = [f for f in dir.glob('*.*') if f.is_file()]
assert len(files) == 0
Expand All @@ -41,10 +46,16 @@ def check_local_dir_empty():


@pytest.fixture()
def ensure_local_dir_empty():
yield
check_local_dir_empty()


@pytest.fixture(scope="module")
def wait_for_background_task(test_client):
async def func(expected_image_count):
while True:
resp = test_client.get('/admin/server_info', headers={'x-admin-token': TEST_ADMIN_TOKEN})
resp = test_client.get('/admin/server_info')
if resp.json()['image_count'] >= expected_image_count:
break
await asyncio.sleep(0.2)
Expand Down
13 changes: 7 additions & 6 deletions tests/api/test_home.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
class TestHome:

def test_get_home_no_tokens(self, test_client):
response = test_client.get("/")
def test_get_home_no_tokens(self, unauthorized_test_client):
response = unauthorized_test_client.get("/")
assert response.status_code == 200
assert response.json()['authorization']['required']
assert not response.json()['authorization']['passed']
assert response.json()['admin_api']['available']
assert not response.json()['admin_api']['passed']

def test_get_home_access_token(self, test_client):
response = test_client.get("/", headers={'x-access-token': 'test_token'})
def test_get_home_access_token(self, unauthorized_test_client):
response = unauthorized_test_client.get("/", headers={'x-access-token': 'test_token'})
assert response.status_code == 200
assert response.json()['authorization']['required']
assert response.json()['authorization']['passed']

def test_get_home_admin_token(self, test_client):
response = test_client.get("/", headers={'x-admin-token': 'test_admin_token', 'x-access-token': 'test_token'})
def test_get_home_admin_token(self, unauthorized_test_client):
response = unauthorized_test_client.get("/", headers={'x-admin-token': 'test_admin_token',
'x-access-token': 'test_token'})
assert response.status_code == 200
assert response.json()['admin_api']['available']
assert response.json()['admin_api']['passed']
Expand Down
77 changes: 46 additions & 31 deletions tests/api/integrate_test.py → tests/api/test_search.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import pytest
import pytest_asyncio

from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN
from .conftest import check_local_dir_empty
from ..assets import assets_path

test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'],
'cat': ['cat_0.jpg', 'cat_1.jpg'],
'cg': ['cg_0.jpg', 'cg_1.png']}


@pytest.mark.asyncio
async def test_search(test_client, check_local_dir_empty, wait_for_background_task):
credentials = {'x-admin-token': TEST_ADMIN_TOKEN, 'x-access-token': TEST_ACCESS_TOKEN}
resp = test_client.get("/", headers=credentials)
assert resp.status_code == 200
@pytest_asyncio.fixture(scope="module")
async def img_ids(test_client, wait_for_background_task):
img_ids = {}
for img_cls, item_images in test_images.items():
img_ids[img_cls] = []
Expand All @@ -21,7 +18,6 @@ async def test_search(test_client, check_local_dir_empty, wait_for_background_ta
with open(assets_path / 'test_images' / image, 'rb') as f:
resp = test_client.post('/admin/upload',
files={'image_file': f},
headers=credentials,
params={'local': True})
assert resp.status_code == 200
img_ids[img_cls].append(resp.json()['image_id'])
Expand All @@ -30,50 +26,69 @@ async def test_search(test_client, check_local_dir_empty, wait_for_background_ta

await wait_for_background_task(sum(len(v) for v in test_images.values()))

resp = test_client.get('/search/text/hatsune+miku',
headers=credentials)
yield img_ids

# cleanup
for img_cls in test_images.keys():
for img_id in img_ids[img_cls]:
resp = test_client.delete(f"/admin/delete/{img_id}")
assert resp.status_code == 200

check_local_dir_empty()


def test_search_text(test_client, img_ids):
resp = test_client.get('/search/text/hatsune+miku')
assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['cg']


def test_search_image(test_client, img_ids):
with open(assets_path / 'test_images' / test_images['cat'][0], 'rb') as f:
resp = test_client.post('/search/image',
files={'image': f},
headers=credentials)
files={'image': f})

assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['cat']

resp = test_client.get(f"/search/similar/{img_ids['bsn'][0]}",
headers=credentials)

def test_search_similar(test_client, img_ids):
resp = test_client.get(f"/search/similar/{img_ids['bsn'][0]}")

assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['bsn']

image_request = test_client.get(resp.json()['result'][0]['img']['url'])
assert image_request.status_code == 200
assert image_request.headers['Content-Type'] == 'image/jpeg'

resp = test_client.put(f"/admin/update_opt/{img_ids['bsn'][0]}", json={'categories': ['bsn'], 'starred': True},
headers=credentials)
def test_search_advanced(test_client, img_ids):
resp = test_client.post("/search/advanced",
json={'criteria': ['white background', 'grayscale image'],
'negative_criteria': ['cat', 'hatsune miku']})
assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['bsn']


def test_search_combined(test_client, img_ids):
resp = test_client.post('/search/combined', json={'criteria': ['hatsune miku'],
'negative_criteria': ['grayscale image', 'cat'],
'extra_prompt': 'hatsunemiku'})

resp = test_client.get("/search/text/cat", params={'categories': 'bsn'}, headers=credentials)
assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['bsn']
assert resp.json()['result'][0]['img']['id'] == img_ids['cg'][1]

resp = test_client.get("/search/text/cat", params={'starred': True}, headers=credentials)
resp = test_client.post('/search/combined?basis=ocr',
json={'criteria': ['hatsunemiku'], 'extra_prompt': 'hatsune miku'})
assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] in img_ids['bsn']
assert resp.json()['result'][0]['img']['id'] == img_ids['cg'][1]

resp = test_client.delete(f"/admin/delete/{img_ids['bsn'][0]}", headers=credentials)

def test_search_filters(test_client, img_ids):
resp = test_client.put(f"/admin/update_opt/{img_ids['bsn'][0]}", json={'categories': ['bsn'], 'starred': True})
assert resp.status_code == 200

resp = test_client.get("/search/text/cat", params={'categories': 'bsn'}, headers=credentials)
resp = test_client.get("/search/text/cat", params={'categories': 'bsn'})
assert resp.status_code == 200
assert len(resp.json()['result']) == 0
assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0]

# cleanup
for img_cls in test_images.keys():
for img_id in img_ids[img_cls]:
resp = test_client.delete(f"/admin/delete/{img_id}", headers=credentials)
assert resp.status_code == (404 if img_id == img_ids['bsn'][0] else 200)
resp = test_client.get("/search/text/cat", params={'starred': True})
assert resp.status_code == 200
assert resp.json()['result'][0]['img']['id'] == img_ids['bsn'][0]
Loading

0 comments on commit fc7d148

Please sign in to comment.