From f6987a6d3aa4ff018b5ebac248a4469437df80d3 Mon Sep 17 00:00:00 2001 From: Aatman Vaidya Date: Fri, 16 Feb 2024 09:58:27 +0530 Subject: [PATCH] feat: feluda store supports audio (#78) * feat: feluda store supports audio * fix: delete and refresh for ES * dhore: profiling audio operator --- .gitignore | 2 +- src/api/config-server.yml | 1 + src/api/core/config.py | 1 + src/api/core/models/media.py | 3 + .../audio_vec_embedding_requirements.in | 4 +- .../audio_vec_embedding_requirements.txt | 37 ++++- src/api/core/store/es_vec.py | 7 +- src/api/core/store/es_vec_mappings.py | 22 +++ src/api/test_audio_es_vec.py | 157 +++++++++--------- src/api/test_es_vec.py | 19 ++- 10 files changed, 164 insertions(+), 89 deletions(-) diff --git a/.gitignore b/.gitignore index 2d614fec..840db038 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,4 @@ src/api/core/operators/yolov8n-seg.pt # Audio Files **100_audio_files/ **50_audio_files/ - +**audio_files/ diff --git a/src/api/config-server.yml b/src/api/config-server.yml index 9260e96b..390cb986 100644 --- a/src/api/config-server.yml +++ b/src/api/config-server.yml @@ -6,6 +6,7 @@ store : image_index_name : "image" text_index_name : "text" video_index_name : "video" + audio_index_name : "audio" queue : label : "Queue" diff --git a/src/api/core/config.py b/src/api/core/config.py index 58dcbac6..c3921228 100644 --- a/src/api/core/config.py +++ b/src/api/core/config.py @@ -22,6 +22,7 @@ class StoreParameters: image_index_name: str text_index_name: str video_index_name: str + audio_index_name: str @dataclass diff --git a/src/api/core/models/media.py b/src/api/core/models/media.py index cbcfdb77..ff621224 100644 --- a/src/api/core/models/media.py +++ b/src/api/core/models/media.py @@ -13,6 +13,7 @@ class MediaType(Enum): TEXT = "text" IMAGE = "image" VIDEO = "video" + AUDIO = "audio" @classmethod def make(media_type): @@ -22,5 +23,7 @@ def make(media_type): return MediaType.IMAGE elif media_type is "video": return MediaType.VIDEO + elif media_type is "audio": + return MediaType.AUDIO else: return MediaType.UNSUPPORTED diff --git a/src/api/core/operators/audio_vec_embedding_requirements.in b/src/api/core/operators/audio_vec_embedding_requirements.in index 38719284..acf5dfed 100644 --- a/src/api/core/operators/audio_vec_embedding_requirements.in +++ b/src/api/core/operators/audio_vec_embedding_requirements.in @@ -1,3 +1,5 @@ numpy==1.26.4 librosa==0.10.1 -panns-inference==0.1.1 \ No newline at end of file +panns-inference==0.1.1 +torch==2.2.0+cpu +torchvision==0.17.0+cpu \ No newline at end of file diff --git a/src/api/core/operators/audio_vec_embedding_requirements.txt b/src/api/core/operators/audio_vec_embedding_requirements.txt index 35b87d71..0a3aee8a 100644 --- a/src/api/core/operators/audio_vec_embedding_requirements.txt +++ b/src/api/core/operators/audio_vec_embedding_requirements.txt @@ -2,8 +2,10 @@ # This file is autogenerated by pip-compile with Python 3.12 # by the following command: # -# pip-compile audio_vec_embedding_requirements.in +# pip-compile --find-links=https://download.pytorch.org/whl/torch_stable.html audio_vec_embedding_requirements.in # +--find-links https://download.pytorch.org/whl/torch_stable.html + audioread==3.0.1 # via librosa certifi==2024.2.2 @@ -18,10 +20,16 @@ cycler==0.12.1 # via matplotlib decorator==5.1.1 # via librosa +filelock==3.13.1 + # via torch fonttools==4.48.1 # via matplotlib +fsspec==2024.2.0 + # via torch idna==3.6 # via requests +jinja2==3.1.3 + # via torch joblib==1.3.2 # via # librosa @@ -37,10 +45,16 @@ librosa==0.10.1 # torchlibrosa llvmlite==0.42.0 # via numba +markupsafe==2.1.5 + # via jinja2 matplotlib==3.8.2 # via panns-inference +mpmath==1.3.0 + # via sympy msgpack==1.0.7 # via librosa +networkx==3.2.1 + # via torch numba==0.59.0 # via librosa numpy==1.26.4 @@ -54,6 +68,7 @@ numpy==1.26.4 # scipy # soxr # torchlibrosa + # torchvision packaging==23.2 # via # matplotlib @@ -61,7 +76,9 @@ packaging==23.2 panns-inference==0.1.1 # via -r audio_vec_embedding_requirements.in pillow==10.2.0 - # via matplotlib + # via + # matplotlib + # torchvision platformdirs==4.2.0 # via pooch pooch==1.8.0 @@ -73,7 +90,9 @@ pyparsing==3.1.1 python-dateutil==2.8.2 # via matplotlib requests==2.31.0 - # via pooch + # via + # pooch + # torchvision scikit-learn==1.4.0 # via librosa scipy==1.12.0 @@ -86,11 +105,21 @@ soundfile==0.12.1 # via librosa soxr==0.3.7 # via librosa +sympy==1.12 + # via torch threadpoolctl==3.2.0 # via scikit-learn +torch==2.2.0+cpu + # via + # -r audio_vec_embedding_requirements.in + # torchvision torchlibrosa==0.1.0 # via panns-inference +torchvision==0.17.0+cpu + # via -r audio_vec_embedding_requirements.in typing-extensions==4.9.0 - # via librosa + # via + # librosa + # torch urllib3==2.2.0 # via requests diff --git a/src/api/core/store/es_vec.py b/src/api/core/store/es_vec.py index 827bc3a5..2d97b4a4 100644 --- a/src/api/core/store/es_vec.py +++ b/src/api/core/store/es_vec.py @@ -19,6 +19,7 @@ def __init__(self, config: StoreConfig): "text": config.parameters.text_index_name, "image": config.parameters.image_index_name, "video": config.parameters.video_index_name, + "audio": config.parameters.audio_index_name, } def connect(self): @@ -57,7 +58,7 @@ def optionally_create_index(self): def delete_indices(self): for index in self.indices: - self.client.indices.delete(index=index) + self.client.indices.delete(index=self.indices[index]) def get_indices(self): index_list = "" @@ -78,7 +79,7 @@ def store(self, media_type: MediaType, doc): def refresh(self): for index in self.indices: - self.client.indices.refresh(index=index) + self.client.indices.refresh() def find(self, index_name, vec): if type(vec) == np.ndarray: @@ -91,6 +92,8 @@ def find(self, index_name, vec): calculation = "1 / (1 + l2norm(params.query_vector, 'image_vec'))" elif index_name == self.indices["video"]: calculation = "1 / (1 + l2norm(params.query_vector, 'vid_vec'))" + elif index_name == self.indices["audio"]: + calculation = "1 / (1 + l2norm(params.query_vector, 'audio_vec'))" print("calculation:", calculation) q = { diff --git a/src/api/core/store/es_vec_mappings.py b/src/api/core/store/es_vec_mappings.py index bc4f0a1b..88f1f27d 100644 --- a/src/api/core/store/es_vec_mappings.py +++ b/src/api/core/store/es_vec_mappings.py @@ -87,4 +87,26 @@ } } }""", + "audio": """{ + "mappings": { + "properties":{ + "e_kosh_id": { + "type": "keyword" + }, + "dataset": { + "type": "keyword" + }, + "metadata": { + "type": "object" + }, + "audio_vec": { + "type":"dense_vector", + "dims": 2048 + }, + "date_added": { + "type": "date" + } + } + } + }""", } diff --git a/src/api/test_audio_es_vec.py b/src/api/test_audio_es_vec.py index 0a369b5b..a179bb0b 100644 --- a/src/api/test_audio_es_vec.py +++ b/src/api/test_audio_es_vec.py @@ -1,11 +1,14 @@ import unittest from unittest.case import skip import requests +from core.store.es_vec import ES +from core.config import StoreConfig, StoreParameters +from core.models.media import MediaType import pprint -import os -from elasticsearch import Elasticsearch +from datetime import datetime from core.operators import audio_vec_embedding from time import sleep +import os pp = pprint.PrettyPrinter(indent=4) ''' @@ -13,6 +16,8 @@ curl -X GET "http://es:9200/_cat/indices?v" Delete all the documents in an index curl -X POST "http://es:9200/test_audio/_delete_by_query" -H 'Content-Type: application/json' -d'{"query":{"match_all":{}}}' +Delete the indice +curl -X DELETE "http://es:9200/test_audio" ''' class TestAudioES(unittest.TestCase): @@ -25,111 +30,113 @@ def setUpClass(cls) -> None: print("Elastic search server is running") else: print("No elasticsearch service found. Tests are bound to fail.") + param_dict = { + "host_name": "es", + "text_index_name": "test_text", + "image_index_name": "test_image", + "video_index_name": "test_video", + "audio_index_name": "test_audio", + } + cls.param = StoreConfig( + label="test", + type="es", + parameters=StoreParameters( + host_name=param_dict["host_name"], + image_index_name=param_dict["image_index_name"], + text_index_name=param_dict["text_index_name"], + video_index_name=param_dict["video_index_name"], + audio_index_name=param_dict["audio_index_name"], + ) + ) - cls.es_host = os.environ.get("ES_HOST") - try: - cls.config = {"host": cls.es_host, "port": 9200, "scheme": "http"} - cls.client = Elasticsearch([cls.config,]) - print("Success Connecting to Elasticsearch") - except Exception: - print("Error Connecting to Elasticsearch") - @classmethod def tearDownClass(cls) -> None: print("TEARING DOWN CLASS") pass - def create_test_audio_index(self): - global index_name - index_name = "test_audio" - index_config = { - "mappings": { - "_source": { - "excludes": ["audio-embedding"] - }, - "properties": { - "audio-embedding": { - "type": "dense_vector", - "dims": 2048, - "index": True, - "similarity": "cosine" - }, - } - } - } - try: - if self.client.indices.exists(index=index_name): - print(f"Index '{index_name}' already exists.") - return - response = self.client.indices.create(index=index_name, body=index_config) - if response["acknowledged"]: - print(f"Index '{index_name}' created successfully.") - else: - print(f"Failed to create index '{index_name}'.") - except Exception as e: - print(f"Error creating index '{index_name}': {e}") - + def test_create_audio_indice(self): + es = ES(self.param) + es.connect() + es.optionally_create_index() + indices = es.get_indices() + self.assertEqual(indices["test_audio"]["mappings"]["properties"]["audio_vec"]["dims"], 2048) + @skip - def test_store_audio_vector(self): - # create the audio indice - self.create_test_audio_index() - # generate an audio vector + def test_store_audio(self): + es = ES(self.param) + es.connect() audio_vec_embedding.initialize(param=None) audio_file_path = r'core/operators/sample_data/audio.wav' audio_emb = audio_vec_embedding.run(audio_file_path) audio_emb_vec = audio_emb.tolist() - # index the vector - body = { - 'audio-embedding' : audio_emb_vec, + doc = { + "e_kosh_id": str(1231231), + "dataset": "test-dataset-id", + "metadata": {}, + "audio_vec": audio_emb_vec, + "date_added": datetime.utcnow(), } - result = self.client.index(index=index_name, document=body) + mediaType = MediaType.AUDIO + result = es.store(mediaType, doc) # print(result) self.assertEqual(result["result"], "created") - + # @skip - def test_search_audio_vector(self): - # create the audio indice - self.create_test_audio_index() - # generate an audio vector + def test_store_and_search_audio(self): + es = ES(self.param) + es.connect() audio_vec_embedding.initialize(param=None) audio_file_path = r'core/operators/sample_data/audio.wav' audio_emb = audio_vec_embedding.run(audio_file_path) audio_emb_vec = audio_emb.tolist() - # index the vector - body = { - 'audio-embedding' : audio_emb_vec, + doc = { + "e_kosh_id": str(1233333333), + "dataset": "test-dataset-id", + "metadata": {}, + "audio_vec": audio_emb_vec, + "date_added": datetime.utcnow(), } - self.client.index(index=index_name, document=body) - # search for it - query = { - "query": { - "script_score": { - "query": {"match_all": {}}, - "script": { - "source": "cosineSimilarity(params.query_vector, 'audio-embedding') + 1.0", - "params": {"query_vector": audio_emb_vec} - } - } - } - } - search_result = self.client.search(index="test_audio", body=query) + mediaType = MediaType.AUDIO + es.store(mediaType, doc) + sleep(4) + search_result = es.find("test_audio", audio_emb_vec) print(search_result) + self.assertEqual(search_result[0]['dataset'], "test-dataset-id") @skip def test_store_and_search_50files(self): - self.create_test_audio_index() + es = ES(self.param) + es.connect() + mediaType = MediaType.AUDIO audio_vec_embedding.initialize(param=None) audio_folder_path = r'core/operators/sample_data/50_audio_files' + count = 1 + # store 50 files for file_name in os.listdir(audio_folder_path): audio_file_path = os.path.join(audio_folder_path, file_name) + # generate an audio vector audio_emb = audio_vec_embedding.run(audio_file_path) audio_emb_vec = audio_emb.tolist() - body = { - 'audio-embedding' : audio_emb_vec, + doc = { + "e_kosh_id": str(count), + "dataset": "test-dataset-id", + "metadata": {}, + "audio_vec": audio_emb_vec, + "date_added": datetime.utcnow(), } - self.client.index(index=index_name, document=body) - sleep(0.5) + es.store(mediaType, doc) + print(f"----------{count}---------------") + count = count + 1 + print(f"Indexed {count} files") + sleep(3) + audio_to_search = r'core/operators/sample_data/100_audio_files/a-cappella-chorus.wav' + audio_to_search_emb = audio_vec_embedding.run(audio_to_search) + audio_to_search_emb_vec = audio_to_search_emb.tolist() + search_result = es.find("test_audio", audio_to_search_emb_vec) + print(search_result) + self.assertEqual(search_result[0]['dataset'], "test-dataset-id") + diff --git a/src/api/test_es_vec.py b/src/api/test_es_vec.py index 3bd80dc7..7fdd4278 100644 --- a/src/api/test_es_vec.py +++ b/src/api/test_es_vec.py @@ -7,6 +7,7 @@ import pprint from datetime import datetime import numpy as np +from time import sleep pp = pprint.PrettyPrinter(indent=4) @@ -26,6 +27,7 @@ def setUpClass(cls) -> None: "text_index_name": "test_text", "image_index_name": "test_image", "video_index_name": "test_video", + "audio_index_name": "test_audio", } cls.param = StoreConfig( label="test", @@ -35,6 +37,7 @@ def setUpClass(cls) -> None: image_index_name=param_dict["image_index_name"], text_index_name=param_dict["text_index_name"], video_index_name=param_dict["video_index_name"], + audio_index_name=param_dict["audio_index_name"], ) ) @@ -65,8 +68,11 @@ def test_create_indices(self): self.assertEqual( indices["test_video"]["mappings"]["properties"]["vid_vec"]["dims"], 512 ) + self.assertEqual( + indices["test_audio"]["mappings"]["properties"]["audio_vec"]["dims"], 2048 + ) - # @skip + @skip def test_store_image(self): es = ES(self.param) es.connect() @@ -99,13 +105,14 @@ def test_search_vectors(self): "date_added": datetime.utcnow(), } result = es.store(MediaType.IMAGE, doc) - pp.pprint(result) + # pp.pprint(result) + sleep(2) search_result = es.find("test_image", vec) - # es.refresh() + es.refresh() print("SEARCH RESULTS \n : ") - pp.pprint(search_result) - self.assertEqual(search_result[0].get('dataset'), "test-dataset-id") - # es.delete_indices() + print(search_result) + self.assertEqual(search_result[0]['dataset'], "test-dataset-id") + es.delete_indices() def test_store_text(self): pass