diff --git a/app/Services/vector_db_context.py b/app/Services/vector_db_context.py index 0b78389..d3d8ea6 100644 --- a/app/Services/vector_db_context.py +++ b/app/Services/vector_db_context.py @@ -34,12 +34,15 @@ def __init__(self): grpc_port=config.qdrant.grpc_port, api_key=config.qdrant.api_key, prefer_grpc=config.qdrant.prefer_grpc) wrap_object(self._client, retry_async((AioRpcError, HTTPError))) + self._local = False case QdrantMode.LOCAL: self._client = AsyncQdrantClient(path=config.qdrant.local_path) + self._local = True case QdrantMode.MEMORY: logger.warning("Using in-memory Qdrant client. Data will be lost after application restart. " "This should only be used for testing and debugging.") self._client = AsyncQdrantClient(":memory:") + self._local = True case _: raise ValueError("Invalid Qdrant mode.") self.collection_name = config.qdrant.coll @@ -234,24 +237,22 @@ def _get_point_from_img_data(cls, img_data: ImageData) -> models.PointStruct: vector=cls._get_vector_from_img_data(img_data).vector ) - @classmethod - def _get_img_data_from_point(cls, point: AVAILABLE_POINT_TYPES) -> ImageData: + def _get_img_data_from_point(self, point: AVAILABLE_POINT_TYPES) -> ImageData: return (ImageData .from_payload(point.id, - point.payload, - image_vector=numpy.array(point.vector[cls.IMG_VECTOR], dtype=numpy.float32) - if point.vector and cls.IMG_VECTOR in point.vector else None, - text_contain_vector=numpy.array(point.vector[cls.TEXT_VECTOR], dtype=numpy.float32) - if point.vector and cls.TEXT_VECTOR in point.vector else None + # workaround: https://github.com/qdrant/qdrant-client/issues/624 + point.payload.copy() if self._local else point.payload, + image_vector=numpy.array(point.vector[self.IMG_VECTOR], dtype=numpy.float32) + if point.vector and self.IMG_VECTOR in point.vector else None, + text_contain_vector=numpy.array(point.vector[self.TEXT_VECTOR], dtype=numpy.float32) + if point.vector and self.TEXT_VECTOR in point.vector else None )) - @classmethod - def _get_img_data_from_points(cls, points: list[AVAILABLE_POINT_TYPES]) -> list[ImageData]: - return [cls._get_img_data_from_point(t) for t in points] + def _get_img_data_from_points(self, points: list[AVAILABLE_POINT_TYPES]) -> list[ImageData]: + return [self._get_img_data_from_point(t) for t in points] - @classmethod - def _get_search_result_from_scored_point(cls, point: models.ScoredPoint) -> SearchResult: - return SearchResult(img=cls._get_img_data_from_point(point), score=point.score) + def _get_search_result_from_scored_point(self, point: models.ScoredPoint) -> SearchResult: + return SearchResult(img=self._get_img_data_from_point(point), score=point.score) @classmethod def getVectorByBasis(cls, basis: SearchBasisEnum) -> str: diff --git a/tests/api/conftest.py b/tests/api/conftest.py new file mode 100644 index 0000000..1a60c83 --- /dev/null +++ b/tests/api/conftest.py @@ -0,0 +1,26 @@ +import pytest +from fastapi.testclient import TestClient + +from app import config + +TEST_ACCESS_TOKEN = 'test_token' +TEST_ADMIN_TOKEN = 'test_admin_token' + +config.config.qdrant.mode = "memory" +config.config.admin_api_enable = True +config.config.access_protected = True +config.config.access_token = TEST_ACCESS_TOKEN +config.config.admin_token = TEST_ADMIN_TOKEN +config.config.storage.method = config.StorageMode.LOCAL + + +@pytest.fixture(scope="session") +def test_client(tmp_path_factory) -> TestClient: + # Modify the configuration for testing + config.config.storage.local.path = tmp_path_factory.mktemp("static_files") + + from app.webapp import app + # Start the application + + with TestClient(app) as client: + yield client diff --git a/tests/api/integrate_test.py b/tests/api/integrate_test.py new file mode 100644 index 0000000..18e0621 --- /dev/null +++ b/tests/api/integrate_test.py @@ -0,0 +1,68 @@ +import asyncio +from pathlib import Path + +import pytest + +from .conftest import TEST_ADMIN_TOKEN, TEST_ACCESS_TOKEN + +assets_path = Path(__file__).parent / '..' / 'assets' + +test_images = {'bsn': ['bsn_0.jpg', 'bsn_1.jpg', 'bsn_2.jpg'], + 'cat': ['cat_0.jpg', 'cat_1.jpg'], + 'cg': ['cg_0.jpg', 'cg_1.png']} + + +@pytest.mark.asyncio +async def test_integrate(test_client): + credentials = {'x-admin-token': TEST_ADMIN_TOKEN, 'x-access-token': TEST_ACCESS_TOKEN} + resp = test_client.get("/", headers=credentials) + assert resp.status_code == 200 + img_ids = dict() + for img_cls in test_images: + img_ids[img_cls] = [] + for image in test_images[img_cls]: + print(f'upload image {image}...') + resp = test_client.post('/admin/upload', + files={'image_file': open(assets_path / 'test_images' / image, 'rb')}, + headers=credentials, + params={'local': True}) + assert resp.status_code == 200 + img_ids[img_cls].append(resp.json()['image_id']) + + print('Waiting for images to be processed...') + + while True: + resp = test_client.get('/admin/server_info', headers=credentials) + if resp.json()['image_count'] >= 7: + break + await asyncio.sleep(1) + + resp = test_client.get('/search/text/hatsune+miku', + headers=credentials) + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] in img_ids['cg'] + + resp = test_client.post('/search/image', + files={'image': open(assets_path / 'test_images' / test_images['cat'][0], 'rb')}, + headers=credentials) + + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] in img_ids['cat'] + + resp = test_client.get(f"/search/similar/{img_ids['bsn'][0]}", + headers=credentials) + + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] in img_ids['bsn'] + + resp = test_client.put(f"/admin/update_opt/{img_ids['bsn'][0]}", json={'categories': ['bsn'], 'starred': True}, + headers=credentials) + assert resp.status_code == 200 + + resp = test_client.get(f"/search/text/cat", params={'categories': 'bsn'}, headers=credentials) + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] in img_ids['bsn'] + + resp = test_client.get(f"/search/text/cat", params={'starred': True}, headers=credentials) + assert resp.status_code == 200 + assert resp.json()['result'][0]['img']['id'] in img_ids['bsn'] diff --git a/tests/api/test_home.py b/tests/api/test_home.py index d1b8646..1e86049 100644 --- a/tests/api/test_home.py +++ b/tests/api/test_home.py @@ -1,7 +1,6 @@ import pytest from fastapi.testclient import TestClient -from app.config import config from app.webapp import app client = TestClient(app) @@ -14,29 +13,22 @@ def anyio_backend(): class TestHome: - # noinspection PyMethodMayBeStatic - def setup_class(self): - config.admin_api_enable = True - config.access_protected = True - config.access_token = 'test_token' - config.admin_token = 'test_admin_token' - - def test_get_home_no_tokens(self): - response = client.get("/") + def test_get_home_no_tokens(self, test_client): + response = test_client.get("/") assert response.status_code == 200 assert response.json()['authorization']['required'] assert not response.json()['authorization']['passed'] assert response.json()['admin_api']['available'] assert not response.json()['admin_api']['passed'] - def test_get_home_access_token(self): - response = client.get("/", headers={'x-access-token': 'test_token'}) + def test_get_home_access_token(self, test_client): + response = test_client.get("/", headers={'x-access-token': 'test_token'}) assert response.status_code == 200 assert response.json()['authorization']['required'] assert response.json()['authorization']['passed'] - def test_get_home_admin_token(self): - response = client.get("/", headers={'x-admin-token': 'test_admin_token', 'x-access-token': 'test_token'}) + def test_get_home_admin_token(self, test_client): + response = test_client.get("/", headers={'x-admin-token': 'test_admin_token', 'x-access-token': 'test_token'}) assert response.status_code == 200 assert response.json()['admin_api']['available'] assert response.json()['admin_api']['passed'] diff --git a/tests/assets/test_images/bsn_0.jpg b/tests/assets/test_images/bsn_0.jpg new file mode 100644 index 0000000..966a7ef Binary files /dev/null and b/tests/assets/test_images/bsn_0.jpg differ diff --git a/tests/assets/test_images/bsn_1.jpg b/tests/assets/test_images/bsn_1.jpg new file mode 100644 index 0000000..a1eb294 Binary files /dev/null and b/tests/assets/test_images/bsn_1.jpg differ diff --git a/tests/assets/test_images/bsn_2.jpg b/tests/assets/test_images/bsn_2.jpg new file mode 100644 index 0000000..35f8cb5 Binary files /dev/null and b/tests/assets/test_images/bsn_2.jpg differ diff --git a/tests/assets/test_images/cg_0.jpg b/tests/assets/test_images/cg_0.jpg new file mode 100644 index 0000000..5580f09 Binary files /dev/null and b/tests/assets/test_images/cg_0.jpg differ diff --git a/tests/assets/test_images/cg_1.png b/tests/assets/test_images/cg_1.png new file mode 100644 index 0000000..ef66247 Binary files /dev/null and b/tests/assets/test_images/cg_1.png differ