diff --git a/CHANGELOG.md b/CHANGELOG.md index 6036b72f..6bc38939 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). Adding stereo models. +Fixed the commitment loss, which was until now only applied to the first RVQ layer. + +Removed compression model state from the LM checkpoints, for consistency, it +should always be loaded from the original `compression_model_checkpoint`. + ## [1.1.0] - 2023-11-06 diff --git a/README.md b/README.md index 21b3f497..3c96c12f 100644 --- a/README.md +++ b/README.md @@ -13,11 +13,13 @@ AudioCraft requires Python 3.9, PyTorch 2.0.0. To install AudioCraft, you can ru ```shell # Best to make sure you have torch installed first, in particular before installing xformers. # Don't run this if you already have PyTorch installed. -pip install 'torch>=2.0' +python -m pip install 'torch==2.1.0' +# You might need the following before trying to install the packages +python -m pip install setuptools wheel # Then proceed to one of the following -pip install -U audiocraft # stable release -pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge -pip install -e . # or if you cloned the repo locally (mandatory if you want to train). +python -m pip install -U audiocraft # stable release +python -m pip install -U git+https://git@github.com/facebookresearch/audiocraft#egg=audiocraft # bleeding edge +python -m pip install -e . # or if you cloned the repo locally (mandatory if you want to train). ``` We also recommend having `ffmpeg` installed, either through your system or Anaconda: @@ -72,11 +74,11 @@ Finally, if you use a model that relies on Demucs (e.g. `musicgen-melody`) and w For the general framework of AudioCraft, please cite the following. ``` -@article{copet2023simple, +@inproceedings{copet2023simple, title={Simple and Controllable Music Generation}, author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, + booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, year={2023}, - journal={arXiv preprint arXiv:2306.05284}, } ``` diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py index 8b7acf22..840aa263 100644 --- a/audiocraft/__init__.py +++ b/audiocraft/__init__.py @@ -23,4 +23,4 @@ # flake8: noqa from . import data, modules, models -__version__ = '1.2.0a1' +__version__ = '1.2.0a2' diff --git a/audiocraft/modules/transformer.py b/audiocraft/modules/transformer.py index 691df6a2..e8100a4c 100644 --- a/audiocraft/modules/transformer.py +++ b/audiocraft/modules/transformer.py @@ -648,7 +648,6 @@ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforwar # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the # backward hook inside of FSDP... layer._magma_checkpointed = True # type: ignore - assert layer.layer_drop == 0., "Need further checking" # type: ignore def _apply_layer(self, layer, *args, **kwargs): method = self.checkpointing diff --git a/audiocraft/quantization/core_vq.py b/audiocraft/quantization/core_vq.py index da02a6ce..6aaa3b07 100644 --- a/audiocraft/quantization/core_vq.py +++ b/audiocraft/quantization/core_vq.py @@ -371,11 +371,16 @@ def forward(self, x, n_q: tp.Optional[int] = None): for i, layer in enumerate(self.layers[:n_q]): quantized, indices, loss = layer(residual) + quantized = quantized.detach() residual = residual - quantized quantized_out = quantized_out + quantized all_indices.append(indices) all_losses.append(loss) + if self.training: + # Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25 + quantized_out = x + (quantized_out - x).detach() + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses diff --git a/audiocraft/solvers/musicgen.py b/audiocraft/solvers/musicgen.py index 2439da33..72b65338 100644 --- a/audiocraft/solvers/musicgen.py +++ b/audiocraft/solvers/musicgen.py @@ -25,7 +25,7 @@ from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition from ..utils.cache import CachedBatchWriter, CachedBatchLoader from ..utils.samples.manager import SampleManager -from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once +from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once, model_hash class MusicGenSolver(base.StandardSolver): @@ -143,7 +143,7 @@ def build_model(self) -> None: # initialize optimization self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates) - self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler') + self.register_stateful('model', 'optimizer', 'lr_scheduler') self.register_best_state('model') self.autocast_dtype = { 'float16': torch.float16, 'bfloat16': torch.bfloat16 @@ -181,6 +181,22 @@ def load_state_dict(self, state: dict) -> None: key = prefix + key assert key not in model_state model_state[key] = value + if 'compression_model' in state: + # We used to store the `compression_model` state in the checkpoint, however + # this is in general not needed, as the compression model should always be readable + # from the original `cfg.compression_model_checkpoint` location. + compression_model_state = state.pop('compression_model') + before_hash = model_hash(self.compression_model) + self.compression_model.load_state_dict(compression_model_state) + after_hash = model_hash(self.compression_model) + if before_hash != after_hash: + raise RuntimeError( + "The compression model state inside the checkpoint is different" + " from the one obtained from compression_model_checkpoint..." + "We do not support altering the compression model inside the LM " + "checkpoint as parts of the code, in particular for running eval post-training " + "will use the compression_model_checkpoint as the source of truth.") + super().load_state_dict(state) def load_from_pretrained(self, name: str): diff --git a/audiocraft/utils/cache.py b/audiocraft/utils/cache.py index f7f82064..6ba017a7 100644 --- a/audiocraft/utils/cache.py +++ b/audiocraft/utils/cache.py @@ -287,6 +287,7 @@ def _load_one(self, index: int): if isinstance(part[0], torch.Tensor): out.append(torch.stack(part)) else: + assert isinstance(part, torch.Tensor) out.append(part) return out except Exception: diff --git a/demos/audiogen_demo.ipynb b/demos/audiogen_demo.ipynb index d3ad73fb..e209fd7b 100644 --- a/demos/audiogen_demo.ipynb +++ b/demos/audiogen_demo.ipynb @@ -83,7 +83,7 @@ " \"\"\"Generates a series of bip bip at the given frequency.\"\"\"\n", " t = torch.arange(\n", " int(duration * sample_rate), device=\"cuda\", dtype=torch.float) / sample_rate\n", - " wav = torch.cos(2 * math.pi * 440 * t)[None]\n", + " wav = torch.cos(2 * math.pi * frequency * t)[None]\n", " tp = (t % (2 * bip_duration)) / (2 * bip_duration)\n", " envelope = (tp >= 0.5).float()\n", " return wav * envelope" diff --git a/demos/musicgen_app.py b/demos/musicgen_app.py index a10d52b5..88cd27dc 100644 --- a/demos/musicgen_app.py +++ b/demos/musicgen_app.py @@ -77,8 +77,7 @@ def _cleanup(self): self.files.pop(0) else: break - - + file_cleaner = FileCleaner() @@ -96,6 +95,9 @@ def load_model(version='facebook/musicgen-melody'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: + # Clear PyTorch CUDA cache and delete model + del MODEL + torch.cuda.empty_cache() MODEL = None # in case loading would crash MODEL = MusicGen.get_pretrained(version) @@ -256,7 +258,7 @@ def ui_full(launch_kwargs): with gr.Column(): radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic") - melody = gr.Audio(source="upload", type="numpy", label="File", + melody = gr.Audio(sources=["upload"], type="numpy", label="File", interactive=True, elem_id="melody-input") with gr.Row(): submit = gr.Button("Submit") diff --git a/docs/MUSICGEN.md b/docs/MUSICGEN.md index fb12e324..9a6b1e74 100644 --- a/docs/MUSICGEN.md +++ b/docs/MUSICGEN.md @@ -340,9 +340,9 @@ Once you have launched some experiments, you can easily get access to the Solver with the latest trained model using the following snippet. ```python -from audiocraft.solvers.musicgen import MusicGen +from audiocraft.solvers.musicgen import MusicGenSolver -solver = MusicGen.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8) +solver = MusicGenSolver.get_eval_solver_from_sig('SIG', device='cpu', batch_size=8) solver.model solver.dataloaders ``` @@ -401,11 +401,11 @@ activations by sharding the optimizer state. ## Citation ``` -@article{copet2023simple, +@inproceedings{copet2023simple, title={Simple and Controllable Music Generation}, author={Jade Copet and Felix Kreuk and Itai Gat and Tal Remez and David Kant and Gabriel Synnaeve and Yossi Adi and Alexandre Défossez}, + booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, year={2023}, - journal={arXiv preprint arXiv:2306.05284}, } ``` diff --git a/requirements.txt b/requirements.txt index e44fe159..e0f4759e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,8 @@ julius num2words numpy sentencepiece -spacy==3.5.2 -torch>=2.0.0 +spacy>=3.6.1 +torch==2.1.0 torchaudio>=2.0.0 huggingface_hub tqdm @@ -20,4 +20,4 @@ librosa gradio torchmetrics encodec -protobuf \ No newline at end of file +protobuf diff --git a/tests/quantization/test_vq.py b/tests/quantization/test_vq.py index c215099f..e58fb0a1 100644 --- a/tests/quantization/test_vq.py +++ b/tests/quantization/test_vq.py @@ -12,7 +12,9 @@ class TestResidualVectorQuantizer: def test_rvq(self): - x = torch.randn(1, 16, 2048) + x = torch.randn(1, 16, 2048, requires_grad=True) vq = ResidualVectorQuantizer(n_q=8, dimension=16, bins=8) res = vq(x, 1.) assert res.x.shape == torch.Size([1, 16, 2048]) + res.x.sum().backward() + assert torch.allclose(x.grad.data, torch.ones(1))