diff --git a/src/api/.env-template b/src/api/.env-template index 3671ceda..d076d5d2 100755 --- a/src/api/.env-template +++ b/src/api/.env-template @@ -6,7 +6,7 @@ AWS_SECRET_ACCESS_KEY=XXXXX AWS_BUCKET=XXXXX S3_CREDENTIALS_PATH = XXXXX/XXXXX GOOGLE_APPLICATION_CREDENTIALS="credentials.json" -ES_HOST=localhost +ES_HOST=es ES_USERNAME=XXXXX ES_PASSWORD=XXXXX ES_IMG_INDEX=imgsearch diff --git a/src/api/core/models/media_factory.py b/src/api/core/models/media_factory.py index 98fa33d4..e9325f39 100644 --- a/src/api/core/models/media_factory.py +++ b/src/api/core/models/media_factory.py @@ -80,7 +80,9 @@ def make_from_file_on_disk(video_path): @staticmethod def make_from_file_in_memory(file_data: FileStorage): # save on disk - return {"path": "file_path_on_disk"} + fname = "/tmp/"+file_data.filename + file_data.save(fname) + return {"path": fname} media_factory = { diff --git a/src/api/core/operators/test_vid_vec_rep_resnet.py b/src/api/core/operators/test_vid_vec_rep_resnet.py index 1a576b06..e4033e3c 100644 --- a/src/api/core/operators/test_vid_vec_rep_resnet.py +++ b/src/api/core/operators/test_vid_vec_rep_resnet.py @@ -19,7 +19,9 @@ def tearDownClass(cls): def test_sample_video_from_disk(self): video = {"path": r"sample_data/sample-cat-video.mp4"} result = vid_vec_rep_resnet.run(video) - self.assertEqual(len(list(result)), 6) + # self.assertEqual(len(list(result)), 6) + for vec in result: + self.assertEqual(len(vec.get('vid_vec')), 512) @skip def test_unsupported_sample_video_from_disk(self): @@ -37,7 +39,6 @@ def test_sample_video_from_url(self): wget.download(video_url, out=video_path) video = VideoFactory.make_from_file_on_disk(video_path) - avg_vec, all_vec = vid_vec_rep_resnet.run(video) - self.assertEqual(len(avg_vec), 512) + all_vec = vid_vec_rep_resnet.run(video) for vec in all_vec: self.assertEqual(len(vec), 512) diff --git a/src/api/endpoint/search.py b/src/api/endpoint/search.py index 7cce58a9..1245948a 100644 --- a/src/api/endpoint/search.py +++ b/src/api/endpoint/search.py @@ -49,7 +49,6 @@ def handle_search(self): return {"matches": results} elif data["query_type"] == "video": file = request.files["media"] - print(file, type(file)) vid_obj = media_factory[MediaType.VIDEO].make_from_file_in_memory( file ) @@ -57,8 +56,9 @@ def handle_search(self): "vid_vec_rep_resnet" ].run(vid_obj) average_vector = next(vid_vec) - results = self.feluda.store.find("image", average_vector) - return {"matches": []} + # TODO: explore finding "all_vectors" along with the "avg_vector" + results = self.feluda.store.find("video", average_vector.get('vid_vec')) + return {"matches": results} else: return {"message": "Unsupported Query Type"} else: diff --git a/src/api/sample_data/07ba4a2f-c0a2-44ba-96d8-7b4cc94c8ee7.mp4 b/src/api/tests/sample_data/07ba4a2f-c0a2-44ba-96d8-7b4cc94c8ee7.mp4 similarity index 100% rename from src/api/sample_data/07ba4a2f-c0a2-44ba-96d8-7b4cc94c8ee7.mp4 rename to src/api/tests/sample_data/07ba4a2f-c0a2-44ba-96d8-7b4cc94c8ee7.mp4 diff --git a/src/api/sample_data/c8709f21-bd7d-4e22-af14-50ad8a429f84.jpeg b/src/api/tests/sample_data/c8709f21-bd7d-4e22-af14-50ad8a429f84.jpeg similarity index 100% rename from src/api/sample_data/c8709f21-bd7d-4e22-af14-50ad8a429f84.jpeg rename to src/api/tests/sample_data/c8709f21-bd7d-4e22-af14-50ad8a429f84.jpeg diff --git a/src/api/sample_data/simple-text.txt b/src/api/tests/sample_data/simple-text.txt similarity index 100% rename from src/api/sample_data/simple-text.txt rename to src/api/tests/sample_data/simple-text.txt diff --git a/src/api/test_handlers.py b/src/api/tests/test_handlers.py similarity index 100% rename from src/api/test_handlers.py rename to src/api/tests/test_handlers.py diff --git a/src/api/test_health_api.py b/src/api/tests/test_health_api.py similarity index 100% rename from src/api/test_health_api.py rename to src/api/tests/test_health_api.py diff --git a/src/api/test_index_api.py b/src/api/tests/test_index_api.py similarity index 100% rename from src/api/test_index_api.py rename to src/api/tests/test_index_api.py diff --git a/src/api/test_index_api_as_client.py b/src/api/tests/test_index_api_as_client.py similarity index 100% rename from src/api/test_index_api_as_client.py rename to src/api/tests/test_index_api_as_client.py diff --git a/src/api/test_search_api_as_client.py b/src/api/tests/test_search_api_as_client.py similarity index 80% rename from src/api/test_search_api_as_client.py rename to src/api/tests/test_search_api_as_client.py index 6d5db606..5f977d00 100644 --- a/src/api/test_search_api_as_client.py +++ b/src/api/tests/test_search_api_as_client.py @@ -27,16 +27,9 @@ def testSearchRawQuery(self): @skip def testSearchImage(self): url = API_URL + "/search" - # data = {"query_type": "image"} + data = {"query_type": "image"} with open("sample_data/c8709f21-bd7d-4e22-af14-50ad8a429f84.jpeg", "rb") as file: - # files = { - # "media": file, - # "data": json.dumps(data), - # } - # response = requests.post(url, files=files) - # print(response.text) - # self.assertEqual(response.status_code, 200) - data = {"data": json.dumps({"query_type": "image"})} + data = {"data": json.dumps(data)} files = {"media": file} response = requests.post(url, data=data, files=files) print(response.text)