Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[81] - add worker for media clustering #379

Merged
merged 14 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/core/models/media_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def make_from_url(video_url):
try:
print("Downloading video from URL")
wget.download(video_url, out=file_path)
print("Video downloaded")
print("\nVideo downloaded")
except Exception as e:
print("Error downloading video:", e)
raise Exception("Error Downloading Video")
Expand All @@ -95,7 +95,7 @@ def make_from_url(video_url):
try:
print("Downloading video from S3")
AWSS3Utils.download_file_from_s3(bucket_name, file_key, file_path)
print("Video downloaded")
print("\nVideo downloaded")
except Exception as e:
print("Error downloading video from S3:", e)
raise Exception("Error Downloading Video")
Expand Down Expand Up @@ -126,7 +126,7 @@ def make_from_url(audio_url):
try:
print("Downloading audio from URL")
wget.download(audio_url, out=file_path)
print("Audio downloaded")
print("\nAudio downloaded")
except Exception as e:
print("Error downloading audio:", e)
raise Exception("Error Downloading audio")
Expand All @@ -138,13 +138,13 @@ def make_from_url(audio_url):
try:
print("Downloading audio from S3")
AWSS3Utils.download_file_from_s3(bucket_name, file_key, file_path)
print("Audio downloaded")
print("\nAudio downloaded")
except Exception as e:
print("Error downloading audio from S3:", e)
raise Exception("Error Downloading audio")

return {"path": file_path}

@staticmethod
def make_from_url_to_wav(audio_url):
temp_dir = tempfile.gettempdir()
Expand All @@ -156,7 +156,7 @@ def make_from_url_to_wav(audio_url):
print("Downloading audio from URL")
wget.download(audio_url, out=audio_file)
print("\naudio downloaded")

_, file_extension = os.path.splitext(file_name)
if file_extension != '.wav':
audio = AudioSegment.from_file(audio_file, format=file_extension[1:])
Expand All @@ -172,7 +172,7 @@ def make_from_url_to_wav(audio_url):
@staticmethod
def make_from_file_on_disk(audio_path):
return {"path": audio_path}



media_factory = {
Expand Down
49 changes: 32 additions & 17 deletions src/core/operators/audio_vec_embedding_clap.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,68 @@
"""Operator to get audio representation using LAION-CLAP - https://github.com/LAION-AI/CLAP. """
"""
Operator to get audio representation using LAION-CLAP - https://huggingface.co/laion/larger_clap_general
"""

def initialize(param):
"""
Initializes the operator.
Args:
param (dict): A dict to initialize and load the the model.
param (dict): A dict to initialize and load the model.
"""
global model
global librosa
global np
global contextmanager
global os
global model, processor, librosa, contextmanager, os, torch, device

import numpy as np
import librosa
from contextlib import contextmanager
import os
import laion_clap
from transformers import ClapModel, ClapProcessor
import torch

# Load the model and processor
model = ClapModel.from_pretrained("laion/larger_clap_general")
processor = ClapProcessor.from_pretrained("laion/larger_clap_general")

model = laion_clap.CLAP_Module()
model.load_ckpt() # load the best checkpoint (HTSAT model) in the paper.
print("model successfully downloaded")
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("audio CLAP Model successfully initialized and loaded onto", device)


def run(audio_file):
"""
Runs the operator and compute inference on the audio file.
Runs the operator and computes inference on the audio file.
Args:
audio_file (dict): `AudioFactory` file object.
Returns:
audio_emb (numpy.ndarray): A 512-length vector embedding representing the audio.
audio_emb (list): A 512-length vector embedding representing the audio.
"""
audio = audio_file["path"]

@contextmanager
def audio_load(fname):
"""
Loads audio and removes the file after use.
Args:
fname (str): Path to the audio file.
Yields:
numpy.ndarray: Loaded audio data.
"""
a, _ = librosa.load(fname, sr=48000)
try:
yield a
finally:
os.remove(fname)

with audio_load(audio) as audio_var:
query_audio = audio_var.reshape(1, -1)
audio_emb = model.get_audio_embedding_from_data(x = query_audio, use_tensor=False)
audio_emb = audio_emb.reshape(-1)
inputs = processor(audios=audio_var, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
audio_emb = model.get_audio_features(**inputs)
audio_emb = audio_emb.squeeze(0).tolist()
return audio_emb
4 changes: 2 additions & 2 deletions src/core/operators/audio_vec_embedding_clap_requirements.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
laion-clap==1.1.6
librosa==0.10.2.post1
torchvision==0.19.0
transformers==4.44.0
torch==2.4.0
Loading
Loading