-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature] add (LAION)-CLAP model embedding to FAD calculation #21
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,20 +7,23 @@ | |
""" | ||
import os | ||
import numpy as np | ||
import resampy | ||
import soundfile as sf | ||
import torch | ||
from torch import nn | ||
import laion_clap | ||
|
||
from multiprocessing.dummy import Pool as ThreadPool | ||
from scipy import linalg | ||
from torch import nn | ||
from tqdm import tqdm | ||
import soundfile as sf | ||
import resampy | ||
from multiprocessing.dummy import Pool as ThreadPool | ||
from .models.pann import Cnn14_8k, Cnn14_16k, Cnn14 | ||
|
||
from .models.pann import Cnn14, Cnn14_8k, Cnn14_16k | ||
|
||
SAMPLE_RATE = 16000 | ||
# SAMPLE_RATE = 16000 | ||
|
||
|
||
def load_audio_task(fname, dtype="float32"): | ||
def load_audio_task(fname, sample_rate, dtype="float32"): | ||
# print("LOAD AUDIO TASK") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line can be removed too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I forgot to remove the debugging print |
||
if dtype not in ['float64', 'float32', 'int32', 'int16']: | ||
raise ValueError(f"dtype not supported: {dtype}") | ||
|
||
|
@@ -30,38 +33,55 @@ def load_audio_task(fname, dtype="float32"): | |
wav_data = wav_data / 32768.0 | ||
elif dtype == 'int32': | ||
wav_data = wav_data / float(2**31) | ||
|
||
# Convert to mono | ||
if len(wav_data.shape) > 1: | ||
wav_data = np.mean(wav_data, axis=1) | ||
|
||
if sr != SAMPLE_RATE: | ||
wav_data = resampy.resample(wav_data, sr, SAMPLE_RATE) | ||
if sr != sample_rate: | ||
wav_data = resampy.resample(wav_data, sr, sample_rate) | ||
|
||
return wav_data | ||
|
||
|
||
class FrechetAudioDistance: | ||
def __init__( | ||
self, | ||
self, | ||
ckpt_dir=None, | ||
model_name="vggish", | ||
model_name="vggish", | ||
submodel_name="630k-audioset", # only for CLAP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might introduce too many optional arguments for the instantiation of Moving forward, we probably need to introduce config files for different embedding models. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess once we also add OpenL3 we will have to do some refactoring to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. We might need to define some data classes for different models, one example might be the different types of model configs in huggingface's |
||
sample_rate=16000, | ||
use_pca=False, | ||
use_activation=False, | ||
verbose=False, | ||
audio_load_worker=8 | ||
): | ||
assert model_name in ["vggish", "pann"], "model_name must be either 'vggish' or 'pann'" | ||
use_pca=False, # only for VGGish | ||
use_activation=False, # only for VGGish | ||
verbose=False, | ||
audio_load_worker=8, | ||
enable_fusion=False, # only for CLAP | ||
): | ||
"""Initialize FAD | ||
|
||
ckpt_dir: folder where the downloaded checkpoints are stored | ||
model_name: one between vggish, pann or clap | ||
submodel_name: only for clap models - determines which checkpoint to use. options: ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"] | ||
sample_rate: one between [8000, 16000, 32000, 48000]. depending on the model set the sample rate to use | ||
use_pca: whether to apply PCA to the vggish embeddings | ||
use_activation: whether to use the output activation in vggish | ||
enable_fusion: whether to use fusion for clap models (valid depending on the specific submodel used) | ||
""" | ||
assert model_name in ["vggish", "pann", "clap"], "model_name must be either 'vggish', 'pann' or 'clap" | ||
if model_name == "vggish": | ||
assert sample_rate == 16000, "sample_rate must be 16000" | ||
elif model_name == "pann": | ||
assert sample_rate in [8000, 16000, 32000], "sample_rate must be 8000, 16000 or 32000" | ||
elif model_name == "clap": | ||
assert sample_rate == 48000, "sample_rate must be 48000" | ||
assert submodel_name in ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"] | ||
self.model_name = model_name | ||
self.submodel_name = submodel_name | ||
self.sample_rate = sample_rate | ||
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
self.verbose = verbose | ||
self.audio_load_worker = audio_load_worker | ||
self.enable_fusion = enable_fusion | ||
if ckpt_dir is not None: | ||
os.makedirs(ckpt_dir, exist_ok=True) | ||
torch.hub.set_dir(ckpt_dir) | ||
|
@@ -70,85 +90,119 @@ def __init__( | |
# by default `ckpt_dir` is `torch.hub.get_dir()` | ||
self.ckpt_dir = torch.hub.get_dir() | ||
self.__get_model(model_name=model_name, use_pca=use_pca, use_activation=use_activation) | ||
|
||
|
||
|
||
def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): | ||
""" | ||
Params: | ||
-- x : Either | ||
(i) a string which is the directory of a set of audio files, or | ||
(ii) a np.ndarray of shape (num_samples, sample_length) | ||
""" | ||
# vggish | ||
if model_name == "vggish": | ||
# S. Hershey et al., "CNN Architectures for Large-Scale Audio Classification", ICASSP 2017 | ||
self.model = torch.hub.load(repo_or_dir='harritaylor/torchvggish', model='vggish') | ||
if not use_pca: | ||
self.model.postprocess = False | ||
if not use_activation: | ||
self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1]) | ||
|
||
# pann | ||
elif model_name == "pann": | ||
# Kong et al., "PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition", IEEE/ACM Transactions on Audio, Speech, and Language Processing 28 (2020) | ||
|
||
# choose the right checkpoint and model based on sample_rate | ||
if self.sample_rate == 8000: | ||
model_path = os.path.join(self.ckpt_dir, "Cnn14_8k_mAP%3D0.416.pth") | ||
if not(os.path.exists(model_path)): | ||
if self.verbose: | ||
print("[Frechet Audio Distance] Downloading {}...".format(model_path)) | ||
torch.hub.download_url_to_file( | ||
url='https://zenodo.org/record/3987831/files/Cnn14_8k_mAP%3D0.416.pth', | ||
dst=model_path | ||
) | ||
download_name = "Cnn14_8k_mAP%3D0.416.pth" | ||
self.model = Cnn14_8k( | ||
sample_rate=8000, | ||
window_size=256, | ||
hop_size=80, | ||
mel_bins=64, | ||
fmin=50, | ||
fmax=4000, | ||
sample_rate=8000, | ||
window_size=256, | ||
hop_size=80, | ||
mel_bins=64, | ||
fmin=50, | ||
fmax=4000, | ||
classes_num=527 | ||
) | ||
elif self.sample_rate == 16000: | ||
model_path = os.path.join(self.ckpt_dir, "Cnn14_16k_mAP%3D0.438.pth") | ||
if not(os.path.exists(model_path)): | ||
if self.verbose: | ||
print("[Frechet Audio Distance] Downloading {}...".format(model_path)) | ||
torch.hub.download_url_to_file( | ||
url='https://zenodo.org/record/3987831/files/Cnn14_16k_mAP%3D0.438.pth', | ||
dst=model_path | ||
) | ||
download_name = "Cnn14_16k_mAP%3D0.438.pth" | ||
self.model = Cnn14_16k( | ||
sample_rate=16000, | ||
window_size=512, | ||
hop_size=160, | ||
mel_bins=64, | ||
fmin=50, | ||
fmax=8000, | ||
sample_rate=16000, | ||
window_size=512, | ||
hop_size=160, | ||
mel_bins=64, | ||
fmin=50, | ||
fmax=8000, | ||
classes_num=527 | ||
) | ||
elif self.sample_rate == 32000: | ||
model_path = os.path.join(self.ckpt_dir, "Cnn14_mAP%3D0.431.pth") | ||
if not(os.path.exists(model_path)): | ||
if self.verbose: | ||
print("[Frechet Audio Distance] Downloading {}...".format(model_path)) | ||
torch.hub.download_url_to_file( | ||
url='https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth', | ||
dst=model_path | ||
) | ||
download_name = "Cnn14_mAP%3D0.431.pth" | ||
self.model = Cnn14( | ||
sample_rate=32000, | ||
window_size=1024, | ||
hop_size=320, | ||
mel_bins=64, | ||
fmin=50, | ||
fmax=16000, | ||
sample_rate=32000, | ||
window_size=1024, | ||
hop_size=320, | ||
mel_bins=64, | ||
fmin=50, | ||
fmax=16000, | ||
classes_num=527 | ||
) | ||
|
||
model_path = os.path.join(self.ckpt_dir, download_name) | ||
|
||
# download checkpoint | ||
if not (os.path.exists(model_path)): | ||
if self.verbose: | ||
print("[Frechet Audio Distance] Downloading {}...".format(model_path)) | ||
torch.hub.download_url_to_file( | ||
url=f"https://zenodo.org/record/3987831/files/{download_name}", | ||
dst=model_path | ||
) | ||
|
||
# load checkpoint | ||
checkpoint = torch.load(model_path, map_location=self.device) | ||
self.model.load_state_dict(checkpoint['model']) | ||
# clap | ||
elif model_name == "clap": | ||
# choose the right checkpoint and model | ||
if self.submodel_name == "630k-audioset": | ||
if self.enable_fusion: | ||
download_name = "630k-audioset-fusion-best.pt" | ||
else: | ||
download_name = "630k-audioset-best.pt" | ||
elif self.submodel_name == "630k": | ||
if self.enable_fusion: | ||
download_name = "630k-fusion-best.pt" | ||
else: | ||
download_name = "630k-best.pt" | ||
elif self.submodel_name == "music_audioset": | ||
download_name = "music_audioset_epoch_15_esc_90.14.pt" | ||
elif self.submodel_name == "music_speech": | ||
download_name = "music_speech_epoch_15_esc_89.25.pt" | ||
elif self.submodel_name == "music_speech_audioset": | ||
download_name = "music_speech_audioset_epoch_15_esc_89.98.pt" | ||
|
||
model_path = os.path.join(self.ckpt_dir, download_name) | ||
|
||
# download checkpoint | ||
if not (os.path.exists(model_path)): | ||
if self.verbose: | ||
print("[Frechet Audio Distance] Downloading {}...".format(model_path)) | ||
torch.hub.download_url_to_file( | ||
url=f"https://huggingface.co/lukewys/laion_clap/resolve/main/{download_name}", | ||
dst=model_path | ||
) | ||
|
||
# init model and load checkpoint | ||
if self.submodel_name in ["630k-audioset", "630k"]: | ||
self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion, | ||
device=self.device) | ||
elif self.submodel_name in ["music_audioset", "music_speech", "music_speech_audioset"]: | ||
self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion, | ||
amodel='HTSAT-base', | ||
device=self.device) | ||
self.model.load_ckpt(model_path) | ||
|
||
self.model.eval() | ||
def get_embeddings(self, x, sr=SAMPLE_RATE): | ||
|
||
def get_embeddings(self, x, sr): | ||
""" | ||
Get embeddings using VGGish model. | ||
Params: | ||
|
@@ -162,28 +216,36 @@ def get_embeddings(self, x, sr=SAMPLE_RATE): | |
embd = self.model.forward(audio, sr) | ||
elif self.model_name == "pann": | ||
with torch.no_grad(): | ||
out = self.model(torch.tensor(audio).float().unsqueeze(0), None) | ||
audio = torch.tensor(audio).float().unsqueeze(0) | ||
out = self.model(audio, None) | ||
embd = out['embedding'].data[0] | ||
elif self.model_name == "clap": | ||
audio = torch.tensor(audio).float().unsqueeze(0) | ||
embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True) | ||
|
||
if self.device == torch.device('cuda'): | ||
embd = embd.cpu() | ||
embd = embd.detach().numpy() | ||
|
||
if torch.is_tensor(embd): | ||
embd = embd.detach().numpy() | ||
|
||
embd_lst.append(embd) | ||
except Exception as e: | ||
print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e))) | ||
|
||
return np.concatenate(embd_lst, axis=0) | ||
|
||
def calculate_embd_statistics(self, embd_lst): | ||
if isinstance(embd_lst, list): | ||
embd_lst = np.array(embd_lst) | ||
mu = np.mean(embd_lst, axis=0) | ||
sigma = np.cov(embd_lst, rowvar=False) | ||
return mu, sigma | ||
|
||
def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): | ||
""" | ||
Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py | ||
|
||
Numpy implementation of the Frechet Distance. | ||
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) | ||
and X_2 ~ N(mu_2, C_2) is | ||
|
@@ -219,7 +281,7 @@ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): | |
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) | ||
if not np.isfinite(covmean).all(): | ||
msg = ('fid calculation produces singular product; ' | ||
'adding %s to diagonal of cov estimates') % eps | ||
'adding %s to diagonal of cov estimates') % eps | ||
print(msg) | ||
offset = np.eye(sigma1.shape[0]) * eps | ||
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) | ||
|
@@ -235,7 +297,7 @@ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): | |
|
||
return (diff.dot(diff) + np.trace(sigma1) | ||
+ np.trace(sigma2) - 2 * tr_covmean) | ||
|
||
def __load_audio_files(self, dir, dtype="float32"): | ||
task_results = [] | ||
|
||
|
@@ -249,24 +311,23 @@ def update(*a): | |
print("[Frechet Audio Distance] Loading audio from {}...".format(dir)) | ||
for fname in os.listdir(dir): | ||
res = pool.apply_async( | ||
load_audio_task, | ||
args=(os.path.join(dir, fname), dtype,), | ||
load_audio_task, | ||
args=(os.path.join(dir, fname), self.sample_rate, dtype), | ||
callback=update | ||
) | ||
task_results.append(res) | ||
pool.close() | ||
pool.join() | ||
|
||
return [k.get() for k in task_results] | ||
pool.join() | ||
|
||
return [k.get() for k in task_results] | ||
|
||
def score(self, | ||
background_dir, | ||
eval_dir, | ||
background_embds_path=None, | ||
eval_embds_path=None, | ||
dtype="float32" | ||
): | ||
def score(self, | ||
background_dir, | ||
eval_dir, | ||
background_embds_path=None, | ||
eval_embds_path=None, | ||
dtype="float32" | ||
): | ||
""" | ||
Computes the Frechet Audio Distance (FAD) between two directories of audio files. | ||
|
||
|
@@ -288,7 +349,7 @@ def score(self, | |
embds_background = np.load(background_embds_path) | ||
else: | ||
audio_background = self.__load_audio_files(background_dir, dtype=dtype) | ||
embds_background = self.get_embeddings(audio_background) | ||
embds_background = self.get_embeddings(audio_background, sr=self.sample_rate) | ||
if background_embds_path: | ||
os.makedirs(os.path.dirname(background_embds_path), exist_ok=True) | ||
np.save(background_embds_path, embds_background) | ||
|
@@ -300,7 +361,7 @@ def score(self, | |
embds_eval = np.load(eval_embds_path) | ||
else: | ||
audio_eval = self.__load_audio_files(eval_dir, dtype=dtype) | ||
embds_eval = self.get_embeddings(audio_eval) | ||
embds_eval = self.get_embeddings(audio_eval, sr=self.sample_rate) | ||
if eval_embds_path: | ||
os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True) | ||
np.save(eval_embds_path, embds_eval) | ||
|
@@ -318,13 +379,13 @@ def score(self, | |
mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval) | ||
|
||
fad_score = self.calculate_frechet_distance( | ||
mu_background, | ||
sigma_background, | ||
mu_eval, | ||
mu_background, | ||
sigma_background, | ||
mu_eval, | ||
sigma_eval | ||
) | ||
|
||
return fad_score | ||
except Exception as e: | ||
print(f"[Frechet Audio Distance] An error occurred: {e}") | ||
return -1 | ||
return -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is used anywhere after your changes on passing in sample rate as argument.
In this case, we can remove this line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can remove it. I left it in case, for some reason, you didn't agree with the change