Skip to content

Commit

Permalink
Merge pull request #36 from hv0905/undup
Browse files Browse the repository at this point in the history
Reproducible random pick and #35
  • Loading branch information
hv0905 authored Jun 21, 2024
2 parents 63a7ef3 + c12fb81 commit f6c9d85
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 23 deletions.
11 changes: 9 additions & 2 deletions app/Controllers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,22 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
if extension in {'.jpg', '.png', '.jpeg', '.jfif', '.webp', '.gif'}:
img_type = extension[1:]
if not img_type:
logger.warning("Failed to infer image format of the uploaded image. Content Type: {}, Filename: {}",
image_file.content_type, image_file.filename)
raise HTTPException(415, "Unsupported image format.")
img_bytes = await image_file.read()
img_id = generate_uuid(img_bytes)
if len(await services.db_context.validate_ids([str(img_id)])) != 0: # check for duplicate points
if img_id in services.upload_service.uploading_ids or len(
await services.db_context.validate_ids([str(img_id)])) != 0: # check for duplicate points
logger.warning("Duplicate upload request for image id: {}", img_id)
raise HTTPException(409, f"The uploaded point is already contained in the database! entity id: {img_id}")

try:
image = Image.open(BytesIO(img_bytes))
image.verify()
image.close()
except UnidentifiedImageError as ex:
logger.warning("Invalid image file from upload request. id: {}", img_id)
raise HTTPException(422, "Cannot open the image file.") from ex

image_data = ImageData(id=img_id,
Expand All @@ -128,7 +135,7 @@ async def upload_image(image_file: Annotated[UploadFile, File(description="The i
format=img_type,
index_date=datetime.now())

await services.upload_service.upload_image(image, image_data, img_bytes, model.skip_ocr, model.local_thumbnail)
await services.upload_service.upload_image(image_data, img_bytes, model.skip_ocr, model.local_thumbnail)
return ImageUploadResponse(message="OK. Image added to upload queue.", image_id=img_id)


Expand Down
10 changes: 7 additions & 3 deletions app/Controllers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,14 @@ async def combinedSearch(
@search_router.get("/random", description="Get random images")
async def randomPick(
filter_param: Annotated[FilterParams, Depends(FilterParams)],
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)]) -> SearchApiResponse:
paging: Annotated[SearchPagingParams, Depends(SearchPagingParams)],
seed: Annotated[int | None, Query(
description="The seed for random pick. This is helpful for generating a reproducible random pick.")] = None,
) -> SearchApiResponse:
logger.info("Random pick request received")
random_vector = services.transformers_service.get_random_vector()
result = await services.db_context.querySearch(random_vector, top_k=paging.count, filter_param=filter_param)
random_vector = services.transformers_service.get_random_vector(seed)
result = await services.db_context.querySearch(random_vector, top_k=paging.count, skip=paging.skip,
filter_param=filter_param)
return await result_postprocessing(
SearchApiResponse(result=result, message=f"Successfully get {len(result)} results.", query_id=uuid4()))

Expand Down
6 changes: 3 additions & 3 deletions app/Services/transformers_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_bert_vector(self, text: str) -> ndarray:
return vector.cpu().numpy()

@staticmethod
def get_random_vector() -> ndarray:
vec = np.random.rand(768)
vec -= vec.mean()
def get_random_vector(seed: int | None = None) -> ndarray:
generator = np.random.default_rng(seed)
vec = generator.uniform(-1, 1, 768)
return vec
14 changes: 9 additions & 5 deletions app/Services/upload_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,28 @@ def __init__(self, storage_service: StorageService, db_context: VectorDbContext,
self._queue = asyncio.Queue(config.admin_index_queue_max_length)
self._upload_worker_task = asyncio.create_task(self._upload_worker())

self.uploading_ids = set()
self._processed_count = 0

async def _upload_worker(self):
while True:
img, img_data, *args = await self._queue.get()
img_data, *args = await self._queue.get()
try:
await self._upload_task(img, img_data, *args)
await self._upload_task(img_data, *args)
logger.success("Image {} uploaded and indexed. Queue Length: {} [-1]", img_data.id, self._queue.qsize())
except Exception as ex:
logger.error("Error occurred while uploading image {}", img_data.id)
logger.exception(ex)
finally:
self._queue.task_done()
self.uploading_ids.remove(img_data.id)
self._processed_count += 1
if self._processed_count % 50 == 0:
gc.collect()

async def _upload_task(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
async def _upload_task(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
img = Image.open(BytesIO(img_bytes))
logger.info('Start indexing image {}. Local: {}. Size: {}', img_data.id, img_data.local, len(img_bytes))
file_name = f"{img_data.id}.{img_data.format}"
thumb_path = f"thumbnails/{img_data.id}.webp"
Expand Down Expand Up @@ -71,9 +74,10 @@ async def _upload_task(self, img: Image.Image, img_data: ImageData, img_bytes: b

img.close()

async def upload_image(self, img: Image.Image, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
async def upload_image(self, img_data: ImageData, img_bytes: bytes, skip_ocr: bool,
thumbnail_mode: UploadImageThumbnailMode):
await self._queue.put((img, img_data, img_bytes, skip_ocr, thumbnail_mode))
self.uploading_ids.add(img_data.id)
await self._queue.put((img_data, img_bytes, skip_ocr, thumbnail_mode))
logger.success("Image {} added to upload queue. Queue Length: {} [+1]", img_data.id, self._queue.qsize())

def get_queue_size(self):
Expand Down
2 changes: 2 additions & 0 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
TEST_ACCESS_TOKEN = 'test_token'
TEST_ADMIN_TOKEN = 'test_admin_token'

assets_path = Path(__file__).parent / '..' / 'assets'


@pytest.fixture(scope="session")
def test_client(tmp_path_factory) -> TestClient:
Expand Down
6 changes: 1 addition & 5 deletions tests/api/integrate_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from pathlib import Path

import pytest

from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN

assets_path = Path(__file__).parent / '..' / 'assets'
from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN, assets_path

test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'],
'cat': ['cat_0.jpg', 'cat_1.jpg'],
Expand Down
34 changes: 29 additions & 5 deletions tests/api/test_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import pytest

from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN
from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN, assets_path

test_file_path = assets_path / 'test_images' / 'bsn_0.jpg'


def get_single_img_info(test_client, image_id):
Expand Down Expand Up @@ -36,6 +38,29 @@ def test_upload_unsupported_types(test_client):
assert resp.status_code == 415


@pytest.mark.asyncio
async def test_upload_duplicate(test_client, check_local_dir_empty, wait_for_background_task):
def upload(file):
return test_client.post('/admin/upload',
files={'image_file': file},
headers={'x-admin-token': TEST_ADMIN_TOKEN},
params={'local': True})

with open(test_file_path, 'rb') as f:
resp = upload(f)
assert resp.status_code == 200
image_id = resp.json()['image_id']
resp = upload(f) # The previous image is still in queue
assert resp.status_code == 409
await wait_for_background_task(1)
resp = upload(f) # The previous image is indexed now
assert resp.status_code == 409

# cleanup
resp = test_client.delete(f'/admin/delete/{image_id}', headers={'x-admin-token': TEST_ADMIN_TOKEN})
assert resp.status_code == 200


TEST_FAKE_URL = 'fake-url'
TEST_FAKE_THUMBNAIL_URL = 'fake-thumbnail-url'

Expand All @@ -55,10 +80,10 @@ def test_upload_unsupported_types(test_client):
@pytest.mark.asyncio
async def test_upload_thumbnails(test_client, check_local_dir_empty, wait_for_background_task, # Fixtures
add_trailing_bytes, params, expect_local_url, expect_thumbnail_mode): # Parameters
with open('tests/assets/test_images/bsn_0.jpg', 'rb') as f:
img_bytes = f.read()
with open(test_file_path, 'rb') as f:
# append 500KB to the image, to make it large enough to generate a thumbnail
if add_trailing_bytes:
img_bytes = f.read()
img_bytes += bytearray(random.getrandbits(8) for _ in range(1024 * 500))
f_patched = io.BytesIO(img_bytes)
f_patched.name = 'bsn_0.jpg'
Expand Down Expand Up @@ -118,8 +143,7 @@ async def test_upload_thumbnails(test_client, check_local_dir_empty, wait_for_ba
@pytest.mark.asyncio
async def test_update_opt(test_client, check_local_dir_empty, wait_for_background_task, # Fixtures
initial_param, update_param, expected_param, resp_code): # Parameters
with open('tests/assets/test_images/bsn_0.jpg', 'rb') as f:
img_bytes = f.read()
with open(test_file_path, 'rb') as f:
resp = test_client.post('/admin/upload',
files={'image_file': f},
headers={'x-admin-token': TEST_ADMIN_TOKEN},
Expand Down

0 comments on commit f6c9d85

Please sign in to comment.