Skip to content

Commit

Permalink
Merge pull request #500 from simonrouard/style_conditioner
Browse files Browse the repository at this point in the history
MusicGen-Style
  • Loading branch information
JadeCopet authored Nov 11, 2024
2 parents adf0b04 + dbbf222 commit 4ec3e79
Show file tree
Hide file tree
Showing 22 changed files with 1,422 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/actions/audiocraft_build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ runs:
. env/bin/activate
python -m pip install --upgrade pip
pip install 'numpy<2' torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0
pip install xformers
pip install xformers==0.0.22.post7
pip install -e '.[dev,wm]'
- name: System Dependencies
shell: bash
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ At the moment, AudioCraft contains the training code and inference code for:
* [Multi Band Diffusion](./docs/MBD.md): An EnCodec compatible decoder using diffusion.
* [MAGNeT](./docs/MAGNET.md): A state-of-the-art non-autoregressive model for text-to-music and text-to-sound.
* [AudioSeal](./docs/WATERMARKING.md): A state-of-the-art audio watermarking.
* [MusicGen Style](./docs/MUSICGEN_STYLE.md): A state-of-the-art text-and-style-to-music model.

## Training code

Expand Down
Binary file added assets/electronic.mp3
Binary file not shown.
Binary file added assets/epic.wav
Binary file not shown.
25 changes: 25 additions & 0 deletions audiocraft/grids/musicgen/musicgen_style_32khz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from ._explorers import LMExplorer
from ...environment import AudioCraftEnvironment


@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=64, partition=partitions, constraint='volta32gb').bind_(label='64gpus')
launcher.bind_(dset='internal/music_400k_32khz')

sub = launcher.bind_({'solver': 'musicgen/musicgen_style_32khz',
'autocast': False,
'fsdp.use': True,
'model/lm/model_scale': 'medium',
'optim.optimizer': 'adamw',
'optim.lr': 1e-4,
'generate.every': 25,
'dataset.generate.num_samples': 64,
})
sub()
8 changes: 7 additions & 1 deletion audiocraft/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..modules.conditioners import (BaseConditioner, ChromaStemConditioner,
CLAPEmbeddingConditioner, ConditionFuser,
ConditioningProvider, LUTConditioner,
T5Conditioner)
T5Conditioner, StyleConditioner)
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
from ..utils.utils import dict_from_config
from .encodec import (CompressionModel, EncodecModel,
Expand Down Expand Up @@ -161,6 +161,12 @@ def get_conditioner_provider(
conditioners[str(cond)] = CLAPEmbeddingConditioner(
output_dim=output_dim, device=device, **model_args
)
elif model_type == 'style':
conditioners[str(cond)] = StyleConditioner(
output_dim=output_dim,
device=device,
**model_args
)
else:
raise ValueError(f"Unrecognized conditioning model: {model_type}")
conditioner = ConditioningProvider(
Expand Down
72 changes: 56 additions & 16 deletions audiocraft/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ConditioningProvider,
ConditioningAttributes,
ConditionType,
_drop_description_condition
)
from ..modules.codebooks_patterns import CodebooksPatternProvider
from ..modules.activations import get_activation_fn
Expand Down Expand Up @@ -255,7 +256,7 @@ def forward(self, sequence: torch.Tensor,
input_, cross_attention_input = self.fuser(input_, condition_tensors)

out = self.transformer(input_, cross_attention_src=cross_attention_input,
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None)) # type: ignore
if self.out_norm:
out = self.out_norm(out)
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
Expand Down Expand Up @@ -328,6 +329,7 @@ def _sample_next_token(self,
top_k: int = 0,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None,
cfg_coef_beta: tp.Optional[float] = None,
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
"""Sample next token from the model given a sequence and a set of conditions. The model supports
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
Expand All @@ -343,14 +345,37 @@ def _sample_next_token(self,
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coef (float, optional): classifier free guidance coefficient
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
push the text condition more than the style condition in the case where both text and style
conditions are being used.
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
Returns:
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
"""
B = sequence.shape[0]
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
model = self if self._fsdp is None else self._fsdp
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
if two_step_cfg and cfg_conditions != {}:
if cfg_coef_beta is not None:
assert isinstance(cfg_conditions, dict)
condition_tensors = cfg_conditions
if condition_tensors:
# Preparing for CFG, predicting conditional text and style, conditional style
# and unconditional
sequence = torch.cat([sequence, sequence, sequence], dim=0)
all_logits = model(
sequence,
conditions=[], condition_tensors=condition_tensors)
if condition_tensors:
cond_logits, wav_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
logits = uncond_logits + cfg_coef * (
wav_logits + cfg_coef_beta * (cond_logits - wav_logits) - uncond_logits
)

elif two_step_cfg and cfg_conditions != {}:
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
condition_tensors, null_condition_tensors = cfg_conditions
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
Expand Down Expand Up @@ -403,24 +428,30 @@ def generate(self,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None,
cfg_coef_beta: tp.Optional[float] = None,
two_step_cfg: tp.Optional[bool] = None,
remove_prompts: bool = False,
check: bool = False,
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
**kwargs) -> torch.Tensor:
) -> torch.Tensor:
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
be performed in a greedy fashion or using sampling with top K and top P strategies.
Args:
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
conditions (list of ConditioningAttributes, optional): List of conditions.
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
max_gen_len (int): Maximum generation length.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coeff (float, optional): Classifier-free guidance coefficient.
cfg_coef (float, optional): Classifier-free guidance coefficient.
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
push the text condition more than the style condition in the case where both text and style
conditions are being used.
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
remove_prompts (bool): Whether to remove prompts from generation or not.
check (bool): Whether to apply further checks on generated sequence.
Expand Down Expand Up @@ -455,18 +486,27 @@ def generate(self,
# the padding structure is exactly the same between train and test.
# With a batch size of 1, this can be slower though.
cfg_conditions: CFGConditions
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
if conditions:
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
if two_step_cfg:
cfg_conditions = (
self.condition_provider(self.condition_provider.tokenize(conditions)),
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
)
else:
conditions = conditions + null_conditions
cfg_conditions = {}
if cfg_coef_beta is not None:
if conditions:
wav_conditions = _drop_description_condition(conditions)
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
conditions = conditions + wav_conditions + null_conditions
tokenized = self.condition_provider.tokenize(conditions)
cfg_conditions = self.condition_provider(tokenized)
elif conditions:
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
if conditions:
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
if two_step_cfg:
cfg_conditions = (
self.condition_provider(self.condition_provider.tokenize(conditions)),
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
)
else:
conditions = conditions + null_conditions
tokenized = self.condition_provider.tokenize(conditions)
cfg_conditions = self.condition_provider(tokenized)
else:
cfg_conditions = {}

Expand Down Expand Up @@ -509,7 +549,7 @@ def generate(self,
# sample next token from the model, next token shape is [B, K, 1]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta, two_step_cfg=two_step_cfg)
# ensure the tokens that should be masked are properly set to special_token_id
# as the model never output special_token_id
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
Expand Down
2 changes: 2 additions & 0 deletions audiocraft/models/lm_magnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def generate(self,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None,
cfg_coef_beta: tp.Optional[float] = None,
two_step_cfg: tp.Optional[bool] = None,
remove_prompts: bool = False,
check: bool = False,
Expand All @@ -135,6 +136,7 @@ def generate(self,
assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance."
assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg."
assert check is False, "MAGNeT currently doesn't support the check arg."
assert cfg_coef_beta is None, "MAGNeT currently doesn't support the cfg_coef_beta arg."
# Call the MAGNeT-specific generation method
return self._generate_magnet(prompt=prompt,
conditions=conditions,
Expand Down
34 changes: 32 additions & 2 deletions audiocraft/models/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .builders import get_debug_compression_model, get_debug_lm_model
from .loaders import load_compression_model, load_lm_model
from ..data.audio_utils import convert_audio
from ..modules.conditioners import ConditioningAttributes, WavCondition
from ..modules.conditioners import ConditioningAttributes, WavCondition, StyleConditioner


MelodyList = tp.List[tp.Optional[torch.Tensor]]
Expand All @@ -33,6 +33,7 @@
"medium": "facebook/musicgen-medium",
"large": "facebook/musicgen-large",
"melody": "facebook/musicgen-melody",
"style": "facebook/musicgen-style",
}


Expand Down Expand Up @@ -63,6 +64,8 @@ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
# see: https://huggingface.co/facebook/musicgen-melody
- facebook/musicgen-large (3.3B), text to music,
# see: https://huggingface.co/facebook/musicgen-large
- facebook/musicgen-style (1.5 B), text and style to music,
# see: https://huggingface.co/facebook/musicgen-style
"""
if device is None:
if torch.cuda.device_count():
Expand Down Expand Up @@ -93,7 +96,8 @@ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
top_p: float = 0.0, temperature: float = 1.0,
duration: float = 30.0, cfg_coef: float = 3.0,
two_step_cfg: bool = False, extend_stride: float = 18):
cfg_coef_beta: tp.Optional[float] = None,
two_step_cfg: bool = False, extend_stride: float = 18,):
"""Set the generation parameters for MusicGen.
Args:
Expand All @@ -103,6 +107,10 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
cfg_coef_beta (float, optional): beta coefficient in double classifier free guidance.
Should be only used for MusicGen melody if we want to push the text condition more than
the audio conditioning. See paragraph 4.3 in https://arxiv.org/pdf/2407.12563 to understand
double CFG.
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
instead of batching together the two. This has some impact on how things
are padded but seems to have little impact in practice.
Expand All @@ -120,8 +128,30 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
'top_p': top_p,
'cfg_coef': cfg_coef,
'two_step_cfg': two_step_cfg,
'cfg_coef_beta': cfg_coef_beta,
}

def set_style_conditioner_params(self, eval_q: int = 3, excerpt_length: float = 3.0,
ds_factor: tp.Optional[int] = None,
encodec_n_q: tp.Optional[int] = None) -> None:
"""Set the parameters of the style conditioner
Args:
eval_q (int): the number of residual quantization streams used to quantize the style condition
the smaller it is, the narrower is the information bottleneck
excerpt_length (float): the excerpt length in seconds that is extracted from the audio
conditioning
ds_factor: (int): the downsampling factor used to downsample the style tokens before
using them as a prefix
encodec_n_q: (int, optional): if encodec is used as a feature extractor, sets the number
of streams that is used to extract features
"""
assert isinstance(self.lm.condition_provider.conditioners.self_wav, StyleConditioner), \
"Only use this function if you model is MusicGen-Style"
self.lm.condition_provider.conditioners.self_wav.set_params(eval_q=eval_q,
excerpt_length=excerpt_length,
ds_factor=ds_factor,
encodec_n_q=encodec_n_q)

def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
melody_sample_rate: int, progress: bool = False,
return_tokens: bool = False) -> tp.Union[torch.Tensor,
Expand Down
Loading

0 comments on commit 4ec3e79

Please sign in to comment.