From b20a5d74f6093f525851089dbff77bf09771a672 Mon Sep 17 00:00:00 2001 From: mcomunita Date: Tue, 10 Oct 2023 16:47:21 +0200 Subject: [PATCH 1/4] [feature] add clap to available models --- frechet_audio_distance/fad.py | 144 ++++++++++++++++++++++++---------- 1 file changed, 103 insertions(+), 41 deletions(-) diff --git a/frechet_audio_distance/fad.py b/frechet_audio_distance/fad.py index 16cdd1b..e5f9c01 100644 --- a/frechet_audio_distance/fad.py +++ b/frechet_audio_distance/fad.py @@ -13,14 +13,19 @@ from tqdm import tqdm import soundfile as sf import resampy +import wget +import laion_clap from multiprocessing.dummy import Pool as ThreadPool from .models.pann import Cnn14_8k, Cnn14_16k, Cnn14 +# from .models.laion_clap.hook import CLAP_Module +# from .models.laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict -SAMPLE_RATE = 16000 +# SAMPLE_RATE = 16000 -def load_audio_task(fname, dtype="float32"): +def load_audio_task(fname, sample_rate, dtype="float32"): + # print("LOAD AUDIO TASK") if dtype not in ['float64', 'float32', 'int32', 'int16']: raise ValueError(f"dtype not supported: {dtype}") @@ -35,8 +40,8 @@ def load_audio_task(fname, dtype="float32"): if len(wav_data.shape) > 1: wav_data = np.mean(wav_data, axis=1) - if sr != SAMPLE_RATE: - wav_data = resampy.resample(wav_data, sr, SAMPLE_RATE) + if sr != sample_rate: + wav_data = resampy.resample(wav_data, sr, sample_rate) return wav_data @@ -45,23 +50,40 @@ class FrechetAudioDistance: def __init__( self, ckpt_dir=None, - model_name="vggish", + model_name="vggish", + submodel_name="630k-audioset", # only for CLAP sample_rate=16000, - use_pca=False, - use_activation=False, + use_pca=False, # only for VGGish + use_activation=False, # only for VGGish verbose=False, - audio_load_worker=8 + audio_load_worker=8, + enable_fusion=False, # only for CLAP ): - assert model_name in ["vggish", "pann"], "model_name must be either 'vggish' or 'pann'" + """Initialize FAD + + ckpt_dir: folder where the downloaded checkpoints are stored + model_name: one between vggish, pann or clap + submodel_name: only for clap models - determines which checkpoint to use. options: ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"] + sample_rate: one between [8000, 16000, 32000, 48000]. depending on the model set the sample rate to use + use_pca: whether to apply PCA to the vggish embeddings + use_activation: whether to use the output activation in vggish + enable_fusion: whether to use fusion for clap models (valid depending on the specific submodel used) + """ + assert model_name in ["vggish", "pann", "clap"], "model_name must be either 'vggish', 'pann' or 'clap" if model_name == "vggish": assert sample_rate == 16000, "sample_rate must be 16000" elif model_name == "pann": assert sample_rate in [8000, 16000, 32000], "sample_rate must be 8000, 16000 or 32000" + elif model_name == "clap": + assert sample_rate == 48000, "sample_rate must be 48000" + assert submodel_name in ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"] self.model_name = model_name + self.submodel_name = submodel_name self.sample_rate = sample_rate self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.verbose = verbose self.audio_load_worker = audio_load_worker + self.enable_fusion = enable_fusion if ckpt_dir is not None: os.makedirs(ckpt_dir, exist_ok=True) torch.hub.set_dir(ckpt_dir) @@ -79,6 +101,7 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): (i) a string which is the directory of a set of audio files, or (ii) a np.ndarray of shape (num_samples, sample_length) """ + # vggish if model_name == "vggish": # S. Hershey et al., "CNN Architectures for Large-Scale Audio Classification", ICASSP 2017 self.model = torch.hub.load(repo_or_dir='harritaylor/torchvggish', model='vggish') @@ -86,18 +109,13 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): self.model.postprocess = False if not use_activation: self.model.embeddings = nn.Sequential(*list(self.model.embeddings.children())[:-1]) - + # pann elif model_name == "pann": # Kong et al., "PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition", IEEE/ACM Transactions on Audio, Speech, and Language Processing 28 (2020) + + # choose the right checkpoint and model based on sample_rate if self.sample_rate == 8000: - model_path = os.path.join(self.ckpt_dir, "Cnn14_8k_mAP%3D0.416.pth") - if not(os.path.exists(model_path)): - if self.verbose: - print("[Frechet Audio Distance] Downloading {}...".format(model_path)) - torch.hub.download_url_to_file( - url='https://zenodo.org/record/3987831/files/Cnn14_8k_mAP%3D0.416.pth', - dst=model_path - ) + download_name = "Cnn14_8k_mAP%3D0.416.pth" self.model = Cnn14_8k( sample_rate=8000, window_size=256, @@ -108,14 +126,7 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): classes_num=527 ) elif self.sample_rate == 16000: - model_path = os.path.join(self.ckpt_dir, "Cnn14_16k_mAP%3D0.438.pth") - if not(os.path.exists(model_path)): - if self.verbose: - print("[Frechet Audio Distance] Downloading {}...".format(model_path)) - torch.hub.download_url_to_file( - url='https://zenodo.org/record/3987831/files/Cnn14_16k_mAP%3D0.438.pth', - dst=model_path - ) + download_name = "Cnn14_16k_mAP%3D0.438.pth" self.model = Cnn14_16k( sample_rate=16000, window_size=512, @@ -126,14 +137,7 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): classes_num=527 ) elif self.sample_rate == 32000: - model_path = os.path.join(self.ckpt_dir, "Cnn14_mAP%3D0.431.pth") - if not(os.path.exists(model_path)): - if self.verbose: - print("[Frechet Audio Distance] Downloading {}...".format(model_path)) - torch.hub.download_url_to_file( - url='https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth', - dst=model_path - ) + download_name = "Cnn14_mAP%3D0.431.pth" self.model = Cnn14( sample_rate=32000, window_size=1024, @@ -143,12 +147,65 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): fmax=16000, classes_num=527 ) + + model_path = os.path.join(self.ckpt_dir, download_name) + + # download checkpoint + if not(os.path.exists(model_path)): + if self.verbose: + print("[Frechet Audio Distance] Downloading {}...".format(model_path)) + torch.hub.download_url_to_file( + url=f"https://zenodo.org/record/3987831/files/{download_name}", + dst=model_path + ) + + # load checkpoint checkpoint = torch.load(model_path, map_location=self.device) self.model.load_state_dict(checkpoint['model']) - + # clap + elif model_name == "clap": + # choose the right checkpoint and model + if self.submodel_name == "630k-audioset": + if self.enable_fusion: + download_name = "630k-audioset-fusion-best.pt" + else: + download_name = "630k-audioset-best.pt" + elif self.submodel_name == "630k": + if self.enable_fusion: + download_name = "630k-fusion-best.pt" + else: + download_name = "630k-best.pt" + elif self.submodel_name == "music_audioset": + download_name = "music_audioset_epoch_15_esc_90.14.pt" + elif self.submodel_name == "music_speech": + download_name = "music_speech_epoch_15_esc_89.25.pt" + elif self.submodel_name == "music_speech_audioset": + download_name = "music_speech_audioset_epoch_15_esc_89.98.pt" + + model_path = os.path.join(self.ckpt_dir, download_name) + + # download checkpoint + if not(os.path.exists(model_path)): + if self.verbose: + print("[Frechet Audio Distance] Downloading {}...".format(model_path)) + torch.hub.download_url_to_file( + url=f"https://huggingface.co/lukewys/laion_clap/resolve/main/{download_name}", + dst=model_path + ) + + # init model and load checkpoint + if self.submodel_name in ["630k-audioset", "630k"]: + self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion, + device=self.device) + elif self.submodel_name in ["music_audioset", "music_speech", "music_speech_audioset"]: + self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion, + amodel= 'HTSAT-base', + device=self.device) + self.model.load_ckpt(model_path) + self.model.eval() - def get_embeddings(self, x, sr=SAMPLE_RATE): + def get_embeddings(self, x, sr): """ Get embeddings using VGGish model. Params: @@ -162,11 +219,16 @@ def get_embeddings(self, x, sr=SAMPLE_RATE): embd = self.model.forward(audio, sr) elif self.model_name == "pann": with torch.no_grad(): - out = self.model(torch.tensor(audio).float().unsqueeze(0), None) + audio = torch.tensor(audio).float().unsqueeze(0) + out = self.model(audio, None) embd = out['embedding'].data[0] + elif self.model_name == "clap": + audio = np.expand_dims(audio, 0) + embd = self.model.get_audio_embedding_from_data(audio, use_tensor=False) if self.device == torch.device('cuda'): embd = embd.cpu() - embd = embd.detach().numpy() + if torch.is_tensor(embd): + embd = embd.detach().numpy() embd_lst.append(embd) except Exception as e: print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e))) @@ -250,7 +312,7 @@ def update(*a): for fname in os.listdir(dir): res = pool.apply_async( load_audio_task, - args=(os.path.join(dir, fname), dtype,), + args=(os.path.join(dir, fname), self.sample_rate, dtype), callback=update ) task_results.append(res) @@ -288,7 +350,7 @@ def score(self, embds_background = np.load(background_embds_path) else: audio_background = self.__load_audio_files(background_dir, dtype=dtype) - embds_background = self.get_embeddings(audio_background) + embds_background = self.get_embeddings(audio_background, sr=self.sample_rate) if background_embds_path: os.makedirs(os.path.dirname(background_embds_path), exist_ok=True) np.save(background_embds_path, embds_background) @@ -300,7 +362,7 @@ def score(self, embds_eval = np.load(eval_embds_path) else: audio_eval = self.__load_audio_files(eval_dir, dtype=dtype) - embds_eval = self.get_embeddings(audio_eval) + embds_eval = self.get_embeddings(audio_eval, sr=self.sample_rate) if eval_embds_path: os.makedirs(os.path.dirname(eval_embds_path), exist_ok=True) np.save(eval_embds_path, embds_eval) From 3b1c72ddaf235f1d29f9a80551405082c760af40 Mon Sep 17 00:00:00 2001 From: mcomunita Date: Tue, 10 Oct 2023 16:48:33 +0200 Subject: [PATCH 2/4] [test] add test_all notebook to test all available models configs --- test/test_all.ipynb | 628 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 628 insertions(+) create mode 100644 test/test_all.ipynb diff --git a/test/test_all.ipynb b/test/test_all.ipynb new file mode 100644 index 0000000..fdb9809 --- /dev/null +++ b/test/test_all.ipynb @@ -0,0 +1,628 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import shutil\n", + "import numpy as np\n", + "import soundfile as sf\n", + "\n", + "\n", + "module_path = os.path.abspath(os.path.join('..'))\n", + "if module_path not in sys.path:\n", + " sys.path.append(module_path)\n", + "\n", + "from frechet_audio_distance import FrechetAudioDistance\n", + "from utils import gen_sine_wave" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### VGGISH" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# STANDARD\n", + "SAMPLE_RATE = 16000\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/vggish\",\n", + " model_name=\"vggish\",\n", + " # submodel_name=\"630k-audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " use_pca=False, # for VGGish only\n", + " use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " # enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# with PCA\n", + "SAMPLE_RATE = 16000\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/vggish\",\n", + " model_name=\"vggish\",\n", + " submodel_name=\"630k-audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " use_pca=True, # for VGGish only\n", + " use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# with ACTIVATIONS\n", + "SAMPLE_RATE = 16000\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/vggish\",\n", + " model_name=\"vggish\",\n", + " submodel_name=\"630k-audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " use_pca=False, # for VGGish only\n", + " use_activation=True, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### PANN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 8kHz\n", + "\n", + "SAMPLE_RATE = 8000\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/pann\",\n", + " model_name=\"pann\",\n", + " # submodel_name=\"630k-audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " # enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 16kHz\n", + "\n", + "SAMPLE_RATE = 16000\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/pann\",\n", + " model_name=\"pann\",\n", + " # submodel_name=\"630k-audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " # enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 32kHz\n", + "\n", + "SAMPLE_RATE = 32000\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/pann\",\n", + " model_name=\"pann\",\n", + " # submodel_name=\"630k-audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " # enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CLAP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 630k-audioset (for general audio less than 10-sec)\n", + "\n", + "SAMPLE_RATE = 48000\n", + "LENGTH_IN_SECONDS = 2\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/clap\",\n", + " model_name=\"clap\",\n", + " submodel_name=\"630k-audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 630k-audioset + fusion (for general audio with variable-length)\n", + "\n", + "SAMPLE_RATE = 48000\n", + "LENGTH_IN_SECONDS = 12\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/clap\",\n", + " model_name=\"clap\",\n", + " submodel_name=\"630k-audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " enable_fusion=True, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " # print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 630k (for general audio less than 10-sec)\n", + "\n", + "SAMPLE_RATE = 48000\n", + "LENGTH_IN_SECONDS = 2\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/clap\",\n", + " model_name=\"clap\",\n", + " submodel_name=\"630k\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " # print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 630k + fusion (for general audio with variable-length)\n", + "\n", + "SAMPLE_RATE = 48000\n", + "LENGTH_IN_SECONDS = 12\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/clap\",\n", + " model_name=\"clap\",\n", + " submodel_name=\"630k\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " enable_fusion=True, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " # print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# music_audioset (for music)\n", + "# (trained on music + Audioset + LAION-Audio-630k. The zeroshot ESC50 performance is 90.14%, the GTZAN performance is 71%.)\n", + "\n", + "SAMPLE_RATE = 48000\n", + "LENGTH_IN_SECONDS = 2\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/clap\",\n", + " model_name=\"clap\",\n", + " submodel_name=\"music_audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " # print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# music_speech (for music and speech)\n", + "# trained on music + speech + LAION-Audio-630k. The zeroshot ESC50 performance is 89.25%, the GTZAN performance is 69%.\n", + "\n", + "SAMPLE_RATE = 48000\n", + "LENGTH_IN_SECONDS = 2\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/clap\",\n", + " model_name=\"clap\",\n", + " submodel_name=\"music_speech\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " # print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# music_speech_audioset (for speech, music and general audio)\n", + "# trained on music + speech + Audioset + LAION-Audio-630k. The zeroshot ESC50 performance is 89.98%, the GTZAN performance is 51%.\n", + "\n", + "SAMPLE_RATE = 48000\n", + "LENGTH_IN_SECONDS = 2\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/clap\",\n", + " model_name=\"clap\",\n", + " submodel_name=\"music_speech_audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=False, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " # print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.12" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 0542001a2a4e0e423f94ff0ab9232ab53a54dfc8 Mon Sep 17 00:00:00 2001 From: mcomunita Date: Tue, 10 Oct 2023 16:48:58 +0200 Subject: [PATCH 3/4] [update] requirements --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a9ebb6e..c6e6370 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ scipy==1.10.1 tqdm soundfile resampy -torchlibrosa \ No newline at end of file +torchlibrosa +laion_clap \ No newline at end of file From e28877848177b1d9bf6ff9a914aa16daf3e8b778 Mon Sep 17 00:00:00 2001 From: mcomunita Date: Tue, 10 Oct 2023 18:17:57 +0200 Subject: [PATCH 4/4] [bug] fix requirements for clap module --- frechet_audio_distance/fad.py | 145 +++++++++++++++++----------------- requirements.txt | 6 +- test/test_all.ipynb | 1 - 3 files changed, 77 insertions(+), 75 deletions(-) diff --git a/frechet_audio_distance/fad.py b/frechet_audio_distance/fad.py index e5f9c01..8e260ea 100644 --- a/frechet_audio_distance/fad.py +++ b/frechet_audio_distance/fad.py @@ -7,19 +7,17 @@ """ import os import numpy as np -import torch -from torch import nn -from scipy import linalg -from tqdm import tqdm -import soundfile as sf import resampy -import wget +import soundfile as sf +import torch import laion_clap + from multiprocessing.dummy import Pool as ThreadPool -from .models.pann import Cnn14_8k, Cnn14_16k, Cnn14 -# from .models.laion_clap.hook import CLAP_Module -# from .models.laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict +from scipy import linalg +from torch import nn +from tqdm import tqdm +from .models.pann import Cnn14, Cnn14_8k, Cnn14_16k # SAMPLE_RATE = 16000 @@ -35,7 +33,7 @@ def load_audio_task(fname, sample_rate, dtype="float32"): wav_data = wav_data / 32768.0 elif dtype == 'int32': wav_data = wav_data / float(2**31) - + # Convert to mono if len(wav_data.shape) > 1: wav_data = np.mean(wav_data, axis=1) @@ -48,19 +46,19 @@ def load_audio_task(fname, sample_rate, dtype="float32"): class FrechetAudioDistance: def __init__( - self, + self, ckpt_dir=None, model_name="vggish", - submodel_name="630k-audioset", # only for CLAP + submodel_name="630k-audioset", # only for CLAP sample_rate=16000, - use_pca=False, # only for VGGish - use_activation=False, # only for VGGish - verbose=False, + use_pca=False, # only for VGGish + use_activation=False, # only for VGGish + verbose=False, audio_load_worker=8, - enable_fusion=False, # only for CLAP - ): + enable_fusion=False, # only for CLAP + ): """Initialize FAD - + ckpt_dir: folder where the downloaded checkpoints are stored model_name: one between vggish, pann or clap submodel_name: only for clap models - determines which checkpoint to use. options: ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"] @@ -92,8 +90,7 @@ def __init__( # by default `ckpt_dir` is `torch.hub.get_dir()` self.ckpt_dir = torch.hub.get_dir() self.__get_model(model_name=model_name, use_pca=use_pca, use_activation=use_activation) - - + def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): """ Params: @@ -117,45 +114,45 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): if self.sample_rate == 8000: download_name = "Cnn14_8k_mAP%3D0.416.pth" self.model = Cnn14_8k( - sample_rate=8000, - window_size=256, - hop_size=80, - mel_bins=64, - fmin=50, - fmax=4000, + sample_rate=8000, + window_size=256, + hop_size=80, + mel_bins=64, + fmin=50, + fmax=4000, classes_num=527 ) elif self.sample_rate == 16000: download_name = "Cnn14_16k_mAP%3D0.438.pth" self.model = Cnn14_16k( - sample_rate=16000, - window_size=512, - hop_size=160, - mel_bins=64, - fmin=50, - fmax=8000, + sample_rate=16000, + window_size=512, + hop_size=160, + mel_bins=64, + fmin=50, + fmax=8000, classes_num=527 ) elif self.sample_rate == 32000: download_name = "Cnn14_mAP%3D0.431.pth" self.model = Cnn14( - sample_rate=32000, - window_size=1024, - hop_size=320, - mel_bins=64, - fmin=50, - fmax=16000, + sample_rate=32000, + window_size=1024, + hop_size=320, + mel_bins=64, + fmin=50, + fmax=16000, classes_num=527 ) model_path = os.path.join(self.ckpt_dir, download_name) # download checkpoint - if not(os.path.exists(model_path)): + if not (os.path.exists(model_path)): if self.verbose: print("[Frechet Audio Distance] Downloading {}...".format(model_path)) torch.hub.download_url_to_file( - url=f"https://zenodo.org/record/3987831/files/{download_name}", + url=f"https://zenodo.org/record/3987831/files/{download_name}", dst=model_path ) @@ -183,28 +180,28 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): download_name = "music_speech_audioset_epoch_15_esc_89.98.pt" model_path = os.path.join(self.ckpt_dir, download_name) - + # download checkpoint - if not(os.path.exists(model_path)): + if not (os.path.exists(model_path)): if self.verbose: print("[Frechet Audio Distance] Downloading {}...".format(model_path)) torch.hub.download_url_to_file( - url=f"https://huggingface.co/lukewys/laion_clap/resolve/main/{download_name}", + url=f"https://huggingface.co/lukewys/laion_clap/resolve/main/{download_name}", dst=model_path ) - + # init model and load checkpoint if self.submodel_name in ["630k-audioset", "630k"]: - self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion, + self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion, device=self.device) elif self.submodel_name in ["music_audioset", "music_speech", "music_speech_audioset"]: - self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion, - amodel= 'HTSAT-base', + self.model = laion_clap.CLAP_Module(enable_fusion=self.enable_fusion, + amodel='HTSAT-base', device=self.device) self.model.load_ckpt(model_path) - + self.model.eval() - + def get_embeddings(self, x, sr): """ Get embeddings using VGGish model. @@ -223,29 +220,32 @@ def get_embeddings(self, x, sr): out = self.model(audio, None) embd = out['embedding'].data[0] elif self.model_name == "clap": - audio = np.expand_dims(audio, 0) - embd = self.model.get_audio_embedding_from_data(audio, use_tensor=False) + audio = torch.tensor(audio).float().unsqueeze(0) + embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True) + if self.device == torch.device('cuda'): embd = embd.cpu() + if torch.is_tensor(embd): embd = embd.detach().numpy() + embd_lst.append(embd) except Exception as e: print("[Frechet Audio Distance] get_embeddings throw an exception: {}".format(str(e))) - + return np.concatenate(embd_lst, axis=0) - + def calculate_embd_statistics(self, embd_lst): if isinstance(embd_lst, list): embd_lst = np.array(embd_lst) mu = np.mean(embd_lst, axis=0) sigma = np.cov(embd_lst, rowvar=False) return mu, sigma - + def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): """ Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py - + Numpy implementation of the Frechet Distance. The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is @@ -281,7 +281,7 @@ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) if not np.isfinite(covmean).all(): msg = ('fid calculation produces singular product; ' - 'adding %s to diagonal of cov estimates') % eps + 'adding %s to diagonal of cov estimates') % eps print(msg) offset = np.eye(sigma1.shape[0]) * eps covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) @@ -297,7 +297,7 @@ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean) - + def __load_audio_files(self, dir, dtype="float32"): task_results = [] @@ -311,24 +311,23 @@ def update(*a): print("[Frechet Audio Distance] Loading audio from {}...".format(dir)) for fname in os.listdir(dir): res = pool.apply_async( - load_audio_task, - args=(os.path.join(dir, fname), self.sample_rate, dtype), + load_audio_task, + args=(os.path.join(dir, fname), self.sample_rate, dtype), callback=update ) task_results.append(res) pool.close() - pool.join() - - return [k.get() for k in task_results] + pool.join() + return [k.get() for k in task_results] - def score(self, - background_dir, - eval_dir, - background_embds_path=None, - eval_embds_path=None, - dtype="float32" - ): + def score(self, + background_dir, + eval_dir, + background_embds_path=None, + eval_embds_path=None, + dtype="float32" + ): """ Computes the Frechet Audio Distance (FAD) between two directories of audio files. @@ -380,13 +379,13 @@ def score(self, mu_eval, sigma_eval = self.calculate_embd_statistics(embds_eval) fad_score = self.calculate_frechet_distance( - mu_background, - sigma_background, - mu_eval, + mu_background, + sigma_background, + mu_eval, sigma_eval ) return fad_score except Exception as e: print(f"[Frechet Audio Distance] An error occurred: {e}") - return -1 \ No newline at end of file + return -1 diff --git a/requirements.txt b/requirements.txt index c6e6370..baaf3a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,8 @@ tqdm soundfile resampy torchlibrosa -laion_clap \ No newline at end of file +torchvision + +laion_clap +transformers<=4.30.2 +torchaudio \ No newline at end of file diff --git a/test/test_all.ipynb b/test/test_all.ipynb index fdb9809..4251b6e 100644 --- a/test/test_all.ipynb +++ b/test/test_all.ipynb @@ -12,7 +12,6 @@ "import numpy as np\n", "import soundfile as sf\n", "\n", - "\n", "module_path = os.path.abspath(os.path.join('..'))\n", "if module_path not in sys.path:\n", " sys.path.append(module_path)\n",