Skip to content

Commit

Permalink
refactor: migrate CLAP operator to Hugging Face Transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaithanya512 committed Sep 9, 2024
1 parent 0acfdb7 commit 9e5195a
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 463 deletions.
49 changes: 32 additions & 17 deletions src/core/operators/audio_vec_embedding_clap.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,68 @@
"""Operator to get audio representation using LAION-CLAP - https://github.com/LAION-AI/CLAP. """
"""
Operator to get audio representation using LAION-CLAP - https://huggingface.co/laion/larger_clap_general
"""

def initialize(param):
"""
Initializes the operator.
Args:
param (dict): A dict to initialize and load the the model.
param (dict): A dict to initialize and load the model.
"""
global model
global librosa
global np
global contextmanager
global os
global model, processor, librosa, contextmanager, os, torch, device

import numpy as np
import librosa
from contextlib import contextmanager
import os
import laion_clap
from transformers import ClapModel, ClapProcessor
import torch

# Load the model and processor
model = ClapModel.from_pretrained("laion/larger_clap_general")
processor = ClapProcessor.from_pretrained("laion/larger_clap_general")

model = laion_clap.CLAP_Module()
model.load_ckpt() # load the best checkpoint (HTSAT model) in the paper.
print("model successfully downloaded")
# Set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

print("audio CLAP Model successfully initialized and loaded onto", device)


def run(audio_file):
"""
Runs the operator and compute inference on the audio file.
Runs the operator and computes inference on the audio file.
Args:
audio_file (dict): `AudioFactory` file object.
Returns:
audio_emb (numpy.ndarray): A 512-length vector embedding representing the audio.
audio_emb (list): A 512-length vector embedding representing the audio.
"""
audio = audio_file["path"]

@contextmanager
def audio_load(fname):
"""
Loads audio and removes the file after use.
Args:
fname (str): Path to the audio file.
Yields:
numpy.ndarray: Loaded audio data.
"""
a, _ = librosa.load(fname, sr=48000)
try:
yield a
finally:
os.remove(fname)

with audio_load(audio) as audio_var:
query_audio = audio_var.reshape(1, -1)
audio_emb = model.get_audio_embedding_from_data(x = query_audio, use_tensor=False)
audio_emb = audio_emb.reshape(-1)
inputs = processor(audios=audio_var, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
audio_emb = model.get_audio_features(**inputs)
audio_emb = audio_emb.squeeze(0).tolist()
return audio_emb
4 changes: 2 additions & 2 deletions src/core/operators/audio_vec_embedding_clap_requirements.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
laion-clap==1.1.6
librosa==0.10.2.post1
torchvision==0.19.0
transformers==4.44.0
torch==2.4.0
Loading

0 comments on commit 9e5195a

Please sign in to comment.