Skip to content

Commit

Permalink
Refactor service provider
Browse files Browse the repository at this point in the history
  • Loading branch information
hv0905 committed May 6, 2024
1 parent d106ade commit 0f8aed6
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 120 deletions.
25 changes: 13 additions & 12 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,39 +14,40 @@
from app.Models.api_response.base import NekoProtocol
from app.Models.img_data import ImageData
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import db_context, storage_service, upload_service
from app.Services.provider import ServiceProvider
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util.generate_uuid import generate_uuid

admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)], tags=["Admin"])

services: ServiceProvider | None = None

@admin_router.delete("/delete/{image_id}",
description="Delete image with the given id from database. "
"If the image is a local image, it will be moved to `/static/_deleted` folder.")
async def delete_image(
image_id: Annotated[UUID, params.Path(description="The id of the image you want to delete.")]) -> NekoProtocol:
try:
point = await db_context.retrieve_by_id(str(image_id))
point = await services.db_context.retrieve_by_id(str(image_id))
except PointNotFoundError as ex:
raise HTTPException(404, "Cannot find the image with the given ID.") from ex
await db_context.deleteItems([str(point.id)])
await services.db_context.deleteItems([str(point.id)])
logger.success("Image {} deleted from database.", point.id)

if point.local and config.storage.method.enabled: # local image
image_files = [itm[0] async for itm in storage_service.active_storage.list_files("", f"{point.id}.*")]
image_files = [itm[0] async for itm in services.storage_service.active_storage.list_files("", f"{point.id}.*")]
assert len(image_files) <= 1

if not image_files:
logger.warning("Image {} is a local image but not found in static folder.", point.id)
else:
await storage_service.active_storage.move(image_files[0], f"_deleted/{image_files[0].name}")
await services.storage_service.active_storage.move(image_files[0], f"_deleted/{image_files[0].name}")
logger.success("Image {} removed.", image_files[0].name)
if point.thumbnail_url is not None:
thumbnail_file = PurePath(f"thumbnails/{point.id}.webp")
if await storage_service.active_storage.is_exist(thumbnail_file):
await storage_service.active_storage.delete(thumbnail_file)
if await services.storage_service.active_storage.is_exist(thumbnail_file):
await services.storage_service.active_storage.delete(thumbnail_file)
logger.success("Thumbnail {} removed.", thumbnail_file.name)
else:
logger.warning("Thumbnail {} not found.", thumbnail_file.name)
Expand All @@ -60,7 +61,7 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id
if model.empty():
raise HTTPException(422, "Nothing to update.")
try:
point = await db_context.retrieve_by_id(str(image_id))
point = await services.db_context.retrieve_by_id(str(image_id))
except PointNotFoundError as ex:
raise HTTPException(404, "Cannot find the image with the given ID.") from ex

Expand All @@ -69,7 +70,7 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id
if model.categories is not None:
point.categories = model.categories

await db_context.updatePayload(point)
await services.db_context.updatePayload(point)
logger.success("Image {} updated.", point.id)

return NekoProtocol(message="Image updated.")
Expand Down Expand Up @@ -99,7 +100,7 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
raise HTTPException(415, "Unsupported image format.")
img_bytes = await image_file.read()
img_id = generate_uuid(img_bytes)
if len(await db_context.validate_ids([str(img_id)])) != 0: # check for duplicate points
if len(await services.db_context.validate_ids([str(img_id)])) != 0: # check for duplicate points
raise HTTPException(409, f"The uploaded point is already contained in the database! entity id: {img_id}")

try:
Expand All @@ -116,11 +117,11 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
format=img_type,
index_date=datetime.now())

await upload_service.upload_image(image, image_data, img_bytes, model.skip_ocr)
await services.upload_service.upload_image(image, image_data, img_bytes, model.skip_ocr)
return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id)


@admin_router.get("/server_info", description="Get server information")
async def server_info():
return ServerInfoResponse(message="Successfully get server information!",
image_count=await db_context.get_counts(exact=True))
image_count=await services.db_context.get_counts(exact=True))
111 changes: 58 additions & 53 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,14 @@
from app.Models.query_params import SearchPagingParams, FilterParams
from app.Models.search_result import SearchResult
from app.Services.authentication import force_access_token_verify
from app.Services.provider import db_context, transformers_service, storage_service
from app.Services.provider import ServiceProvider
from app.config import config
from app.util.calculate_vectors_cosine import calculate_vectors_cosine

searchRouter = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None),
search_router = APIRouter(dependencies=([Depends(force_access_token_verify)] if config.access_protected else None),
tags=["Search"])


async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
for item in resp.result:
if item.img.local and config.storage.method.enabled:
img_extension = item.img.format or item.img.url.split('.')[-1]
img_remote_filename = f"{item.img.id}.{img_extension}"
item.img.url = await storage_service.active_storage.presign_url(img_remote_filename)
if item.img.thumbnail_url is not None:
thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp"
item.img.thumbnail_url = await storage_service.active_storage.presign_url(thumbnail_remote_filename)
return resp
services: ServiceProvider | None = None # The service provider will be injected in the webapp initialize


class SearchBasisParams:
Expand All @@ -52,7 +42,20 @@ def __init__(self,
self.basis = basis


@searchRouter.get("/text/{prompt}", description="Search images by text prompt")
async def result_postprocessing(resp: SearchApiResponse) -> SearchApiResponse:
for item in resp.result:
if item.img.local and config.storage.method.enabled:
img_extension = item.img.format or item.img.url.split('.')[-1]
img_remote_filename = f"{item.img.id}.{img_extension}"
item.img.url = await services.storage_service.active_storage.presign_url(img_remote_filename)
if item.img.thumbnail_url is not None:
thumbnail_remote_filename = f"thumbnails/{item.img.id}.webp"
item.img.thumbnail_url = await services.storage_service.active_storage.presign_url(
thumbnail_remote_filename)
return resp


@search_router.get("/text/{prompt}", description="Search images by text prompt")
async def textSearch(
prompt: Annotated[
str, Path(max_length=100, description="The image prompt text you want to search.")],
Expand All @@ -64,20 +67,20 @@ async def textSearch(
"criteria you have given. This won't take any effect in vision search.")] = False
) -> SearchApiResponse:
logger.info("Text search request received, prompt: {}", prompt)
text_vector = transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \
else transformers_service.get_bert_vector(prompt)
text_vector = services.transformers_service.get_text_vector(prompt) if basis.basis == SearchBasisEnum.vision \
else services.transformers_service.get_bert_vector(prompt)
if basis.basis == SearchBasisEnum.ocr and exact:
filter_param.ocr_text = prompt
results = await db_context.querySearch(text_vector,
query_vector_name=db_context.getVectorByBasis(basis.basis),
filter_param=filter_param,
top_k=paging.count,
skip=paging.skip)
results = await services.db_context.querySearch(text_vector,
query_vector_name=services.db_context.getVectorByBasis(basis.basis),
filter_param=filter_param,
top_k=paging.count,
skip=paging.skip)
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))


@searchRouter.post("/image", description="Search images by image")
@search_router.post("/image", description="Search images by image")
async def imageSearch(
image: Annotated[bytes, File(max_length=10 * 1024 * 1024, media_type="image/*",
description="The image you want to search.")],
Expand All @@ -87,16 +90,16 @@ async def imageSearch(
fakefile = BytesIO(image)
img = Image.open(fakefile)
logger.info("Image search request received")
image_vector = transformers_service.get_image_vector(img)
results = await db_context.querySearch(image_vector,
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param)
image_vector = services.transformers_service.get_image_vector(img)
results = await services.db_context.querySearch(image_vector,
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param)
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))


@searchRouter.get("/similar/{image_id}",
@search_router.get("/similar/{image_id}",
description="Search images similar to the image with given id. "
"Won't include the given image itself in the result.")
async def similarWith(
Expand All @@ -106,16 +109,17 @@ async def similarWith(
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]
) -> SearchApiResponse:
logger.info("Similar search request received, id: {}", image_id)
results = await db_context.querySimilar(search_id=str(image_id),
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param,
query_vector_name=db_context.getVectorByBasis(basis.basis))
results = await services.db_context.querySimilar(search_id=str(image_id),
top_k=paging.count,
skip=paging.skip,
filter_param=filter_param,
query_vector_name=services.db_context.getVectorByBasis(
basis.basis))
return await result_postprocessing(
SearchApiResponse(result=results, message=f"Successfully get {len(results)} results.", query_id=uuid4()))


@searchRouter.post("/advanced", description="Search with multiple criteria")
@search_router.post("/advanced", description="Search with multiple criteria")
async def advancedSearch(
model: AdvancedSearchModel,
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
Expand All @@ -129,7 +133,7 @@ async def advancedSearch(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


@searchRouter.post("/combined", description="Search with combined criteria")
@search_router.post("/combined", description="Search with combined criteria")
async def combinedSearch(
model: CombinedSearchModel,
basis: Annotated[SearchCombinedParams, Depends(SearchCombinedParams)],
Expand All @@ -145,18 +149,18 @@ async def combinedSearch(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


@searchRouter.get("/random", description="Get random images")
@search_router.get("/random", description="Get random images")
async def randomPick(
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
logger.info("Random pick request received")
random_vector = transformers_service.get_random_vector()
result = await db_context.querySearch(random_vector, top_k=paging.count, filter_param=filter_param)
random_vector = services.transformers_service.get_random_vector()
result = await services.db_context.querySearch(random_vector, top_k=paging.count, filter_param=filter_param)
return await result_postprocessing(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


@searchRouter.get("/recall/{query_id}", description="Recall the query with given queryId")
@search_router.get("/recall/{query_id}", description="Recall the query with given queryId")
async def recallQuery(query_id: str):
raise NotImplementedError()

Expand All @@ -166,31 +170,32 @@ async def process_advanced_and_combined_search_query(model: Union[AdvancedSearch
filter_param: FilterParams,
paging: SearchPagingParams) -> List[SearchResult]:
if basis.basis == SearchBasisEnum.ocr:
positive_vectors = [transformers_service.get_bert_vector(t) for t in model.criteria]
negative_vectors = [transformers_service.get_bert_vector(t) for t in model.negative_criteria]
positive_vectors = [services.transformers_service.get_bert_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_bert_vector(t) for t in model.negative_criteria]
else:
positive_vectors = [transformers_service.get_text_vector(t) for t in model.criteria]
negative_vectors = [transformers_service.get_text_vector(t) for t in model.negative_criteria]
positive_vectors = [services.transformers_service.get_text_vector(t) for t in model.criteria]
negative_vectors = [services.transformers_service.get_text_vector(t) for t in model.negative_criteria]
# In order to ensure the query effect of the combined query, modify the actual top_k
_query_top_k = min(max(30, paging.count * 3), 100) if isinstance(model, CombinedSearchModel) else paging.count
result = await db_context.querySimilar(query_vector_name=db_context.getVectorByBasis(basis.basis),
positive_vectors=positive_vectors,
negative_vectors=negative_vectors,
mode=model.mode,
filter_param=filter_param,
with_vectors=True if isinstance(basis, SearchCombinedParams) else False,
top_k=_query_top_k,
skip=paging.skip)
result = await services.db_context.querySimilar(query_vector_name=services.db_context.getVectorByBasis(basis.basis),
positive_vectors=positive_vectors,
negative_vectors=negative_vectors,
mode=model.mode,
filter_param=filter_param,
with_vectors=True if isinstance(basis,
SearchCombinedParams) else False,
top_k=_query_top_k,
skip=paging.skip)
return result


def calculate_and_sort_by_combined_scores(model: CombinedSearchModel,
basis: SearchCombinedParams,
result: List[SearchResult]) -> None:
# First, calculate the extra prompt vector
extra_prompt_vector = transformers_service.get_text_vector(model.extra_prompt) \
extra_prompt_vector = services.transformers_service.get_text_vector(model.extra_prompt) \
if basis.basis == SearchCombinedBasisEnum.ocr \
else transformers_service.get_bert_vector(model.extra_prompt)
else services.transformers_service.get_bert_vector(model.extra_prompt)
# Then, calculate combined_similar_score (original score * similar_score) and write to SearchResult.score
for itm in result:
extra_vector = itm.img.image_vector if itm.img.image_vector is not None else itm.img.text_contain_vector
Expand Down
Loading

0 comments on commit 0f8aed6

Please sign in to comment.