Skip to content

Commit

Permalink
fix: video search (#52)
Browse files Browse the repository at this point in the history
* chore: moving test files to a folder
* fix: video search
* docs: commenting TODO in search.py
  • Loading branch information
aatmanvaidya committed Feb 2, 2024
1 parent 8e55226 commit af54ac0
Show file tree
Hide file tree
Showing 12 changed files with 13 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/api/.env-template
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ AWS_SECRET_ACCESS_KEY=XXXXX
AWS_BUCKET=XXXXX
S3_CREDENTIALS_PATH = XXXXX/XXXXX
GOOGLE_APPLICATION_CREDENTIALS="credentials.json"
ES_HOST=localhost
ES_HOST=es
ES_USERNAME=XXXXX
ES_PASSWORD=XXXXX
ES_IMG_INDEX=imgsearch
Expand Down
4 changes: 3 additions & 1 deletion src/api/core/models/media_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def make_from_file_on_disk(video_path):
@staticmethod
def make_from_file_in_memory(file_data: FileStorage):
# save on disk
return {"path": "file_path_on_disk"}
fname = "/tmp/"+file_data.filename
file_data.save(fname)
return {"path": fname}


media_factory = {
Expand Down
7 changes: 4 additions & 3 deletions src/api/core/operators/test_vid_vec_rep_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def tearDownClass(cls):
def test_sample_video_from_disk(self):
video = {"path": r"sample_data/sample-cat-video.mp4"}
result = vid_vec_rep_resnet.run(video)
self.assertEqual(len(list(result)), 6)
# self.assertEqual(len(list(result)), 6)
for vec in result:
self.assertEqual(len(vec.get('vid_vec')), 512)

@skip
def test_unsupported_sample_video_from_disk(self):
Expand All @@ -37,7 +39,6 @@ def test_sample_video_from_url(self):
wget.download(video_url, out=video_path)
video = VideoFactory.make_from_file_on_disk(video_path)

avg_vec, all_vec = vid_vec_rep_resnet.run(video)
self.assertEqual(len(avg_vec), 512)
all_vec = vid_vec_rep_resnet.run(video)
for vec in all_vec:
self.assertEqual(len(vec), 512)
6 changes: 3 additions & 3 deletions src/api/endpoint/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ def handle_search(self):
return {"matches": results}
elif data["query_type"] == "video":
file = request.files["media"]
print(file, type(file))
vid_obj = media_factory[MediaType.VIDEO].make_from_file_in_memory(
file
)
vid_vec = self.feluda.operators.active_operators[
"vid_vec_rep_resnet"
].run(vid_obj)
average_vector = next(vid_vec)
results = self.feluda.store.find("image", average_vector)
return {"matches": []}
# TODO: explore finding "all_vectors" along with the "avg_vector"
results = self.feluda.store.find("video", average_vector.get('vid_vec'))
return {"matches": results}
else:
return {"message": "Unsupported Query Type"}
else:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,9 @@ def testSearchRawQuery(self):
@skip
def testSearchImage(self):
url = API_URL + "/search"
# data = {"query_type": "image"}
data = {"query_type": "image"}
with open("sample_data/c8709f21-bd7d-4e22-af14-50ad8a429f84.jpeg", "rb") as file:
# files = {
# "media": file,
# "data": json.dumps(data),
# }
# response = requests.post(url, files=files)
# print(response.text)
# self.assertEqual(response.status_code, 200)
data = {"data": json.dumps({"query_type": "image"})}
data = {"data": json.dumps(data)}
files = {"media": file}
response = requests.post(url, data=data, files=files)
print(response.text)
Expand Down

0 comments on commit af54ac0

Please sign in to comment.