Skip to content

Commit

Permalink
[feature] add EnCodec model embeddings to FAD calculation (#23)
Browse files Browse the repository at this point in the history
* [chore] added gitignore for cleaner commiting

* [feature] added encodec embeddings to FAD calculation

* [feature] fixed-up encodec FAD and added 24khz test to notebook

* [feature] added encodec as requirement and 48khz model test to notebook

* [fix] added missing channels variable definition
  • Loading branch information
ivanlmh authored Nov 20, 2023
1 parent 47dbd5c commit cf50298
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 11 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Python Project
checkpoints/
__pycache__/
.conda/
.pytest_cache/

*.egg-info/
```
70 changes: 61 additions & 9 deletions frechet_audio_distance/fad.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

from .models.pann import Cnn14, Cnn14_8k, Cnn14_16k

from encodec import EncodecModel

def load_audio_task(fname, sample_rate, dtype="float32"):

def load_audio_task(fname, sample_rate, channels, dtype="float32"):
if dtype not in ['float64', 'float32', 'int32', 'int16']:
raise ValueError(f"dtype not supported: {dtype}")

Expand All @@ -32,7 +34,8 @@ def load_audio_task(fname, sample_rate, dtype="float32"):
wav_data = wav_data / float(2**31)

# Convert to mono
if len(wav_data.shape) > 1:
assert channels in [1, 2], "channels must be 1 or 2"
if len(wav_data.shape) > channels:
wav_data = np.mean(wav_data, axis=1)

if sr != sample_rate:
Expand All @@ -48,6 +51,7 @@ def __init__(
model_name="vggish",
submodel_name="630k-audioset", # only for CLAP
sample_rate=16000,
channels=1,
use_pca=False, # only for VGGish
use_activation=False, # only for VGGish
verbose=False,
Expand All @@ -57,24 +61,29 @@ def __init__(
"""Initialize FAD
ckpt_dir: folder where the downloaded checkpoints are stored
model_name: one between vggish, pann or clap
model_name: one between vggish, pann, clap or encodec
submodel_name: only for clap models - determines which checkpoint to use. options: ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"]
sample_rate: one between [8000, 16000, 32000, 48000]. depending on the model set the sample rate to use
use_pca: whether to apply PCA to the vggish embeddings
use_activation: whether to use the output activation in vggish
enable_fusion: whether to use fusion for clap models (valid depending on the specific submodel used)
"""
assert model_name in ["vggish", "pann", "clap"], "model_name must be either 'vggish', 'pann' or 'clap"
assert model_name in ["vggish", "pann", "clap", "encodec"], "model_name must be either 'vggish', 'pann', 'clap', or 'encodec'"
if model_name == "vggish":
assert sample_rate == 16000, "sample_rate must be 16000"
elif model_name == "pann":
assert sample_rate in [8000, 16000, 32000], "sample_rate must be 8000, 16000 or 32000"
elif model_name == "clap":
assert sample_rate == 48000, "sample_rate must be 48000"
assert submodel_name in ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"]
elif model_name == "encodec":
assert sample_rate in [24000, 48000], "sample_rate must be 24000 or 48000"
if sample_rate == 48000:
assert channels == 2, "channels must be 2 for 48khz encodec model"
self.model_name = model_name
self.submodel_name = submodel_name
self.sample_rate = sample_rate
self.channels = channels
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.verbose = verbose
self.audio_load_worker = audio_load_worker
Expand Down Expand Up @@ -197,11 +206,23 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False):
device=self.device)
self.model.load_ckpt(model_path)

# encodec
elif model_name == "encodec":
# choose the right model based on sample_rate
# weights are loaded from the encodec repo: https://github.com/facebookresearch/encodec/
if self.sample_rate == 24000:
self.model = EncodecModel.encodec_model_24khz()
elif self.sample_rate == 48000:
self.model = EncodecModel.encodec_model_48khz()
# 24kbps is the max bandwidth supported by both versions
# these models use 32 residual quantizers
self.model.set_target_bandwidth(24.0)

self.model.eval()

def get_embeddings(self, x, sr):
"""
Get embeddings using VGGish model.
Get embeddings using VGGish, PANN, CLAP or EnCodec models.
Params:
-- x : a list of np.ndarray audio samples
-- sr : Sampling rate, if x is a list of audio samples. Default value is 16000.
Expand All @@ -219,8 +240,39 @@ def get_embeddings(self, x, sr):
elif self.model_name == "clap":
audio = torch.tensor(audio).float().unsqueeze(0)
embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True)

if self.device == torch.device('cuda'):
elif self.model_name == "encodec":
# add two dimensions
audio = torch.tensor(audio).float().unsqueeze(0).unsqueeze(0)
# if SAMPLE_RATE is 48000, we need to make audio stereo
if self.model.sample_rate == 48000:
if audio.shape[-1] != 2:
print(
"[Frechet Audio Distance] Audio is mono, converting to stereo for 48khz model..."
)
audio = torch.cat((audio, audio), dim=1)
else:
# transpose to (batch, channels, samples)
audio = audio[:, 0].transpose(1, 2)

if self.verbose:
print(
"[Frechet Audio Distance] Audio shape: {}".format(
audio.shape
)
)

with torch.no_grad():
# encodec embedding (before quantization)
embd = self.model.encoder(audio)
embd = embd.squeeze(0)

if self.verbose:
print(
"[Frechet Audio Distance] Embedding shape: {}".format(
embd.shape
)
)
if self.device == torch.device("cuda"):
embd = embd.cpu()

if torch.is_tensor(embd):
Expand Down Expand Up @@ -309,8 +361,8 @@ def update(*a):
for fname in os.listdir(dir):
res = pool.apply_async(
load_audio_task,
args=(os.path.join(dir, fname), self.sample_rate, dtype),
callback=update
args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype),
callback=update,
)
task_results.append(res)
pool.close()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
'laion_clap',
'transformers<=4.30.2',
'torchaudio',
'encodec',
]

[project.urls]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ torchvision

laion_clap
transformers<=4.30.2
torchaudio
torchaudio
encodec
100 changes: 99 additions & 1 deletion test/test_all.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,104 @@
"shutil.rmtree(\"test2\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EnCodec"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# EnCodec is a model trained as a neural codec, that is, it is trained to compress audio into a latent space and then reconstruct it.\n",
"# One is able to obtain high quality reconstruction from the generated embeddings.\n",
"# It encodes 1 second of audio into 75 embeddings of 128 dimensions each.\n",
"SAMPLE_RATE = 24000\n",
"LENGTH_IN_SECONDS = 1\n",
"\n",
"frechet = FrechetAudioDistance(\n",
" ckpt_dir=\"../checkpoints/encodec\",\n",
" model_name=\"encodec\",\n",
" # submodel_name=\"music_speech_audioset\", # for CLAP only\n",
" sample_rate=SAMPLE_RATE,\n",
" # use_pca=True, # for VGGish only\n",
" # use_activation=False, # for VGGish only\n",
" verbose=True,\n",
" audio_load_worker=8,\n",
" # enable_fusion=False, # for CLAP only\n",
")\n",
"\n",
"for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n",
" os.makedirs(traget, exist_ok=True)\n",
" frequencies = np.linspace(100, 1000, count).tolist()\n",
" for freq in frequencies:\n",
" samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n",
" filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n",
" print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n",
" sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n",
"\n",
"fad_score = frechet.score(\"background\", \"test1\")\n",
"print(\"FAD score test 1: %.8f\" % fad_score)\n",
"\n",
"fad_score = frechet.score(\"background\", \"test2\")\n",
"print(\"FAD score test 2: %.8f\" % fad_score)\n",
"\n",
"shutil.rmtree(\"background\")\n",
"shutil.rmtree(\"test1\")\n",
"shutil.rmtree(\"test2\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# EnCodec's 48kHz version maintains the embedding size, so compression is doubled.\n",
"# The model available for 48kHz audio expects stereo audio, so the input audio must have 2 channels.\n",
"SAMPLE_RATE = 48000\n",
"LENGTH_IN_SECONDS = 1\n",
"\n",
"frechet = FrechetAudioDistance(\n",
" ckpt_dir=\"../checkpoints/encodec\",\n",
" model_name=\"encodec\",\n",
" # submodel_name=\"music_speech_audioset\", # for CLAP only\n",
" sample_rate=SAMPLE_RATE,\n",
" channels=2,\n",
" # use_pca=True, # for VGGish only\n",
" # use_activation=False, # for VGGish only\n",
" verbose=True,\n",
" audio_load_worker=8,\n",
" # enable_fusion=False, # for CLAP only\n",
")\n",
"\n",
"for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n",
" os.makedirs(traget, exist_ok=True)\n",
" frequencies = np.linspace(100, 1000, count).tolist()\n",
" for freq in frequencies:\n",
" samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n",
" filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n",
" # make audio stereo\n",
" samples = np.stack([samples, samples], axis=1)\n",
"\n",
" print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n",
" sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n",
"\n",
"fad_score = frechet.score(\"background\", \"test1\")\n",
"print(\"FAD score test 1: %.8f\" % fad_score)\n",
"\n",
"fad_score = frechet.score(\"background\", \"test2\")\n",
"print(\"FAD score test 2: %.8f\" % fad_score)\n",
"\n",
"shutil.rmtree(\"background\")\n",
"shutil.rmtree(\"test1\")\n",
"shutil.rmtree(\"test2\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -618,7 +716,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.13"
},
"orig_nbformat": 4
},
Expand Down

0 comments on commit cf50298

Please sign in to comment.