diff --git a/audiocraft/models/lm.py b/audiocraft/models/lm.py index c6a25d4d..0719b0cc 100644 --- a/audiocraft/models/lm.py +++ b/audiocraft/models/lm.py @@ -24,6 +24,7 @@ ConditioningProvider, ConditioningAttributes, ConditionType, + _drop_text_condition ) from ..modules.codebooks_patterns import CodebooksPatternProvider from ..modules.activations import get_activation_fn @@ -95,55 +96,6 @@ def init_layer(m: nn.Module, init_fn(m.weight) -def merge_pairs_of_conditions(alphas: tp.Dict[str, float], num_conditions: int, cfg_conditions: CFGConditions): - """ - Given: - - alphas: dic where the keys are attributes that need to be merged and the values are - floats in [0, 1] - ex: {'description': 0.3, 'self_wav': 0.5, 'style_wav': 1.0} - - num_conditions: the number of conditions in parallel, 2 in case of cfg, 3 in case of - double_cfg and 4 in case of triple_cfg - - cfg_conditions - - Returns: - for each pair of condition (2i, 2i+1), - we compute sqrt{alpha}*condition_{2i} + sqrt{1 - alpha**2}*condition_{2i+1} - """ - new_cfg_conditions = deepcopy(cfg_conditions) - for attribute in alphas.keys(): - embed, mask = cfg_conditions[attribute] # type: ignore - B, T, C = embed.shape # type: ignore - assert B % (2 * num_conditions) == 0 - alpha = alphas[attribute] - new_embed = alpha ** 0.5 * embed[::2] + (1 - alpha)**0.5 * embed[1::2] - new_mask = mask[::2] - new_cfg_conditions[attribute] = (new_embed, new_mask) # type: ignore - return new_cfg_conditions - - -def sum_pairs_of_conditions(which_conditions: tp.List[str], num_conditions: int, cfg_conditions: CFGConditions): - """ - Given: - - which_conditions: list of the attributes that need to be merged - ex: ['description', 'self_wav', 'style_wav'] - - num_conditions: the number of conditions in parallel, 2 in case of cfg, 3 in case of - double_cfg and 4 in case of triple_cfg - - cfg_conditions - - Returns: for each pair of condition (2i, 2i+1), we compute - alpha*condition_{2i} + sqrt{1 - alpha**2}*condition_{2i+1} - """ - new_cfg_conditions = deepcopy(cfg_conditions) - for attribute in which_conditions: - embed, mask = cfg_conditions[attribute] # type: ignore - B, T, C = embed.shape # type: ignore - assert B % (2 * num_conditions) == 0 - new_embed = embed[::2] + embed[1::2] - new_mask = mask[::2] # type: ignore - new_cfg_conditions[attribute] = (new_embed, new_mask) # type: ignore - return new_cfg_conditions - - class ScaledEmbedding(nn.Embedding): """Boost learning rate for embeddings (with `scale`). """ @@ -377,9 +329,8 @@ def _sample_next_token(self, temp: float = 1.0, top_k: int = 0, top_p: float = 0.0, - double_cfg: bool = False, cfg_coef: tp.Optional[float] = None, - cfg_coef_2: tp.Optional[float] = None, + cfg_coef_beta: tp.Optional[float] = None, two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor: """Sample next token from the model given a sequence and a set of conditions. The model supports multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). @@ -395,6 +346,11 @@ def _sample_next_token(self, top_k (int): K for "top-k" sampling. top_p (float): P for "top-p" sampling. cfg_coef (float, optional): classifier free guidance coefficient + cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef. + If not None, we apply double classifier free guidance as introduced in MusicGen-Style + in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to + push the text condition more than the style condition in the case where both text and style + conditions are being used. Returns: next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. """ @@ -402,11 +358,12 @@ def _sample_next_token(self, cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef model = self if self._fsdp is None else self._fsdp two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg - if double_cfg: + if cfg_coef_beta is not None: assert isinstance(cfg_conditions, dict) - assert cfg_coef_2 is not None condition_tensors = cfg_conditions if condition_tensors: + # Preparing for CFG, predicting conditional text and style, conditional style + # and unconditional sequence = torch.cat([sequence, sequence, sequence], dim=0) all_logits = model( sequence, @@ -414,7 +371,7 @@ def _sample_next_token(self, if condition_tensors: cond_logits, wav_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card] logits = uncond_logits + cfg_coef * ( - wav_logits + cfg_coef_2 * (cond_logits - wav_logits) - uncond_logits + wav_logits + cfg_coef_beta * (cond_logits - wav_logits) - uncond_logits ) elif two_step_cfg and cfg_conditions != {}: @@ -470,29 +427,30 @@ def generate(self, top_k: int = 250, top_p: float = 0.0, cfg_coef: tp.Optional[float] = None, + cfg_coef_beta: tp.Optional[float] = None, two_step_cfg: tp.Optional[bool] = None, - double_cfg: bool = False, - cfg_coef_2: tp.Optional[float] = None, remove_prompts: bool = False, check: bool = False, callback: tp.Optional[tp.Callable[[int, int], None]] = None, - postprocess_fn: tp.Optional[str] = None, - alphas: tp.Optional[tp.Dict[str, float]] = None, - which_conditions: tp.Optional[tp.List[str]] = None, ) -> torch.Tensor: """Generate tokens sampling from the model given a prompt or unconditionally. Generation can be performed in a greedy fashion or using sampling with top K and top P strategies. Args: prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. - conditions_tensors (list of ConditioningAttributes, optional): List of conditions. + conditions (list of ConditioningAttributes, optional): List of conditions. num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given. max_gen_len (int): Maximum generation length. use_sampling (bool): Whether to use a sampling strategy or not. temp (float): Sampling temperature. top_k (int): K for "top-k" sampling. top_p (float): P for "top-p" sampling. - cfg_coeff (float, optional): Classifier-free guidance coefficient. + cfg_coef (float, optional): Classifier-free guidance coefficient. + cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef. + If not None, we apply double classifier free guidance as introduced in MusicGen-Style + in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to + push the text condition more than the style condition in the case where both text and style + conditions are being used. two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation. remove_prompts (bool): Whether to remove prompts from generation or not. check (bool): Whether to apply further checks on generated sequence. @@ -528,17 +486,14 @@ def generate(self, # With a batch size of 1, this can be slower though. cfg_conditions: CFGConditions cfg_conditions = {} - if double_cfg: - num_conditions = 3 + if cfg_coef_beta is not None: if conditions: - wav_conditions = AttributeDropout(p={'text': {'description': 1.0}, - 'wav': {'self_wav': 0.0}})(conditions) + wav_conditions = _drop_text_condition(conditions) null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions) conditions = conditions + wav_conditions + null_conditions tokenized = self.condition_provider.tokenize(conditions) cfg_conditions = self.condition_provider(tokenized) elif conditions: - num_conditions = 2 two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg if conditions: null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions) @@ -553,23 +508,10 @@ def generate(self, cfg_conditions = self.condition_provider(tokenized) else: cfg_conditions = {} - if postprocess_fn is not None: - if postprocess_fn == 'merge': - assert alphas is not None - cfg_conditions = merge_pairs_of_conditions(alphas, num_conditions, cfg_conditions) - elif postprocess_fn == 'sum': - assert which_conditions is not None - cfg_conditions = sum_pairs_of_conditions(which_conditions, num_conditions, cfg_conditions) - else: - assert False if prompt is None: assert num_samples > 0 - if postprocess_fn in ['merge', 'sum']: - assert num_samples % 2 == 0 - prompt = torch.zeros((num_samples // 2, self.num_codebooks, 0), dtype=torch.long, device=device) - else: - prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) + prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) B, K, T = prompt.shape start_offset = T @@ -606,7 +548,7 @@ def generate(self, # sample next token from the model, next token shape is [B, K, 1] next_token = self._sample_next_token( curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, - double_cfg=double_cfg, cfg_coef=cfg_coef, cfg_coef_2=cfg_coef_2, two_step_cfg=two_step_cfg) + cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta, two_step_cfg=two_step_cfg) # ensure the tokens that should be masked are properly set to special_token_id # as the model never output special_token_id valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) diff --git a/audiocraft/models/lm_magnet.py b/audiocraft/models/lm_magnet.py index 2940a361..65894434 100644 --- a/audiocraft/models/lm_magnet.py +++ b/audiocraft/models/lm_magnet.py @@ -126,25 +126,17 @@ def generate(self, top_p: float = 0.0, cfg_coef: tp.Optional[float] = None, two_step_cfg: tp.Optional[bool] = None, - double_cfg: bool = False, - cfg_coef_2: tp.Optional[float] = None, + cfg_coef_beta: tp.Optional[float] = None, remove_prompts: bool = False, check: bool = False, callback: tp.Optional[tp.Callable[[int, int], None]] = None, - postprocess_fn: tp.Optional[str] = None, - alphas: tp.Optional[tp.Dict[str, float]] = None, - which_conditions: tp.Optional[tp.List[str]] = None, **kwargs) -> torch.Tensor: assert cfg_coef is None, "Unsupported in MAGNeT. Use max_cfg_coef,min_cfg_coef instead." assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance." assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg." assert check is False, "MAGNeT currently doesn't support the check arg." - assert postprocess_fn is None, "MAGNeT currently doesn't support the postprocess_fn arg." - assert alphas is None, "MAGNeT currently doesn't support the alphas arg." - assert which_conditions is None, "MAGNeT currently doesn't support the which_conditions arg." - assert double_cfg is False, "MAGNeT currently doesn't support the double_cfg arg." - assert cfg_coef_2 is None, "MAGNeT currently doesn't support the cfg_coef_2 arg." + assert cfg_coef_beta is None, "MAGNeT currently doesn't support the cfg_coef_beta arg." # Call the MAGNeT-specific generation method return self._generate_magnet(prompt=prompt, conditions=conditions, diff --git a/audiocraft/models/musicgen.py b/audiocraft/models/musicgen.py index 5fb8866c..8e770bfc 100644 --- a/audiocraft/models/musicgen.py +++ b/audiocraft/models/musicgen.py @@ -95,12 +95,9 @@ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None): def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, top_p: float = 0.0, temperature: float = 1.0, - duration: float = 30.0, double_cfg: bool = False, - cfg_coef: float = 3.0, cfg_coef_2: tp.Optional[float] = None, - two_step_cfg: bool = False, extend_stride: float = 18, - postprocess_fn: tp.Optional[str] = None, - alphas: tp.Optional[tp.Dict[str, float]] = None, - which_conditions: tp.Optional[tp.List[str]] = None): + duration: float = 30.0, cfg_coef: float = 3.0, + cfg_coef_beta: tp.Optional[float] = None, + two_step_cfg: bool = False, extend_stride: float = 18,): """Set the generation parameters for MusicGen. Args: @@ -110,6 +107,10 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. duration (float, optional): Duration of the generated waveform. Defaults to 30.0. cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. + cfg_coef_beta (float, optional): beta coefficient in double classifier free guidance. + Should be only used for MusicGen melody if we want to push the text condition more than + the melody conditioning. See paragraph 4.3 in https://arxiv.org/pdf/2407.12563 to understand + double CFG. two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, instead of batching together the two. This has some impact on how things are padded but seems to have little impact in practice. @@ -127,11 +128,7 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, 'top_p': top_p, 'cfg_coef': cfg_coef, 'two_step_cfg': two_step_cfg, - 'double_cfg': double_cfg, - 'cfg_coef_2': cfg_coef_2, - 'postprocess_fn': postprocess_fn, - 'alphas': alphas, - 'which_conditions': which_conditions + 'cfg_coef_beta': cfg_coef_beta, } def set_style_conditioner_params(self, eval_q: int = 3, excerpt_length: float = 3.0, diff --git a/audiocraft/modules/conditioners.py b/audiocraft/modules/conditioners.py index 19525d2e..70319f54 100644 --- a/audiocraft/modules/conditioners.py +++ b/audiocraft/modules/conditioners.py @@ -177,6 +177,16 @@ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: seek_time=[0] * embed.wav.shape[0], ) +def _drop_text_condition(conditions: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + """Drop the text condition but keep the wav conditon on a list of ConditioningAttributes. + This is useful to calculate l_style in the double classifier free guidance formula. + See paragraph 4.3 in https://arxiv.org/pdf/2407.12563 + + Args: + conditions (tp.List[ConditioningAttributes]): List of conditions. + """ + return AttributeDropout(p={'text': {'description': 1.0}, + 'wav': {'self_wav': 0.0}})(conditions) class Tokenizer: """Base tokenizer implementation diff --git a/config/solver/musicgen/musicgen_style_32khz.yaml b/config/solver/musicgen/musicgen_style_32khz.yaml index 35dc4927..5051658e 100644 --- a/config/solver/musicgen/musicgen_style_32khz.yaml +++ b/config/solver/musicgen/musicgen_style_32khz.yaml @@ -36,8 +36,7 @@ generate: top_k: 250 top_p: 0.0 cfg_coef: 3.0 - cfg_coef_2: - double_cfg: false + cfg_coef_beta: optim: epochs: 500 diff --git a/demos/musicgen_style_app.py b/demos/musicgen_style_app.py index e59602fc..d4df98d7 100644 --- a/demos/musicgen_style_app.py +++ b/demos/musicgen_style_app.py @@ -104,8 +104,8 @@ def load_diffusion(): MBD = MultiBandDiffusion.get_mbd_musicgen() -def _do_predictions(texts, melodies, duration, top_k, top_p, temperature, cfg_coef, double_cfg, cfg_coef_2, eval_q, excerpt_length, progress=False, gradio_progress=None): - MODEL.set_generation_params(duration=duration, top_k=top_k, top_p=top_p, temperature=temperature, cfg_coef=cfg_coef, double_cfg=double_cfg, cfg_coef_2=cfg_coef_2) +def _do_predictions(texts, melodies, duration, top_k, top_p, temperature, cfg_coef, double_cfg, cfg_coef_beta, eval_q, excerpt_length, progress=False, gradio_progress=None): + MODEL.set_generation_params(duration=duration, top_k=top_k, top_p=top_p, temperature=temperature, cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta) MODEL.set_style_conditioner_params(eval_q=eval_q, excerpt_length=excerpt_length) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() @@ -161,7 +161,7 @@ def _do_predictions(texts, melodies, duration, top_k, top_p, temperature, cfg_co return out_videos, out_wavs -def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, double_cfg, cfg_coef_2, eval_q, excerpt_length, progress=gr.Progress()): +def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, temperature, cfg_coef, double_cfg, cfg_coef_beta, eval_q, excerpt_length, progress=gr.Progress()): global INTERRUPTING global USE_DIFFUSION INTERRUPTING = False @@ -195,11 +195,8 @@ def predict_full(model, model_path, decoder, text, melody, duration, topk, topp, USE_DIFFUSION = False load_model(model) - if double_cfg == "Yes": - double_cfg = True - else: - double_cfg = False - cfg_coef_2 = None + if double_cfg != "Yes": + cfg_coef_beta = None max_generated = 0 def _progress(generated, to_generate): @@ -213,7 +210,7 @@ def _progress(generated, to_generate): videos, wavs = _do_predictions( [text], [melody], duration, progress=True, top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef, - double_cfg=double_cfg, cfg_coef_2=cfg_coef_2, eval_q=eval_q, excerpt_length=excerpt_length, + cfg_coef_beta=cfg_coef_beta, eval_q=eval_q, excerpt_length=excerpt_length, gradio_progress=progress) if USE_DIFFUSION: return videos[0], wavs[0], videos[1], wavs[1] @@ -274,7 +271,7 @@ def ui_full(launch_kwargs): cfg_coef = gr.Number(label="CFG alpha", value=3.0, interactive=True) double_cfg = gr.Radio(["Yes", "No"], label="Use Double Classifier Free Guidance (if No, CFG beta is useless)", value="Yes", interactive=True) - cfg_coef_2 = gr.Number(label="CFG beta (double CFG)", value=5.0, interactive=True) + cfg_coef_beta = gr.Number(label="CFG beta (double CFG)", value=5.0, interactive=True) excerpt_length = gr.Number(label="length used of the conditioning (has to be <= 4.5 seconds)", value=3.0, interactive=True) with gr.Column(): output = gr.Video(label="Generated Music") @@ -283,7 +280,7 @@ def ui_full(launch_kwargs): audio_diffusion = gr.Audio(label="MultiBand Diffusion Decoder (wav)", type='filepath') submit.click(toggle_diffusion, decoder, [diffusion_output, audio_diffusion], queue=False, show_progress=False).then(predict_full, inputs=[model, model_path, decoder, text, melody, duration, topk, topp, - temperature, cfg_coef, double_cfg, cfg_coef_2, eval_q, excerpt_length], + temperature, cfg_coef, double_cfg, cfg_coef_beta, eval_q, excerpt_length], outputs=[output, audio_output, diffusion_output, audio_diffusion]) radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False) @@ -378,4 +375,4 @@ def ui_full(launch_kwargs): logging.basicConfig(level=logging.INFO, stream=sys.stderr) # Show the interface - ui_full(launch_kwargs) + ui_full(launch_kwargs) \ No newline at end of file diff --git a/demos/musicgen_style_demo.ipynb b/demos/musicgen_style_demo.ipynb index fae82c31..3ec689c9 100644 --- a/demos/musicgen_style_demo.ipynb +++ b/demos/musicgen_style_demo.ipynb @@ -37,8 +37,7 @@ "* `temperature` (float, optional): softmax temperature parameter. Defaults to 1.0.\n", "* `duration` (float, optional): duration of the generated waveform. Defaults to 30.0.\n", "* `cfg_coef` (float, optional): coefficient used for classifier free guidance. Defaults to 3.0.\n", - "* `double_cfg` (bool, optional): If True, use double CFG. Defaults to False.\n", - "* `cfg_coef_2` (float, optional): If double_cfg is True, sets the beta parameter that pushes the text. Defaults to None, user should start at 5.\n", + "* `cfg_coef_beta` (float, optional): If not None, we use double CFG. cfg_coef_beta is the parameter that pushes the text. Defaults to None, user should start at 5.\n", " If the generated music adheres to much to the text, the user should reduce this parameter. If the music adheres too much to the style conditioning, \n", " the user should increase it\n", "\n", @@ -93,8 +92,7 @@ " use_sampling=True, \n", " top_k=250,\n", " cfg_coef=3., # Classifier Free Guidance coefficient \n", - " double_cfg=False, # double CFG is only useful for text-and-style conditioning\n", - " cfg_coef_2=None,\n", + " cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning\n", ")\n", "\n", "output = model.generate(\n", @@ -137,8 +135,7 @@ " use_sampling=True, \n", " top_k=250,\n", " cfg_coef=3., # Classifier Free Guidance coefficient \n", - " double_cfg=False, # double CFG is only useful for text-and-style conditioning\n", - " cfg_coef_2=None,\n", + " cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning\n", ")\n", "\n", "model.set_style_conditioner_params(\n", @@ -188,9 +185,9 @@ " use_sampling=True, \n", " top_k=250,\n", " cfg_coef=3., # Classifier Free Guidance coefficient \n", - " double_cfg=True, # double CFG is necessary for text-and-style conditioning\n", - " cfg_coef_2=5., # Beta in the double CFG formula. between 1 and 9. When set to 1 \n", - " # it is equivalent to normal CFG. \n", + " cfg_coef_beta=5., # double CFG is necessary for text-and-style conditioning\n", + " # Beta in the double CFG formula. between 1 and 9. When set to 1 \n", + " # it is equivalent to normal CFG. \n", ")\n", "\n", "model.set_style_conditioner_params(\n", diff --git a/docs/MUSICGEN_STYLE.md b/docs/MUSICGEN_STYLE.md index e945571a..34955b11 100644 --- a/docs/MUSICGEN_STYLE.md +++ b/docs/MUSICGEN_STYLE.md @@ -52,8 +52,7 @@ model.set_generation_params( use_sampling=True, top_k=250, cfg_coef=3., # Classifier Free Guidance coefficient - double_cfg=False, # double CFG is only useful for text-and-style conditioning - cfg_coef_2=None, + cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning ) descriptions = ['disco beat', 'energetic EDM', 'funky groove'] @@ -78,8 +77,7 @@ model.set_generation_params( use_sampling=True, top_k=250, cfg_coef=3., # Classifier Free Guidance coefficient - double_cfg=False, # double CFG is only useful for text-and-style conditioning - cfg_coef_2=None, + cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning ) model.set_style_conditioner_params( @@ -115,8 +113,8 @@ model.set_generation_params( use_sampling=True, top_k=250, cfg_coef=3., # Classifier Free Guidance coefficient - double_cfg=True, # double CFG is necessary for text-and-style conditioning - cfg_coef_2=5., # Beta in the double CFG formula. between 1 and 9. When set to 1 it is equivalent to normal CFG. + cfg_coef_beta=5., # double CFG is necessary for text-and-style conditioning + # Beta in the double CFG formula. between 1 and 9. When set to 1 it is equivalent to normal CFG. # When we increase this parameter, the text condition is pushed. See the bottom of https://musicgenstyle.github.io/ # to better understand the effects of the double CFG coefficients. )