Skip to content

Commit

Permalink
Merge pull request #379 from Snehil-Shah/worker
Browse files Browse the repository at this point in the history
worker for clustering media
  • Loading branch information
aatmanvaidya committed Sep 9, 2024
2 parents 1da1c27 + ba01287 commit 4abae4d
Show file tree
Hide file tree
Showing 13 changed files with 494 additions and 538 deletions.
14 changes: 7 additions & 7 deletions src/core/models/media_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def make_from_url(video_url):
try:
print("Downloading video from URL")
wget.download(video_url, out=file_path)
print("Video downloaded")
print("\nVideo downloaded")
except Exception as e:
print("Error downloading video:", e)
raise Exception("Error Downloading Video")
Expand All @@ -95,7 +95,7 @@ def make_from_url(video_url):
try:
print("Downloading video from S3")
AWSS3Utils.download_file_from_s3(bucket_name, file_key, file_path)
print("Video downloaded")
print("\nVideo downloaded")
except Exception as e:
print("Error downloading video from S3:", e)
raise Exception("Error Downloading Video")
Expand Down Expand Up @@ -126,7 +126,7 @@ def make_from_url(audio_url):
try:
print("Downloading audio from URL")
wget.download(audio_url, out=file_path)
print("Audio downloaded")
print("\nAudio downloaded")
except Exception as e:
print("Error downloading audio:", e)
raise Exception("Error Downloading audio")
Expand All @@ -138,13 +138,13 @@ def make_from_url(audio_url):
try:
print("Downloading audio from S3")
AWSS3Utils.download_file_from_s3(bucket_name, file_key, file_path)
print("Audio downloaded")
print("\nAudio downloaded")
except Exception as e:
print("Error downloading audio from S3:", e)
raise Exception("Error Downloading audio")

return {"path": file_path}

@staticmethod
def make_from_url_to_wav(audio_url):
temp_dir = tempfile.gettempdir()
Expand All @@ -156,7 +156,7 @@ def make_from_url_to_wav(audio_url):
print("Downloading audio from URL")
wget.download(audio_url, out=audio_file)
print("\naudio downloaded")

_, file_extension = os.path.splitext(file_name)
if file_extension != '.wav':
audio = AudioSegment.from_file(audio_file, format=file_extension[1:])
Expand All @@ -172,7 +172,7 @@ def make_from_url_to_wav(audio_url):
@staticmethod
def make_from_file_on_disk(audio_path):
return {"path": audio_path}



media_factory = {
Expand Down
49 changes: 32 additions & 17 deletions src/core/operators/audio_vec_embedding_clap.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,68 @@
"""Operator to get audio representation using LAION-CLAP - https://github.com/LAION-AI/CLAP. """
"""
Operator to get audio representation using LAION-CLAP - https://huggingface.co/laion/larger_clap_general
"""

def initialize(param):
"""
Initializes the operator.
Args:
param (dict): A dict to initialize and load the the model.
param (dict): A dict to initialize and load the model.
"""
global model
global librosa
global np
global contextmanager
global os
global model, processor, librosa, contextmanager, os, torch, device

import numpy as np
import librosa
from contextlib import contextmanager
import os
import laion_clap
from transformers import ClapModel, ClapProcessor
import torch

# Load the model and processor
model = ClapModel.from_pretrained("laion/larger_clap_general")
processor = ClapProcessor.from_pretrained("laion/larger_clap_general")

model = laion_clap.CLAP_Module()
model.load_ckpt() # load the best checkpoint (HTSAT model) in the paper.
print("model successfully downloaded")
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("audio CLAP Model successfully initialized and loaded onto", device)


def run(audio_file):
"""
Runs the operator and compute inference on the audio file.
Runs the operator and computes inference on the audio file.
Args:
audio_file (dict): `AudioFactory` file object.
Returns:
audio_emb (numpy.ndarray): A 512-length vector embedding representing the audio.
audio_emb (list): A 512-length vector embedding representing the audio.
"""
audio = audio_file["path"]

@contextmanager
def audio_load(fname):
"""
Loads audio and removes the file after use.
Args:
fname (str): Path to the audio file.
Yields:
numpy.ndarray: Loaded audio data.
"""
a, _ = librosa.load(fname, sr=48000)
try:
yield a
finally:
os.remove(fname)

with audio_load(audio) as audio_var:
query_audio = audio_var.reshape(1, -1)
audio_emb = model.get_audio_embedding_from_data(x = query_audio, use_tensor=False)
audio_emb = audio_emb.reshape(-1)
inputs = processor(audios=audio_var, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
audio_emb = model.get_audio_features(**inputs)
audio_emb = audio_emb.squeeze(0).tolist()
return audio_emb
4 changes: 2 additions & 2 deletions src/core/operators/audio_vec_embedding_clap_requirements.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
laion-clap==1.1.6
librosa==0.10.2.post1
torchvision==0.19.0
transformers==4.44.0
torch==2.4.0
Loading

0 comments on commit 4abae4d

Please sign in to comment.