diff --git a/.github/actions/audiocraft_build/action.yml b/.github/actions/audiocraft_build/action.yml index b13c3626..b412cd02 100644 --- a/.github/actions/audiocraft_build/action.yml +++ b/.github/actions/audiocraft_build/action.yml @@ -21,8 +21,8 @@ runs: python3 -m venv env . env/bin/activate python -m pip install --upgrade pip - pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu - pip install --pre xformers + pip install torch torchvision torchaudio + pip install xformers pip install -e '.[dev]' - name: System Dependencies shell: bash diff --git a/CHANGELOG.md b/CHANGELOG.md index 01026a6f..6036b72f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,31 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [1.2.0a] - TBD -## [1.0.1] - TBD +Adding stereo models. + + +## [1.1.0] - 2023-11-06 Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons. +Fixed DAC support with non default number of codebooks. + +Fixed bug when `two_step_cfg` was overriden when calling `generate()`. + +Fixed samples being always prompted with audio, rather than having both prompted and unprompted. + +**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release. + The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners. + We removed it, so you might need to retrain models. + +**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before). + +**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one + retrained a model with this pattern, so hopefully this won't impact you! + + ## [1.0.0] - 2023-09-07 Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion. diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py index 6ab34607..8b7acf22 100644 --- a/audiocraft/__init__.py +++ b/audiocraft/__init__.py @@ -23,4 +23,4 @@ # flake8: noqa from . import data, modules, models -__version__ = '1.0.0' +__version__ = '1.2.0a1' diff --git a/audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py b/audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py new file mode 100644 index 00000000..2904e73d --- /dev/null +++ b/audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +from ._explorers import LMExplorer +from ...environment import AudioCraftEnvironment + + +@LMExplorer +def explorer(launcher): + partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global']) + launcher.slurm_(gpus=32, partition=partitions) + launcher.bind_(solver='musicgen/musicgen_base_32khz') + # replace this by the desired music dataset, which needs to be stereo + launcher.bind_(dset='audio/example') + + fsdp = {'autocast': False, 'fsdp.use': True} + medium = {'model/lm/model_scale': 'medium'} + large = {'model/lm/model_scale': 'large'} + + cfg_low = {'classifier_free_guidance.training_dropout': 0.2} + wd_low = {'conditioners.description.t5.word_dropout': 0.2} + + adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4} + + stereo = { + 'codebooks_pattern.delay.delays': [0, 0, 1, 1, 2, 2, 3, 3], + 'transformer_lm.n_q': 8, + 'interleave_stereo_codebooks.use': True, + 'channels': 2, + } + + # You must follow the instructions in docs/MUSICGEN.md about the creation + # of the proper fine tuning checkpoints. We will assume they are stored under + # ~/checkpoints/{mode_name}. + + checkpoints = Path.home() / 'checkpoints' + + launcher.bind_(fsdp, stereo, {'optim.epochs': 100}) + + launcher.slurm_(gpus=32).bind_(label='32gpus') + with launcher.job_array(): + sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-small.th')}) + sub() + + launcher.slurm_(gpus=64).bind_(label='64gpus') + with launcher.job_array(): + sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-medium.th')}) + sub(medium, adam) + + launcher.slurm_(gpus=96).bind_(label='96gpus') + with launcher.job_array(): + sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-large.th')}) + sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3}) diff --git a/audiocraft/models/audiogen.py b/audiocraft/models/audiogen.py index 5cb88998..b4df536e 100644 --- a/audiocraft/models/audiogen.py +++ b/audiocraft/models/audiogen.py @@ -38,6 +38,10 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, self.name = name self.compression_model = compression_model self.lm = lm + # Just to be safe, let's put everything in eval mode. + self.compression_model.eval() + self.lm.eval() + if max_duration is None: if hasattr(lm, 'cfg'): max_duration = lm.cfg.dataset.segment_duration # type: ignore diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py index 038bf99c..b7144874 100644 --- a/audiocraft/models/builders.py +++ b/audiocraft/models/builders.py @@ -15,7 +15,7 @@ import omegaconf import torch -from .encodec import CompressionModel, EncodecModel +from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel from .lm import LMModel from ..modules.codebooks_patterns import ( CodebooksPatternProvider, @@ -23,7 +23,7 @@ MusicLMPattern, ParallelPatternProvider, UnrolledPatternProvider, - VALLEPattern, + CoarseFirstPattern, ) from ..modules.conditioners import ( BaseConditioner, @@ -172,7 +172,7 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb 'parallel': ParallelPatternProvider, 'delay': DelayedPatternProvider, 'unroll': UnrolledPatternProvider, - 'valle': VALLEPattern, + 'coarse_first': CoarseFirstPattern, 'musiclm': MusicLMPattern, } name = cfg.modeling @@ -196,7 +196,6 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000): 'dimension': 32, 'ratios': ratios, } - print(seanet_kwargs) encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs) decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs) quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4) @@ -248,5 +247,12 @@ def get_debug_lm_model(device='cpu'): def get_wrapped_compression_model( compression_model: CompressionModel, cfg: omegaconf.DictConfig) -> CompressionModel: - # more to come. + if hasattr(cfg, 'interleave_stereo_codebooks'): + if cfg.interleave_stereo_codebooks.use: + kwargs = dict_from_config(cfg.interleave_stereo_codebooks) + kwargs.pop('use') + compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs) + if hasattr(cfg, 'compression_model_n_q'): + if cfg.compression_model_n_q is not None: + compression_model.set_num_codebooks(cfg.compression_model_n_q) return compression_model diff --git a/audiocraft/models/encodec.py b/audiocraft/models/encodec.py index 4020ee26..cb0484ee 100644 --- a/audiocraft/models/encodec.py +++ b/audiocraft/models/encodec.py @@ -13,6 +13,7 @@ from pathlib import Path import typing as tp +from einops import rearrange import numpy as np import torch from torch import nn @@ -276,7 +277,7 @@ def forward(self, x: torch.Tensor) -> qt.QuantizedResult: def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: codes = self.model.encode(x, self.n_quantizers)[1] - return codes, None + return codes[:, :self.n_quantizers], None def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): assert scale is None @@ -391,3 +392,115 @@ def set_num_codebooks(self, n: int): if n not in self.possible_num_codebooks: raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}") self._num_codebooks = n + + +class InterleaveStereoCompressionModel(CompressionModel): + """Wraps a CompressionModel to support stereo inputs. The wrapped model + will be applied independently to the left and right channels, and both codebooks + will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per + channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on + `per_timestep`. + + Args: + model (CompressionModel): Compression model to wrap. + per_timestep (bool): Whether to interleave on the timestep dimension + or on the codebooks dimension. + """ + def __init__(self, model: CompressionModel, per_timestep: bool = False): + super().__init__() + self.model = model + self.per_timestep = per_timestep + assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio" + + @property + def total_codebooks(self): + return self.model.total_codebooks + + @property + def num_codebooks(self): + """Active number of codebooks used by the quantizer. + + ..Warning:: this reports the number of codebooks after the interleaving + of the codebooks! + """ + return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + + ..Warning:: this sets the number of codebooks before the interleaving! + """ + self.model.set_num_codebooks(n) + + @property + def num_virtual_steps(self) -> float: + """Return the number of virtual steps, e.g. one real step + will be split into that many steps. + """ + return 2 if self.per_timestep else 1 + + @property + def frame_rate(self) -> float: + return self.model.frame_rate * self.num_virtual_steps + + @property + def sample_rate(self) -> int: + return self.model.sample_rate + + @property + def channels(self) -> int: + return 2 + + @property + def cardinality(self): + """Cardinality of each codebook. + """ + return self.model.cardinality + + def forward(self, x: torch.Tensor) -> qt.QuantizedResult: + raise NotImplementedError("Not supported, use encode and decode.") + + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + B, C, T = x.shape + assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}" + + indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1)) + indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1)) + indices = torch.stack([indices_c0, indices_c1], dim=0) + scales: tp.Optional[torch.Tensor] = None + if scales_c0 is not None and scales_c1 is not None: + scales = torch.stack([scales_c0, scales_c1], dim=1) + + if self.per_timestep: + indices = rearrange(indices, 'c b k t -> b k (t c)', c=2) + else: + indices = rearrange(indices, 'c b k t -> b (k c) t', c=2) + + return (indices, scales) + + def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]: + if self.per_timestep: + codes = rearrange(codes, 'b k (t c) -> c b k t', c=2) + else: + codes = rearrange(codes, 'b (k c) t -> c b k t', c=2) + return codes[0], codes[1] + + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + B, K, T = codes.shape + assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match" + assert K == self.num_codebooks, "Provided codes' number of codebooks does not match" + + scale_c0, scale_c1 = None, None + if scale is not None: + assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}" + scale_c0 = scale[0, ...] + scale_c1 = scale[1, ...] + + codes_c0, codes_c1 = self.get_left_right_codes(codes) + audio_c0 = self.model.decode(codes_c0, scale_c0) + audio_c1 = self.model.decode(codes_c1, scale_c1) + return torch.cat([audio_c0, audio_c1], dim=1) + + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + raise NotImplementedError("Not supported by interleaved stereo wrapped models.") diff --git a/audiocraft/models/lm.py b/audiocraft/models/lm.py index 8cefd2c5..c4ea2e5e 100644 --- a/audiocraft/models/lm.py +++ b/audiocraft/models/lm.py @@ -314,7 +314,8 @@ def _sample_next_token(self, temp: float = 1.0, top_k: int = 0, top_p: float = 0.0, - cfg_coef: tp.Optional[float] = None) -> torch.Tensor: + cfg_coef: tp.Optional[float] = None, + two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor: """Sample next token from the model given a sequence and a set of conditions. The model supports multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). @@ -335,7 +336,8 @@ def _sample_next_token(self, B = sequence.shape[0] cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef model = self if self._fsdp is None else self._fsdp - if self.two_step_cfg and cfg_conditions != {}: + two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg + if two_step_cfg and cfg_conditions != {}: assert isinstance(cfg_conditions, tuple), type(cfg_conditions) condition_tensors, null_condition_tensors = cfg_conditions cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors) @@ -493,7 +495,7 @@ def generate(self, # sample next token from the model, next token shape is [B, K, 1] next_token = self._sample_next_token( curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, - cfg_coef=cfg_coef) + cfg_coef=cfg_coef, two_step_cfg=two_step_cfg) # ensure the tokens that should be masked are properly set to special_token_id # as the model never output special_token_id valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py index 7fd49d84..f02ba115 100644 --- a/audiocraft/models/loaders.py +++ b/audiocraft/models/loaders.py @@ -27,6 +27,7 @@ from omegaconf import OmegaConf, DictConfig import torch +import audiocraft from . import builders from .encodec import CompressionModel @@ -60,7 +61,9 @@ def _get_state_dict( else: assert filename is not None, "filename needs to be defined if using HF checkpoints" - file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir) + file = hf_hub_download( + repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir, + library_name="audiocraft", library_version=audiocraft.__version__) return torch.load(file, map_location=device) diff --git a/audiocraft/models/musicgen.py b/audiocraft/models/musicgen.py index 557d1196..88ee13b6 100644 --- a/audiocraft/models/musicgen.py +++ b/audiocraft/models/musicgen.py @@ -12,11 +12,12 @@ import typing as tp import warnings +import omegaconf import torch from .encodec import CompressionModel from .lm import LMModel -from .builders import get_debug_compression_model, get_debug_lm_model +from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model from .loaders import load_compression_model, load_lm_model from ..data.audio_utils import convert_audio from ..modules.conditioners import ConditioningAttributes, WavCondition @@ -52,14 +53,28 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, self.name = name self.compression_model = compression_model self.lm = lm + self.cfg: tp.Optional[omegaconf.DictConfig] = None + # Just to be safe, let's put everything in eval mode. + self.compression_model.eval() + self.lm.eval() + + if hasattr(lm, 'cfg'): + cfg = lm.cfg + assert isinstance(cfg, omegaconf.DictConfig) + self.cfg = cfg + + if self.cfg is not None: + self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg) + if max_duration is None: - if hasattr(lm, 'cfg'): + if self.cfg is not None: max_duration = lm.cfg.dataset.segment_duration # type: ignore else: raise ValueError("You must provide max_duration when building directly MusicGen") assert max_duration is not None self.max_duration: float = max_duration self.device = next(iter(lm.parameters())).device + self.generation_params: dict = {} self.set_generation_params(duration=15) # 15 seconds by default self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None @@ -118,6 +133,7 @@ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None): compression_model = load_compression_model(name, device=device) if 'self_wav' in lm.condition_provider.conditioners: lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True + lm.condition_provider.conditioners['self_wav']._use_masking = False return MusicGen(name, compression_model, lm) diff --git a/audiocraft/modules/codebooks_patterns.py b/audiocraft/modules/codebooks_patterns.py index 3cf3bb41..61362588 100644 --- a/audiocraft/modules/codebooks_patterns.py +++ b/audiocraft/modules/codebooks_patterns.py @@ -486,9 +486,14 @@ def get_pattern(self, timesteps: int) -> Pattern: return Pattern(out, n_q=self.n_q, timesteps=timesteps) -class VALLEPattern(CodebooksPatternProvider): - """Almost VALL-E style pattern. - We further allow some delays for the codebooks other than the first one. +class CoarseFirstPattern(CodebooksPatternProvider): + """First generates all the codebooks #1 (e.g. coarser), then the remaining ones, + potentially with delays. + + ..Warning:: You must always generate the full training duration at test time, for instance, + 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected + location. This is due to the non causality of the remaining codebooks with respect to + the first ones. Args: n_q (int): Number of codebooks. diff --git a/audiocraft/modules/conditioners.py b/audiocraft/modules/conditioners.py index d10ac8dc..178957d1 100644 --- a/audiocraft/modules/conditioners.py +++ b/audiocraft/modules/conditioners.py @@ -469,6 +469,8 @@ class WaveformConditioner(BaseConditioner): def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): super().__init__(dim, output_dim) self.device = device + # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample. + self._use_masking = True def tokenize(self, x: WavCondition) -> WavCondition: wav, length, sample_rate, path, seek_time = x @@ -496,13 +498,12 @@ def forward(self, x: WavCondition) -> ConditionType: embeds = embeds.to(self.output_proj.weight) embeds = self.output_proj(embeds) - if lengths is not None: + if lengths is not None and self._use_masking: lengths = lengths / self._downsampling_factor() mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore else: - mask = torch.ones_like(embeds) - embeds = (embeds * mask.unsqueeze(2).to(self.device)) - + mask = torch.ones_like(embeds[..., 0]) + embeds = (embeds * mask.unsqueeze(-1)) return embeds, mask @@ -537,6 +538,8 @@ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32) self.sample_rate = sample_rate self.match_len_on_eval = match_len_on_eval + if match_len_on_eval: + self._use_masking = False self.duration = duration self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device) stem_sources: list = self.demucs.sources # type: ignore @@ -792,6 +795,8 @@ def __init__(self, dim: int, output_dim: int, device: str, attribute: str, import laion_clap # type: ignore except ImportError: raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'") + warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). " + "Please retrain all models.") checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint) clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base') clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) diff --git a/audiocraft/optim/fsdp.py b/audiocraft/optim/fsdp.py index b3c1a55b..1090d3d7 100644 --- a/audiocraft/optim/fsdp.py +++ b/audiocraft/optim/fsdp.py @@ -143,8 +143,8 @@ def _name_without_fsdp_prefix(name: str) -> str: new_parts = [part for part in parts if part != FSDP_WRAPPED_MODULE] return '.'.join(new_parts) - def state_dict(self) -> tp.Dict[str, tp.Any]: # type: ignore - state = dict(super().state_dict()) + def state_dict(self, *args, **kwargs) -> tp.Dict[str, tp.Any]: # type: ignore + state = dict(super().state_dict(*args, **kwargs)) for key, value in list(state.items()): if is_sharded_tensor(value): del state[key] diff --git a/audiocraft/solvers/musicgen.py b/audiocraft/solvers/musicgen.py index bb615abf..2439da33 100644 --- a/audiocraft/solvers/musicgen.py +++ b/audiocraft/solvers/musicgen.py @@ -7,6 +7,7 @@ from pathlib import Path import time import typing as tp +import warnings import flashy import math @@ -226,7 +227,6 @@ def _compute_cross_entropy( ce = ce / K return ce, ce_per_codebook - @torch.no_grad() def _prepare_tokens_and_attributes( self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], check_synchronization_points: bool = False @@ -243,6 +243,12 @@ def _prepare_tokens_and_attributes( with B the batch size, K the number of codebooks, T_s the token timesteps. Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s]. """ + if self.model.training: + warnings.warn( + "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. " + "This is inconsistent with how model were trained in the MusicGen paper. We removed the " + "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. " + "Really sorry about that.") if self._cached_batch_loader is None or self.current_stage != "train": audio, infos = batch audio = audio.to(self.device) @@ -533,7 +539,7 @@ def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]): rtf = 1. else: gen_unprompted_outputs = self.run_generate_step( - batch, gen_duration=target_duration, prompt_duration=prompt_duration, + batch, gen_duration=target_duration, prompt_duration=None, **self.generation_params) gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu() rtf = gen_unprompted_outputs['rtf'] diff --git a/audiocraft/utils/cache.py b/audiocraft/utils/cache.py index 2fccc0ac..f7f82064 100644 --- a/audiocraft/utils/cache.py +++ b/audiocraft/utils/cache.py @@ -57,7 +57,7 @@ class EmbeddingCache: specify the index corresponding to the current embedding in the object that can represent batch metadata. If not specified, will return the full embedding unmodified. """ - def __init__(self, cache_path: tp.Union[Path], device: tp.Union[str, torch.device], + def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device], compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor], extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None): self.cache_path = Path(cache_path) diff --git a/audiocraft/utils/export_legacy.py b/audiocraft/utils/export_legacy.py index 52f145f3..367c3f3c 100644 --- a/audiocraft/utils/export_legacy.py +++ b/audiocraft/utils/export_legacy.py @@ -14,13 +14,21 @@ from omegaconf import OmegaConf, DictConfig import torch +from audiocraft import __version__ + def _clean_lm_cfg(cfg: DictConfig): OmegaConf.set_struct(cfg, False) # This used to be set automatically in the LM solver, need a more robust solution # for the future. cfg['transformer_lm']['card'] = 2048 - cfg['transformer_lm']['n_q'] = 4 + n_q = 4 + stereo_cfg = getattr(cfg, 'interleave_stereo_codebooks', None) + if stereo_cfg is not None and stereo_cfg.use: + if 'downsample' in stereo_cfg: + del stereo_cfg['downsample'] + n_q = 8 + cfg['transformer_lm']['n_q'] = n_q # Experimental params no longer supported. bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters', 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop'] @@ -30,27 +38,33 @@ def _clean_lm_cfg(cfg: DictConfig): return cfg -def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): - sig = Path(checkpoint_path).parent.name - assert len(sig) == 8, "Not a valid Dora signature" +def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): pkg = torch.load(checkpoint_path, 'cpu') new_pkg = { 'best_state': pkg['ema']['state']['model'], 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), + # The following params were NOT exported for the first release of MusicGen. + 'version': __version__, + 'exported': True, } - out_file = Path(out_folder) / f'{sig}.th' + Path(out_file).parent.mkdir(exist_ok=True, parents=True) torch.save(new_pkg, out_file) return out_file -def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]): - sig = Path(checkpoint_path).parent.name - assert len(sig) == 8, "Not a valid Dora signature" +def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): pkg = torch.load(checkpoint_path, 'cpu') + if pkg['fsdp_best_state']: + best_state = pkg['fsdp_best_state']['model'] + else: + best_state = pkg['best_state']['model'] new_pkg = { - 'best_state': pkg['fsdp_best_state']['model'], - 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])) + 'best_state': best_state, + 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])), + # The following params were NOT exported for the first release of MusicGen. + 'version': __version__, + 'exported': True, } - out_file = Path(out_folder) / f'{sig}.th' + Path(out_file).parent.mkdir(exist_ok=True, parents=True) torch.save(new_pkg, out_file) return out_file diff --git a/audiocraft/utils/utils.py b/audiocraft/utils/utils.py index 3135d70e..2c5799f8 100644 --- a/audiocraft/utils/utils.py +++ b/audiocraft/utils/utils.py @@ -185,7 +185,7 @@ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> t assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." final_length = lengths.max().item() if not max_len else max_len final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor - return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] + return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None] def hash_trick(word: str, vocab_size: int) -> int: diff --git a/config/conditioner/clapemb2music.yaml b/config/conditioner/clapemb2music.yaml index 8500a826..d44ac774 100644 --- a/config/conditioner/clapemb2music.yaml +++ b/config/conditioner/clapemb2music.yaml @@ -23,7 +23,7 @@ conditioners: checkpoint: //reference/clap/music_audioset_epoch_15_esc_90.14.pt model_arch: 'HTSAT-base' enable_fusion: false - sample_rate: 44100 + sample_rate: 48000 max_audio_length: 10 audio_stride: 1 dim: 512 diff --git a/config/model/lm/audiogen_lm.yaml b/config/model/lm/audiogen_lm.yaml index 696f7462..d17e7a93 100644 --- a/config/model/lm/audiogen_lm.yaml +++ b/config/model/lm/audiogen_lm.yaml @@ -18,7 +18,7 @@ codebooks_pattern: delays: [0, 0, 0, 0] music_lm: group_by: 2 - valle: + coarse_first: delays: [0, 0, 0] transformer_lm: diff --git a/config/model/lm/musicgen_lm.yaml b/config/model/lm/musicgen_lm.yaml index 5bc87a62..be1fbc14 100644 --- a/config/model/lm/musicgen_lm.yaml +++ b/config/model/lm/musicgen_lm.yaml @@ -18,7 +18,7 @@ codebooks_pattern: delays: [0, 0, 0, 0] music_lm: group_by: 2 - valle: + coarse_first: delays: [0, 0, 0] transformer_lm: diff --git a/config/solver/musicgen/default.yaml b/config/solver/musicgen/default.yaml index 59e01137..8bdf9c74 100644 --- a/config/solver/musicgen/default.yaml +++ b/config/solver/musicgen/default.yaml @@ -14,10 +14,20 @@ solver: musicgen sample_rate: ??? channels: ??? compression_model_checkpoint: ??? +# The following will set the num codebooks on the underlying +# model, this might be different from the actual value for n_q +# given to the transformer, when the model output is postprocessed, for instance +# for stereo channels. If not provided, default value for the compression model +# will be used. +compression_model_n_q: null tokens: padding_with_special_token: false +interleave_stereo_codebooks: + use: false + per_timestep: false + cache: path: write: false diff --git a/demos/musicgen_app.py b/demos/musicgen_app.py index 74c893e7..a10d52b5 100644 --- a/demos/musicgen_app.py +++ b/demos/musicgen_app.py @@ -9,24 +9,29 @@ import argparse from concurrent.futures import ProcessPoolExecutor +import logging import os from pathlib import Path import subprocess as sp +import sys from tempfile import NamedTemporaryFile import time import typing as tp import warnings +from einops import rearrange import torch import gradio as gr from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write +from audiocraft.models.encodec import InterleaveStereoCompressionModel from audiocraft.models import MusicGen, MultiBandDiffusion MODEL = None # Last used model -IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '') +SPACE_ID = os.environ.get('SPACE_ID', '') +IS_BATCHED = "facebook/MusicGen" in SPACE_ID or 'musicgen-internal/musicgen_dev' in SPACE_ID print(IS_BATCHED) MAX_BATCH_SIZE = 12 BATCHED_DURATION = 15 @@ -91,6 +96,7 @@ def load_model(version='facebook/musicgen-melody'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: + MODEL = None # in case loading would crash MODEL = MusicGen.get_pretrained(version) @@ -101,7 +107,7 @@ def load_diffusion(): MBD = MultiBandDiffusion.get_mbd_musicgen() -def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): +def _do_predictions(texts, melodies, duration, progress=False, gradio_progress=None, **gen_kwargs): MODEL.set_generation_params(duration=duration, **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() @@ -119,18 +125,30 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): melody = convert_audio(melody, sr, target_sr, target_ac) processed_melodies.append(melody) - if any(m is not None for m in processed_melodies): - outputs = MODEL.generate_with_chroma( - descriptions=texts, - melody_wavs=processed_melodies, - melody_sample_rate=target_sr, - progress=progress, - return_tokens=USE_DIFFUSION - ) - else: - outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION) + try: + if any(m is not None for m in processed_melodies): + outputs = MODEL.generate_with_chroma( + descriptions=texts, + melody_wavs=processed_melodies, + melody_sample_rate=target_sr, + progress=progress, + return_tokens=USE_DIFFUSION + ) + else: + outputs = MODEL.generate(texts, progress=progress, return_tokens=USE_DIFFUSION) + except RuntimeError as e: + raise gr.Error("Error while generating " + e.args[0]) if USE_DIFFUSION: - outputs_diffusion = MBD.tokens_to_wav(outputs[1]) + if gradio_progress is not None: + gradio_progress(1, desc='Running MultiBandDiffusion...') + tokens = outputs[1] + if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel): + left, right = MODEL.compression_model.get_left_right_codes(tokens) + tokens = torch.cat([left, right]) + outputs_diffusion = MBD.tokens_to_wav(tokens) + if isinstance(MODEL.compression_model, InterleaveStereoCompressionModel): + assert outputs_diffusion.shape[1] == 1 # output is mono + outputs_diffusion = rearrange(outputs_diffusion, '(s b) c t -> b (s c) t', s=2) outputs = torch.cat([outputs[0], outputs_diffusion], dim=0) outputs = outputs.detach().cpu().float() pending_videos = [] @@ -154,15 +172,24 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): def predict_batched(texts, melodies): max_text_length = 512 texts = [text[:max_text_length] for text in texts] - load_model('facebook/musicgen-melody') + load_model('facebook/musicgen-stereo-melody') res = _do_predictions(texts, melodies, BATCHED_DURATION) return res -def predict_full(model, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()): +def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()): global INTERRUPTING global USE_DIFFUSION INTERRUPTING = False + progress(0, desc="Loading model...") + model_path = model_path.strip() + if model_path: + if not Path(model_path).exists(): + raise gr.Error(f"Model path {model_path} doesn't exist.") + if not Path(model_path).is_dir(): + raise gr.Error(f"Model path {model_path} must be a folder containing " + "state_dict.bin and compression_state_dict_.bin.") + model = model_path if temperature < 0: raise gr.Error("Temperature must be >= 0.") if topk < 0: @@ -173,20 +200,26 @@ def predict_full(model, decoder, text, melody, duration, topk, topp, temperature topk = int(topk) if decoder == "MultiBand_Diffusion": USE_DIFFUSION = True + progress(0, desc="Loading diffusion model...") load_diffusion() else: USE_DIFFUSION = False load_model(model) + max_generated = 0 + def _progress(generated, to_generate): - progress((min(generated, to_generate), to_generate)) + nonlocal max_generated + max_generated = max(generated, max_generated) + progress((min(max_generated, to_generate), to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) videos, wavs = _do_predictions( [text], [melody], duration, progress=True, - top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef) + top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, + gradio_progress=progress) if USE_DIFFUSION: return videos[0], wavs[0], videos[1], wavs[1] return videos[0], wavs[0], None, None @@ -231,8 +264,12 @@ def ui_full(launch_kwargs): _ = gr.Button("Interrupt").click(fn=interrupt, queue=False) with gr.Row(): model = gr.Radio(["facebook/musicgen-melody", "facebook/musicgen-medium", "facebook/musicgen-small", - "facebook/musicgen-large"], - label="Model", value="facebook/musicgen-melody", interactive=True) + "facebook/musicgen-large", "facebook/musicgen-melody-large", + "facebook/musicgen-stereo-small", "facebook/musicgen-stereo-medium", + "facebook/musicgen-stereo-melody", "facebook/musicgen-stereo-large", + "facebook/musicgen-stereo-melody-large"], + label="Model", value="facebook/musicgen-stereo-melody", interactive=True) + model_path = gr.Text(label="Model Path (custom models)") with gr.Row(): decoder = gr.Radio(["Default", "MultiBand_Diffusion"], label="Decoder", value="Default", interactive=True) @@ -249,7 +286,7 @@ def ui_full(launch_kwargs): diffusion_output = gr.Video(label="MultiBand Diffusion Decoder") audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath') submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False, - show_progress=False).then(predict_full, inputs=[model, decoder, text, melody, duration, topk, topp, + show_progress=False).then(predict_full, inputs=[model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output, audio_output, diffusion_output, audio_diffusion]) radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False) @@ -260,37 +297,37 @@ def ui_full(launch_kwargs): [ "An 80s driving pop song with heavy drums and synth pads in the background", "./assets/bach.mp3", - "facebook/musicgen-melody", + "facebook/musicgen-stereo-melody", "Default" ], [ "A cheerful country song with acoustic guitars", "./assets/bolero_ravel.mp3", - "facebook/musicgen-melody", + "facebook/musicgen-stereo-melody", "Default" ], [ "90s rock song with electric guitar and heavy drums", None, - "facebook/musicgen-medium", + "facebook/musicgen-stereo-medium", "Default" ], [ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions", "./assets/bach.mp3", - "facebook/musicgen-melody", + "facebook/musicgen-stereo-melody", "Default" ], [ "lofi slow bpm electro chill with organic samples", None, - "facebook/musicgen-medium", + "facebook/musicgen-stereo-medium", "Default" ], [ "Punk rock with loud drum and power guitar", None, - "facebook/musicgen-medium", + "facebook/musicgen-stereo-medium", "MultiBand_Diffusion" ], ], @@ -302,8 +339,18 @@ def ui_full(launch_kwargs): ### More details The model will generate a short music extract based on the description you provided. - The model can generate up to 30 seconds of audio in one pass. It is now possible - to extend the generation by feeding back the end of the previous chunk of audio. + The model can generate up to 30 seconds of audio in one pass. + + The model was trained with description from a stock music catalog, descriptions that will work best + should include some level of details on the instruments present, along with some intended use case + (e.g. adding "perfect for a commercial" can somehow help). + + Using one of the `melody` model (e.g. `musicgen-melody-*`), you can optionally provide a reference audio + from which a broad melody will be extracted. + The model will then try to follow both the description and melody provided. + For best results, the melody should be 30 seconds long (I know, the samples we provide are not...) + + It is now possible to extend the generation by feeding back the end of the previous chunk of audio. This can take a long time, and the model might lose consistency. The model might also decide at arbitrary positions that the song ends. @@ -311,23 +358,23 @@ def ui_full(launch_kwargs): An overlap of 12 seconds is kept with the previously generated chunk, and 18 "new" seconds are generated each time. - We present 4 model variations: + We present 10 model variations: 1. facebook/musicgen-melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only. 2. facebook/musicgen-small -- a 300M transformer decoder conditioned on text only. 3. facebook/musicgen-medium -- a 1.5B transformer decoder conditioned on text only. 4. facebook/musicgen-large -- a 3.3B transformer decoder conditioned on text only. + 5. facebook/musicgen-melody-large -- a 3.3B transformer decoder conditioned on and melody. + 6. facebook/musicgen-stereo-*: same as the previous models but fine tuned to output stereo audio. We also present two way of decoding the audio tokens - 1. Use the default GAN based compression model - 2. Use MultiBand Diffusion from (paper linknano ) + 1. Use the default GAN based compression model. It can suffer from artifacts especially + for crashes, snares etc. + 2. Use [MultiBand Diffusion](https://arxiv.org/abs/2308.02560). Should improve the audio quality, + at an extra computational cost. When this is selected, we provide both the GAN based decoded + audio, and the one obtained with MBD. - When using `facebook/musicgen-melody`, you can optionally provide a reference audio from - which a broad melody will be extracted. The model will then try to follow both - the description and melody provided. - - You can also use your own GPU or a Google Colab by following the instructions on our repo. - See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft) + See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) for more details. """ ) @@ -341,7 +388,7 @@ def ui_batched(launch_kwargs): """ # MusicGen - This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), + This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md), a simple and controllable model for music generation presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
@@ -399,15 +446,27 @@ def ui_batched(launch_kwargs): gr.Markdown(""" ### More details - The model will generate 12 seconds of audio based on the description you provided. + The model will generate 15 seconds of audio based on the description you provided. + The model was trained with description from a stock music catalog, descriptions that will work best + should include some level of details on the instruments present, along with some intended use case + (e.g. adding "perfect for a commercial" can somehow help). + You can optionally provide a reference audio from which a broad melody will be extracted. The model will then try to follow both the description and melody provided. - All samples are generated with the `melody` model. + For best results, the melody should be 30 seconds long (I know, the samples we provide are not...) - You can also use your own GPU or a Google Colab by following the instructions on our repo. + You can access more control (longer generation, more models etc.) by clicking + the + Duplicate Space + (you will then need a paid GPU from HuggingFace). + If you have a GPU, you can run the gradio demo locally (click the link to our repo below for more info). + Finally, you can get a GPU for free from Google + and run the demo in [a Google Colab.](https://ai.honu.io/red/musicgen-colab). - See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft) - for more details. + See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/MUSICGEN.md) + for more details. All samples are generated with the `stereo-melody` model. """) demo.queue(max_size=8 * 4).launch(**launch_kwargs) @@ -454,6 +513,8 @@ def ui_batched(launch_kwargs): if args.share: launch_kwargs['share'] = args.share + logging.basicConfig(level=logging.INFO, stream=sys.stderr) + # Show the interface if IS_BATCHED: global USE_DIFFUSION diff --git a/docs/MBD.md b/docs/MBD.md index 4288a89d..b6629184 100644 --- a/docs/MBD.md +++ b/docs/MBD.md @@ -113,5 +113,5 @@ Learn more about AudioCraft training pipelines in the [dedicated section](./TRAI See license information in the [README](../README.md). -[arxiv]: https://dl.fbaipublicfiles.com/encodec/Diffusion/paper.pdf +[arxiv]: https://arxiv.org/abs/2308.02560 [mbd_samples]: https://ai.honu.io/papers/mbd/ diff --git a/docs/MUSICGEN.md b/docs/MUSICGEN.md index 606ce858..fb12e324 100644 --- a/docs/MUSICGEN.md +++ b/docs/MUSICGEN.md @@ -9,7 +9,7 @@ a small delay between the codebooks, we show we can predict them in parallel, th steps per second of audio. Check out our [sample page][musicgen_samples] or test the available demo! - + Open In Colab @@ -38,7 +38,7 @@ We offer a number of way to interact with MusicGen: 1. A demo is also available on the [`facebook/MusicGen` Hugging Face Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support). 2. You can run the extended demo on a Colab: -[colab notebook](https://colab.research.google.com/drive/1JlTOjB-G0A2Hz3h8PK63vLZk4xdCI5QB?usp=sharing) +[colab notebook](https://ai.honu.io/red/musicgen-colab) 3. You can use the gradio demo locally by running [`python -m demos.musicgen_app --share`](../demos/musicgen_app.py). 4. You can play with MusicGen by running the jupyter notebook at [`demos/musicgen_demo.ipynb`](../demos/musicgen_demo.ipynb) locally (if you have a GPU). 5. Finally, checkout [@camenduru Colab page](https://github.com/camenduru/MusicGen-colab) @@ -47,11 +47,18 @@ which is regularly updated with contributions from @camenduru and the community. ## API -We provide a simple API and 4 pre-trained models. The pre trained models are: +We provide a simple API and 10 pre-trained models. The pre trained models are: - `facebook/musicgen-small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small) - `facebook/musicgen-medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium) - `facebook/musicgen-melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody) - `facebook/musicgen-large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large) +- `facebook/musicgen-melody-large`: 3.3B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody-large) +- `facebook/musicgen-stereo-*`: All the previous models fine tuned for stereo generation - + [small](https://huggingface.co/facebook/musicgen-stereo-small), + [medium](https://huggingface.co/facebook/musicgen-stereo-medium), + [large](https://huggingface.co/facebook/musicgen-stereo-large), + [melody](https://huggingface.co/facebook/musicgen-stereo-melody), + [melody large](https://huggingface.co/facebook/musicgen-stereo-melody-large). We observe the best trade-off between quality and compute with the `facebook/musicgen-medium` or `facebook/musicgen-melody` model. In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller @@ -143,6 +150,10 @@ We provide a dummy dataset containing just a few examples for illustrative purpo Please read first the [TRAINING documentation](./TRAINING.md), in particular the Environment Setup section. + +**Warning:** As of version 1.1.0, a few breaking changes were introduced. Check the [CHANGELOG.md](../CHANGELOG.md) +file for more information. You might need to retrain some of your models. + ### Example configurations and grids We provide configurations to reproduce the released models and our research. @@ -205,6 +216,19 @@ dora run solver=musicgen/debug \ **Warning:** you are responsible for setting the proper value for `transformer_lm.n_q` and `transformer_lm.card` (cardinality of the codebooks). You also have to update the codebook_pattern to match `n_q` as shown in the example for using DAC. . +### Training stereo models + +Use the option `interleave_stereo_codebooks.use` set to `True` to activate stereo training along with `channels=2`. Left and right channels will be +encoded separately by the compression model, then their codebook will be interleaved, e.g. order of codebook is +`[1_L, 1_R, 2_L, 2_R, ...]`. You will also need to update the delays for the codebook patterns to match the number of codebooks, and the `n_q` value passed to the transformer LM: +``` +dora run solver=musicgen/debug \ + compression_model_checkpoint=//pretrained/facebook/encodec_32khz \ + channels=2 interleave_stereo_codebooks.use=True \ + transformer_lm.n_q=8 transformer_lm.card=2048 \ + codebooks_pattern.delay.delays='[0, 0, 1, 1, 2, 2, 3, 3]' +``` + ### Fine tuning existing models You can initialize your model to one of the pretrained models by using the `continue_from` argument, in particular @@ -228,6 +252,39 @@ dora run solver=musicgen/musicgen_base_32khz model/lm/model_scale=medium continu If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict `{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix. + +#### Fine tuning mono model to stereo + +You will not be able to `continue_from` a mono model with stereo training, as the shape of the embeddings and output linears +would not match. You can use the following snippet to prepare a proper finetuning checkpoint. + +```python +from pathlib import Path +import torch + +# Download the pretrained model, e.g. from +# https://huggingface.co/facebook/musicgen-melody/blob/main/state_dict.bin + +model_name = 'musicgen-melody' +root = Path.home() / 'checkpoints' +# You are responsible for downloading the following checkpoint in the proper location +input_state_dict_path = root / model_name / 'state_dict.bin' +state = torch.load(input_state_dict_path, 'cpu') +bs = state['best_state'] +# there is a slight different in format between training checkpoints and exported public checkpoints. +# If you want to use your own mono models from one of your training checkpont, following the instructions +# for exporting a model explained later on this page. +assert 'model' not in bs, 'The following code is for using an exported pretrained model' +nbs = dict(bs) +for k in range(8): + # We will just copy mono embeddings and linears twice, once for left and right channels. + nbs[f'linears.{k}.weight'] = bs[f'linears.{k//2}.weight'] + nbs[f'emb.{k}.weight'] = bs[f'emb.{k//2}.weight'] +torch.save({'best_state': {'model': nbs}}, root / f'stereo_finetune_{model_name}.th') +``` + +Now, you can use `$HOME/checkpoints/stereo_finetune_musicgen-melody.th` as a `continue_from` target (without a `//pretrained` prefix!). + ### Caching of EnCodec tokens It is possible to precompute the EnCodec tokens and other metadata. diff --git a/model_cards/MUSICGEN_MODEL_CARD.md b/model_cards/MUSICGEN_MODEL_CARD.md index 95431368..68e81d44 100644 --- a/model_cards/MUSICGEN_MODEL_CARD.md +++ b/model_cards/MUSICGEN_MODEL_CARD.md @@ -87,4 +87,19 @@ More information can be found in the paper [Simple and Controllable Music Genera **Use cases:** Users must be aware of the biases, limitations and risks of the model. MusicGen is a model developed for artificial intelligence research on controllable music generation. As such, it should not be used for downstream applications without further investigation and mitigation of risks. +## Update: stereo models and large melody. + +We further release a set of stereophonic capable models. Those were fine tuned for 200k updates starting +from the mono models. The training data is otherwise identical and capabilities and limitations are shared with the base modes. The stereo models work by getting 2 streams of tokens from the EnCodec model, and interleaving those using +the delay pattern. We also release a mono large model with melody conditioning capabilities. The list of new models +is as follow: + +- facebook/musicgen-stereo-small +- facebook/musicgen-stereo-medium +- facebook/musicgen-stereo-large +- facebook/musicgen-stereo-melody +- facebook/musicgen-melody-large +- facebook/musicgen-stereo-melody-large + + [arxiv]: https://arxiv.org/abs/2306.05284 diff --git a/tests/models/test_musicgen.py b/tests/models/test_musicgen.py index 65618a9e..2b32ac5d 100644 --- a/tests/models/test_musicgen.py +++ b/tests/models/test_musicgen.py @@ -56,3 +56,10 @@ def test_generate_long(self): wav = mg.generate( ['youpi', 'lapin dort']) assert list(wav.shape) == [2, 1, 32000 * 4] + + def test_generate_two_step_cfg(self): + mg = self.get_musicgen() + mg.set_generation_params(duration=2.0, extend_stride=2., two_step_cfg=True) + wav = mg.generate( + ['youpi', 'lapin dort']) + assert list(wav.shape) == [2, 1, 64000]