From 62c26aeaca875645642acd67e63459c56c6b6a78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20D=C3=A9fossez?= Date: Mon, 11 Dec 2023 10:03:02 -0600 Subject: [PATCH 1/2] fix commitment loss --- README.md | 4 +++- audiocraft/quantization/core_vq.py | 5 +++++ tests/quantization/test_vq.py | 4 +++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e3687f1e..3c96c12f 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,9 @@ AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can ru ```shell # Best to make sure you have torch installed first, in particular before installing xformers. # Don't run this if you already have PyTorch installed. -python -m pip install 'torch>=2.0' +python -m pip install 'torch==2.1.0' +# You might need the following before trying to install the packages +python -m pip install setuptools wheel # Then proceed to one of the following python -m pip install -U audiocraft # stable release python -m pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge diff --git a/audiocraft/quantization/core_vq.py b/audiocraft/quantization/core_vq.py index da02a6ce..6aaa3b07 100644 --- a/audiocraft/quantization/core_vq.py +++ b/audiocraft/quantization/core_vq.py @@ -371,11 +371,16 @@ def forward(self, x, n_q: tp.Optional[int] = None): for i, layer in enumerate(self.layers[:n_q]): quantized, indices, loss = layer(residual) + quantized = quantized.detach() residual = residual - quantized quantized_out = quantized_out + quantized all_indices.append(indices) all_losses.append(loss) + if self.training: + # Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25 + quantized_out = x + (quantized_out - x).detach() + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses diff --git a/tests/quantization/test_vq.py b/tests/quantization/test_vq.py index c215099f..e58fb0a1 100644 --- a/tests/quantization/test_vq.py +++ b/tests/quantization/test_vq.py @@ -12,7 +12,9 @@ class TestResidualVectorQuantizer: def test_rvq(self): - x = torch.randn(1, 16, 2048) + x = torch.randn(1, 16, 2048, requires_grad=True) vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8) res = vq(x, 1.) assert res.x.shape == torch.Size([1, 16, 2048]) + res.x.sum().backward() + assert torch.allclose(x.grad.data, torch.ones(1)) From 32fec6440755392db2db8ab39ea7942e690476c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexandre=20D=C3=A9fossez?= Date: Mon, 11 Dec 2023 18:12:32 -0600 Subject: [PATCH 2/2] changelog and version --- CHANGELOG.md | 2 ++ audiocraft/__init__.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6036b72f..29fc9244 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Adding stereo models. +Fixed the commitment loss, which was until now only applied to the first RVQ layer. + ## [1.1.0] - 2023-11-06 diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py index 8b7acf22..840aa263 100644 --- a/audiocraft/__init__.py +++ b/audiocraft/__init__.py @@ -23,4 +23,4 @@ # flake8: noqa from . import data, modules, models -__version__ = '1.2.0a1' +__version__ = '1.2.0a2'