Skip to content

Commit

Permalink
Merge pull request #38 from hv0905/dup_validate
Browse files Browse the repository at this point in the history
feat: Add duplication validate API
  • Loading branch information
hv0905 authored Jun 26, 2024
2 parents 2648e51 + dd8d0d3 commit a4fdb86
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 27 deletions.
26 changes: 21 additions & 5 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
from fastapi import APIRouter, Depends, HTTPException, params, UploadFile, File
from loguru import logger

from app.Models.api_models.admin_api_model import ImageOptUpdateModel
from app.Models.api_models.admin_api_model import ImageOptUpdateModel, DuplicateValidationModel
from app.Models.api_models.admin_query_params import UploadImageModel
from app.Models.api_response.admin_api_response import ServerInfoResponse, ImageUploadResponse
from app.Models.api_response.admin_api_response import ServerInfoResponse, ImageUploadResponse, \
DuplicateValidationResponse
from app.Models.api_response.base import NekoProtocol
from app.Models.img_data import ImageData
from app.Services.authentication import force_admin_token_verify
from app.Services.provider import ServiceProvider
from app.Services.vector_db_context import PointNotFoundError
from app.config import config
from app.util.generate_uuid import generate_uuid
from app.util.generate_uuid import generate_uuid, generate_uuid_from_sha1

admin_router = APIRouter(dependencies=[Depends(force_admin_token_verify)], tags=["Admin"])

Expand Down Expand Up @@ -98,7 +99,7 @@ async def update_image(image_id: Annotated[UUID, params.Path(description="The id
description="Upload image to server. The image will be indexed and stored in the database. If "
"local is set to true, the image will be uploaded to local storage.")
async def upload_image(image_file: Annotated[UploadFile, File(description="The image to be uploaded.")],
model: Annotated[UploadImageModel, Depends()]):
model: Annotated[UploadImageModel, Depends()]) -> ImageUploadResponse:
# generate an ID for the image
img_type = None
if image_file.content_type.lower() in IMAGE_MIMES:
Expand Down Expand Up @@ -140,7 +141,22 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i


@admin_router.get("/server_info", description="Get server information")
async def server_info():
async def server_info() -> ServerInfoResponse:
return ServerInfoResponse(message="Successfully get server information!",
image_count=await services.db_context.get_counts(exact=True),
index_queue_length=services.upload_service.get_queue_size())


@admin_router.post("/duplication_validate",
description="Check if an image exists in the server by its SHA1 hash. If the image exists, "
"the image ID will be returned.\n"
"This is helpful for checking if an image is already in the server without "
"uploading the image.")
async def duplication_validate(model: DuplicateValidationModel) -> DuplicateValidationResponse:
ids = [generate_uuid_from_sha1(t) for t in model.hashes]
valid_ids = await services.db_context.validate_ids([str(t) for t in ids])
exists_matrix = [str(t) in valid_ids or t in services.upload_service.uploading_ids for t in ids]
return DuplicateValidationResponse(
exists=exists_matrix,
entity_ids=[(str(t) if exists else None) for (t, exists) in zip(ids, exists_matrix)],
message="Validation completed.")
11 changes: 3 additions & 8 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ async def advancedSearch(
basis: Annotated[SearchBasisParams, Depends(SearchBasisParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
if len(model.criteria) + len(model.negative_criteria) == 0:
raise HTTPException(status_code=422, detail="At least one criteria should be provided.")
logger.info("Advanced search request received: {}", model)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging)
return await result_postprocessing(
Expand All @@ -141,8 +139,6 @@ async def combinedSearch(
basis: Annotated[SearchCombinedParams, Depends(SearchCombinedParams)],
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
if len(model.criteria) + len(model.negative_criteria) == 0:
raise HTTPException(status_code=422, detail="At least one criteria should be provided.")
logger.info("Combined search request received: {}", model)
result = await process_advanced_and_combined_search_query(model, basis, filter_param, paging)
calculate_and_sort_by_combined_scores(model, basis, result)
Expand All @@ -166,10 +162,9 @@ async def randomPick(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))


@search_router.get("/recall/{query_id}", description="Recall the query with given queryId")
async def recallQuery(query_id: str):
raise NotImplementedError()

# @search_router.get("/recall/{query_id}", description="Recall the query with given queryId")
# async def recallQuery(query_id: str):
# raise NotImplementedError()

async def process_advanced_and_combined_search_query(model: Union[AdvancedSearchModel, CombinedSearchModel],
basis: Union[SearchBasisParams, SearchCombinedParams],
Expand Down
12 changes: 10 additions & 2 deletions app/Models/api_models/admin_api_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional
from typing import Optional, Annotated

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, StringConstraints


class ImageOptUpdateModel(BaseModel):
Expand All @@ -21,3 +21,11 @@ class ImageOptUpdateModel(BaseModel):

def empty(self) -> bool:
return all([item is None for item in self.model_dump().values()])


Sha1HashString = Annotated[
str, StringConstraints(min_length=40, max_length=40, pattern=r"[0-9a-f]+", to_lower=True, strip_whitespace=True)]


class DuplicateValidationModel(BaseModel):
hashes: list[Sha1HashString] = Field(description="The SHA1 hash of the image.", min_length=1)
9 changes: 7 additions & 2 deletions app/Models/api_models/search_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@ class SearchCombinedBasisEnum(str, Enum):


class AdvancedSearchModel(BaseModel):
criteria: list[str] = Field([], description="The positive criteria you want to search with", max_length=16)
negative_criteria: list[str] = Field([], description="The negative criteria you want to search with", max_length=16)
criteria: list[str] = Field([],
description="The positive criteria you want to search with",
max_length=16,
min_length=1)
negative_criteria: list[str] = Field([],
description="The negative criteria you want to search with",
max_length=16)
mode: SearchModelEnum = Field(SearchModelEnum.average,
description="The mode you want to use to combine the criteria.")

Expand Down
9 changes: 9 additions & 0 deletions app/Models/api_response/admin_api_response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from uuid import UUID

from pydantic import Field

from .base import NekoProtocol


Expand All @@ -8,5 +10,12 @@ class ServerInfoResponse(NekoProtocol):
index_queue_length: int


class DuplicateValidationResponse(NekoProtocol):
entity_ids: list[UUID | None] = Field(
description="The image id for each hash. If the image does not exist in the server, the value will be null.")
exists: list[bool] = Field(
description="Whether the image exists in the server. True if the image exists, False otherwise.")


class ImageUploadResponse(NekoProtocol):
image_id: UUID
6 changes: 5 additions & 1 deletion app/util/generate_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ def generate_uuid(file_input: pathlib.Path | io.BytesIO | bytes) -> UUID:
else:
raise ValueError("Unsupported file type. Must be pathlib.Path or io.BytesIO.")
file_hash = hashlib.sha1(file_content).hexdigest()
return uuid5(namespace_uuid, file_hash)
return generate_uuid_from_sha1(file_hash)


def generate_uuid_from_sha1(sha1_hash: str) -> UUID:
return uuid5(namespace_uuid, sha1_hash.lower())
2 changes: 0 additions & 2 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
TEST_ACCESS_TOKEN = 'test_token'
TEST_ADMIN_TOKEN = 'test_admin_token'

assets_path = Path(__file__).parent / '..' / 'assets'


@pytest.fixture(scope="session")
def test_client(tmp_path_factory) -> TestClient:
Expand Down
3 changes: 2 additions & 1 deletion tests/api/integrate_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN, assets_path
from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN
from ..assets import assets_path

test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'],
'cat': ['cat_0.jpg', 'cat_1.jpg'],
Expand Down
38 changes: 32 additions & 6 deletions tests/api/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

import pytest

from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN, assets_path
from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN
from ..assets import assets_path

test_file_path = assets_path / 'test_images' / 'bsn_0.jpg'
test_file_2_path = assets_path / 'test_images' / 'bsn_1.jpg'

test_file_hashes = ['648351F7CBD472D0CA23EADCCF3B9E619EC9ADDA', 'C5DE90DAC2F75FBDBE48023DF4DE7585A86B2392']


def get_single_img_info(test_client, image_id):
Expand Down Expand Up @@ -46,15 +50,37 @@ def upload(file):
headers={'x-admin-token': TEST_ADMIN_TOKEN},
params={'local': True})

def validate(hashes):
return test_client.post('/admin/duplication_validate',
json={'hashes': hashes},
headers={'x-admin-token': TEST_ADMIN_TOKEN})

with open(test_file_path, 'rb') as f:
# Validate 1#
val_resp = validate(test_file_hashes)
assert val_resp.status_code == 200
assert val_resp.json()['exists'] == [False, False]
assert val_resp.json()['entity_ids'] == [None, None]

# Upload
resp = upload(f)
assert resp.status_code == 200
image_id = resp.json()['image_id']
resp = upload(f) # The previous image is still in queue
assert resp.status_code == 409
await wait_for_background_task(1)
resp = upload(f) # The previous image is indexed now
assert resp.status_code == 409

for i in range(0, 2):
# Re-upload
resp = upload(f)
assert resp.status_code == 409, i

# Validate
val_resp = validate(test_file_hashes)
assert val_resp.status_code == 200, i
assert val_resp.json()['exists'] == [True, False], i
assert val_resp.json()['entity_ids'] == [str(image_id), None], i

# Wait for the image to be indexed
if i == 0:
await wait_for_background_task(1)

# cleanup
resp = test_client.delete(f'/admin/delete/{image_id}', headers={'x-admin-token': TEST_ADMIN_TOKEN})
Expand Down
3 changes: 3 additions & 0 deletions tests/assets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pathlib import Path

assets_path = Path(__file__).parent
19 changes: 19 additions & 0 deletions tests/unit/test_image_uuid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import io
from uuid import UUID

from app.util.generate_uuid import generate_uuid
from ..assets import assets_path

BSN_UUID = UUID('b3aff1e9-8085-5300-8e06-37b522384659') # To test consistency of UUID across versions


def test_uuid_consistency():
file_path = assets_path / 'test_images' / 'bsn_0.jpg'
with open(file_path, 'rb') as f:
file_content = f.read()

uuid1 = generate_uuid(file_path)
uuid2 = generate_uuid(io.BytesIO(file_content))
uuid3 = generate_uuid(file_content)

assert uuid1 == uuid2 == uuid3 == BSN_UUID

0 comments on commit a4fdb86

Please sign in to comment.