Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] add (LAION)-CLAP model embedding to FAD calculation #21

Merged
merged 4 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 152 additions & 91 deletions frechet_audio_distance/fad.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,23 @@
"""
import os
import numpy as np
import resampy
import soundfile as sf
import torch
from torch import nn
import laion_clap

from multiprocessing.dummy import Pool as ThreadPool
from scipy import linalg
from torch import nn
from tqdm import tqdm
import soundfile as sf
import resampy
from multiprocessing.dummy import Pool as ThreadPool
from .models.pann import Cnn14_8k, Cnn14_16k, Cnn14

from .models.pann import Cnn14, Cnn14_8k, Cnn14_16k

SAMPLE_RATE = 16000
# SAMPLE_RATE = 16000
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is used anywhere after your changes on passing in sample rate as argument.
In this case, we can remove this line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can remove it. I left it in case, for some reason, you didn't agree with the change



def load_audio_task(fname, dtype="float32"):
def load_audio_task(fname, sample_rate, dtype="float32"):
# print("LOAD AUDIO TASK")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line can be removed too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I forgot to remove the debugging print

if dtype not in ['float64', 'float32', 'int32', 'int16']:
raise ValueError(f"dtype not supported: {dtype}")

Expand All @@ -30,38 +33,55 @@ def load_audio_task(fname, dtype="float32"):
wav_data = wav_data / 32768.0
elif dtype == 'int32':
wav_data = wav_data / float(2**31)

# Convert to mono
if len(wav_data.shape) > 1:
wav_data = np.mean(wav_data, axis=1)

if sr != SAMPLE_RATE:
wav_data = resampy.resample(wav_data, sr, SAMPLE_RATE)
if sr != sample_rate:
wav_data = resampy.resample(wav_data, sr, sample_rate)

return wav_data


class FrechetAudioDistance:
def __init__(
self,
self,
ckpt_dir=None,
model_name="vggish",
model_name="vggish",
submodel_name="630k-audioset", # only for CLAP
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might introduce too many optional arguments for the instantiation of FrechetAudioDistance. While it is clear with comments right now, it might get bloated if we add in more models that need different arguments in future.

Moving forward, we probably need to introduce config files for different embedding models.
Let's leave it as is for now, and add a note here for future code refactoring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess once we also add OpenL3 we will have to do some refactoring to make FrechetAudioDistance easy to understand.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We might need to define some data classes for different models, one example might be the different types of model configs in huggingface's transformers, and we can keep them as JSON files.
If you have better ideas feel free to suggest / contribute too! :)

sample_rate=16000,
use_pca=False,
use_activation=False,
verbose=False,
audio_load_worker=8
):
assert model_name in ["vggish", "pann"], "model_name must be either 'vggish' or 'pann'"
use_pca=False, # only for VGGish
use_activation=False, # only for VGGish
verbose=False,
audio_load_worker=8,
enable_fusion=False, # only for CLAP
):
"""Initialize FAD

ckpt_dir: folder where the downloaded checkpoints are stored
model_name: one between vggish, pann or clap
submodel_name: only for clap models - determines which checkpoint to use. options: ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"]
sample_rate: one between [8000, 16000, 32000, 48000]. depending on the model set the sample rate to use
use_pca: whether to apply PCA to the vggish embeddings
use_activation: whether to use the output activation in vggish
enable_fusion: whether to use fusion for clap models (valid depending on the specific submodel used)
"""
assert model_name in ["vggish", "pann", "clap"], "model_name must be either 'vggish', 'pann' or 'clap"
if model_name == "vggish":
assert sample_rate == 16000, "sample_rate must be 16000"
elif model_name == "pann":
assert sample_rate in [8000, 16000, 32000], "sample_rate must be 8000, 16000 or 32000"
elif model_name == "clap":
assert sample_rate == 48000, "sample_rate must be 48000"
assert submodel_name in ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"]
self.model_name = model_name
self.submodel_name = submodel_name
self.sample_rate = sample_rate
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.verbose = verbose
self.audio_load_worker = audio_load_worker
self.enable_fusion = enable_fusion
if ckpt_dir is not None:
os.makedirs(ckpt_dir, exist_ok=True)
torch.hub.set_dir(ckpt_dir)
Expand All @@ -70,85 +90,119 @@ def __init__(
# by default `ckpt_dir` is `torch.hub.get_dir()`
self.ckpt_dir = torch.hub.get_dir()
self.__get_model(model_name=model_name, use_pca=use_pca, use_activation=use_activation)



def __get_model(self, model_name="vggish", use_pca=False, use_activation=False):
"""
Params:
-- x : Either
(i) a string which is the directory of a set of audio files, or
(ii) a np.ndarray of shape (num_samples, sample_length)
"""
# vggish
if model_name == "vggish":
# S. Hershey et al., "CNN Architectures for Large-Scale Audio Classification", ICASSP 2017
self.model = torch.hub.load(repo_or_dir='harritaylor/torchvggish', model='vggish')
if not use_pca:
self.model.postprocess = False
if not use_activation:
self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1])

# pann
elif model_name == "pann":
# Kong et al., "PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition", IEEE/ACM Transactions on Audio, Speech, and Language Processing 28 (2020)

# choose the right checkpoint and model based on sample_rate
if self.sample_rate == 8000:
model_path = os.path.join(self.ckpt_dir, "Cnn14_8k_mAP%3D0.416.pth")
if not(os.path.exists(model_path)):
if self.verbose:
print("[Frechet Audio Distance] Downloading {}...".format(model_path))
torch.hub.download_url_to_file(
url='https://zenodo.org/record/3987831/files/Cnn14_8k_mAP%3D0.416.pth',
dst=model_path
)
download_name = "Cnn14_8k_mAP%3D0.416.pth"
self.model = Cnn14_8k(
sample_rate=8000,
window_size=256,
hop_size=80,
mel_bins=64,
fmin=50,
fmax=4000,
sample_rate=8000,
window_size=256,
hop_size=80,
mel_bins=64,
fmin=50,
fmax=4000,
classes_num=527
)
elif self.sample_rate == 16000:
model_path = os.path.join(self.ckpt_dir, "Cnn14_16k_mAP%3D0.438.pth")
if not(os.path.exists(model_path)):
if self.verbose:
print("[Frechet Audio Distance] Downloading {}...".format(model_path))
torch.hub.download_url_to_file(
url='https://zenodo.org/record/3987831/files/Cnn14_16k_mAP%3D0.438.pth',
dst=model_path
)
download_name = "Cnn14_16k_mAP%3D0.438.pth"
self.model = Cnn14_16k(
sample_rate=16000,
window_size=512,
hop_size=160,
mel_bins=64,
fmin=50,
fmax=8000,
sample_rate=16000,
window_size=512,
hop_size=160,
mel_bins=64,
fmin=50,
fmax=8000,
classes_num=527
)
elif self.sample_rate == 32000:
model_path = os.path.join(self.ckpt_dir, "Cnn14_mAP%3D0.431.pth")
if not(os.path.exists(model_path)):
if self.verbose:
print("[Frechet Audio Distance] Downloading {}...".format(model_path))
torch.hub.download_url_to_file(
url='https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth',
dst=model_path
)
download_name = "Cnn14_mAP%3D0.431.pth"
self.model = Cnn14(
sample_rate=32000,
window_size=1024,
hop_size=320,
mel_bins=64,
fmin=50,
fmax=16000,
sample_rate=32000,
window_size=1024,
hop_size=320,
mel_bins=64,
fmin=50,
fmax=16000,
classes_num=527
)

model_path = os.path.join(self.ckpt_dir, download_name)

# download checkpoint
if not (os.path.exists(model_path)):
if self.verbose:
print("[Frechet Audio Distance] Downloading {}...".format(model_path))
torch.hub.download_url_to_file(
url=f"https://zenodo.org/record/3987831/files/{download_name}",
dst=model_path
)

# load checkpoint
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model'])
# clap
elif model_name == "clap":
# choose the right checkpoint and model
if self.submodel_name == "630k-audioset":
if self.enable_fusion:
download_name = "630k-audioset-fusion-best.pt"
else:
download_name = "630k-audioset-best.pt"
elif self.submodel_name == "630k":
if self.enable_fusion:
download_name = "630k-fusion-best.pt"
else:
download_name = "630k-best.pt"
elif self.submodel_name == "music_audioset":
download_name = "music_audioset_epoch_15_esc_90.14.pt"
elif self.submodel_name == "music_speech":
download_name = "music_speech_epoch_15_esc_89.25.pt"
elif self.submodel_name == "music_speech_audioset":
download_name = "music_speech_audioset_epoch_15_esc_89.98.pt"

model_path = os.path.join(self.ckpt_dir, download_name)

# download checkpoint
if not (os.path.exists(model_path)):
if self.verbose:
print("[Frechet Audio Distance] Downloading {}...".format(model_path))
torch.hub.download_url_to_file(
url=f"https://huggingface.co/lukewys/laion_clap/resolve/main/{download_name}",
dst=model_path
)

# init model and load checkpoint
if self.submodel_name in ["630k-audioset", "630k"]:
self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion,
device=self.device)
elif self.submodel_name in ["music_audioset", "music_speech", "music_speech_audioset"]:
self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion,
amodel='HTSAT-base',
device=self.device)
self.model.load_ckpt(model_path)

self.model.eval()
def get_embeddings(self, x, sr=SAMPLE_RATE):

def get_embeddings(self, x, sr):
"""
Get embeddings using VGGish model.
Params:
Expand All @@ -162,28 +216,36 @@ def get_embeddings(self, x, sr=SAMPLE_RATE):
embd = self.model.forward(audio, sr)
elif self.model_name == "pann":
with torch.no_grad():
out = self.model(torch.tensor(audio).float().unsqueeze(0), None)
audio = torch.tensor(audio).float().unsqueeze(0)
out = self.model(audio, None)
embd = out['embedding'].data[0]
elif self.model_name == "clap":
audio = torch.tensor(audio).float().unsqueeze(0)
embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True)

if self.device == torch.device('cuda'):
embd = embd.cpu()
embd = embd.detach().numpy()

if torch.is_tensor(embd):
embd = embd.detach().numpy()

embd_lst.append(embd)
except Exception as e:
print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e)))

return np.concatenate(embd_lst, axis=0)

def calculate_embd_statistics(self, embd_lst):
if isinstance(embd_lst, list):
embd_lst = np.array(embd_lst)
mu = np.mean(embd_lst, axis=0)
sigma = np.cov(embd_lst, rowvar=False)
return mu, sigma

def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
"""
Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py

Numpy implementation of the Frechet Distance.
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
and X_2 ~ N(mu_2, C_2) is
Expand Down Expand Up @@ -219,7 +281,7 @@ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
if not np.isfinite(covmean).all():
msg = ('fid calculation produces singular product; '
'adding %s to diagonal of cov estimates') % eps
'adding %s to diagonal of cov estimates') % eps
print(msg)
offset = np.eye(sigma1.shape[0]) * eps
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
Expand All @@ -235,7 +297,7 @@ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):

return (diff.dot(diff) + np.trace(sigma1)
+ np.trace(sigma2) - 2 * tr_covmean)

def __load_audio_files(self, dir, dtype="float32"):
task_results = []

Expand All @@ -249,24 +311,23 @@ def update(*a):
print("[Frechet Audio Distance] Loading audio from {}...".format(dir))
for fname in os.listdir(dir):
res = pool.apply_async(
load_audio_task,
args=(os.path.join(dir, fname), dtype,),
load_audio_task,
args=(os.path.join(dir, fname), self.sample_rate, dtype),
callback=update
)
task_results.append(res)
pool.close()
pool.join()

return [k.get() for k in task_results]
pool.join()

return [k.get() for k in task_results]

def score(self,
background_dir,
eval_dir,
background_embds_path=None,
eval_embds_path=None,
dtype="float32"
):
def score(self,
background_dir,
eval_dir,
background_embds_path=None,
eval_embds_path=None,
dtype="float32"
):
"""
Computes the Frechet Audio Distance (FAD) between two directories of audio files.

Expand All @@ -288,7 +349,7 @@ def score(self,
embds_background = np.load(background_embds_path)
else:
audio_background = self.__load_audio_files(background_dir, dtype=dtype)
embds_background = self.get_embeddings(audio_background)
embds_background = self.get_embeddings(audio_background, sr=self.sample_rate)
if background_embds_path:
os.makedirs(os.path.dirname(background_embds_path), exist_ok=True)
np.save(background_embds_path, embds_background)
Expand All @@ -300,7 +361,7 @@ def score(self,
embds_eval = np.load(eval_embds_path)
else:
audio_eval = self.__load_audio_files(eval_dir, dtype=dtype)
embds_eval = self.get_embeddings(audio_eval)
embds_eval = self.get_embeddings(audio_eval, sr=self.sample_rate)
if eval_embds_path:
os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True)
np.save(eval_embds_path, embds_eval)
Expand All @@ -318,13 +379,13 @@ def score(self,
mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval)

fad_score = self.calculate_frechet_distance(
mu_background,
sigma_background,
mu_eval,
mu_background,
sigma_background,
mu_eval,
sigma_eval
)

return fad_score
except Exception as e:
print(f"[Frechet Audio Distance] An error occurred: {e}")
return -1
return -1
Loading