Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
simonrouard committed Nov 5, 2024
1 parent 4e4972d commit fd9fee4
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 131 deletions.
104 changes: 23 additions & 81 deletions audiocraft/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ConditioningProvider,
ConditioningAttributes,
ConditionType,
_drop_text_condition
)
from ..modules.codebooks_patterns import CodebooksPatternProvider
from ..modules.activations import get_activation_fn
Expand Down Expand Up @@ -95,55 +96,6 @@ def init_layer(m: nn.Module,
init_fn(m.weight)


def merge_pairs_of_conditions(alphas: tp.Dict[str, float], num_conditions: int, cfg_conditions: CFGConditions):
"""
Given:
- alphas: dic where the keys are attributes that need to be merged and the values are
floats in [0, 1]
ex: {'description': 0.3, 'self_wav': 0.5, 'style_wav': 1.0}
- num_conditions: the number of conditions in parallel, 2 in case of cfg, 3 in case of
double_cfg and 4 in case of triple_cfg
- cfg_conditions
Returns:
for each pair of condition (2i, 2i+1),
we compute sqrt{alpha}*condition_{2i} + sqrt{1 - alpha**2}*condition_{2i+1}
"""
new_cfg_conditions = deepcopy(cfg_conditions)
for attribute in alphas.keys():
embed, mask = cfg_conditions[attribute] # type: ignore
B, T, C = embed.shape # type: ignore
assert B % (2 * num_conditions) == 0
alpha = alphas[attribute]
new_embed = alpha ** 0.5 * embed[::2] + (1 - alpha)**0.5 * embed[1::2]
new_mask = mask[::2]
new_cfg_conditions[attribute] = (new_embed, new_mask) # type: ignore
return new_cfg_conditions


def sum_pairs_of_conditions(which_conditions: tp.List[str], num_conditions: int, cfg_conditions: CFGConditions):
"""
Given:
- which_conditions: list of the attributes that need to be merged
ex: ['description', 'self_wav', 'style_wav']
- num_conditions: the number of conditions in parallel, 2 in case of cfg, 3 in case of
double_cfg and 4 in case of triple_cfg
- cfg_conditions
Returns: for each pair of condition (2i, 2i+1), we compute
alpha*condition_{2i} + sqrt{1 - alpha**2}*condition_{2i+1}
"""
new_cfg_conditions = deepcopy(cfg_conditions)
for attribute in which_conditions:
embed, mask = cfg_conditions[attribute] # type: ignore
B, T, C = embed.shape # type: ignore
assert B % (2 * num_conditions) == 0
new_embed = embed[::2] + embed[1::2]
new_mask = mask[::2] # type: ignore
new_cfg_conditions[attribute] = (new_embed, new_mask) # type: ignore
return new_cfg_conditions


class ScaledEmbedding(nn.Embedding):
"""Boost learning rate for embeddings (with `scale`).
"""
Expand Down Expand Up @@ -377,9 +329,8 @@ def _sample_next_token(self,
temp: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
double_cfg: bool = False,
cfg_coef: tp.Optional[float] = None,
cfg_coef_2: tp.Optional[float] = None,
cfg_coef_beta: tp.Optional[float] = None,
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
"""Sample next token from the model given a sequence and a set of conditions. The model supports
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
Expand All @@ -395,26 +346,32 @@ def _sample_next_token(self,
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coef (float, optional): classifier free guidance coefficient
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
push the text condition more than the style condition in the case where both text and style
conditions are being used.
Returns:
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
"""
B = sequence.shape[0]
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
model = self if self._fsdp is None else self._fsdp
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
if double_cfg:
if cfg_coef_beta is not None:
assert isinstance(cfg_conditions, dict)
assert cfg_coef_2 is not None
condition_tensors = cfg_conditions
if condition_tensors:
# Preparing for CFG, predicting conditional text and style, conditional style
# and unconditional
sequence = torch.cat([sequence, sequence, sequence], dim=0)
all_logits = model(
sequence,
conditions=[], condition_tensors=condition_tensors)
if condition_tensors:
cond_logits, wav_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
logits = uncond_logits + cfg_coef * (
wav_logits + cfg_coef_2 * (cond_logits - wav_logits) - uncond_logits
wav_logits + cfg_coef_beta * (cond_logits - wav_logits) - uncond_logits
)

elif two_step_cfg and cfg_conditions != {}:
Expand Down Expand Up @@ -470,29 +427,30 @@ def generate(self,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None,
cfg_coef_beta: tp.Optional[float] = None,
two_step_cfg: tp.Optional[bool] = None,
double_cfg: bool = False,
cfg_coef_2: tp.Optional[float] = None,
remove_prompts: bool = False,
check: bool = False,
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
postprocess_fn: tp.Optional[str] = None,
alphas: tp.Optional[tp.Dict[str, float]] = None,
which_conditions: tp.Optional[tp.List[str]] = None,
) -> torch.Tensor:
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
be performed in a greedy fashion or using sampling with top K and top P strategies.
Args:
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
conditions (list of ConditioningAttributes, optional): List of conditions.
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
max_gen_len (int): Maximum generation length.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coeff (float, optional): Classifier-free guidance coefficient.
cfg_coef (float, optional): Classifier-free guidance coefficient.
cfg_coef_beta (float, optional): If None, simple classifier free guidance is used with cfg_coef.
If not None, we apply double classifier free guidance as introduced in MusicGen-Style
in paragraph 4.3 (https://arxiv.org/pdf/2407.12563). This beta coefficient is meant to
push the text condition more than the style condition in the case where both text and style
conditions are being used.
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
remove_prompts (bool): Whether to remove prompts from generation or not.
check (bool): Whether to apply further checks on generated sequence.
Expand Down Expand Up @@ -528,17 +486,14 @@ def generate(self,
# With a batch size of 1, this can be slower though.
cfg_conditions: CFGConditions
cfg_conditions = {}
if double_cfg:
num_conditions = 3
if cfg_coef_beta is not None:
if conditions:
wav_conditions = AttributeDropout(p={'text': {'description': 1.0},
'wav': {'self_wav': 0.0}})(conditions)
wav_conditions = _drop_text_condition(conditions)
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
conditions = conditions + wav_conditions + null_conditions
tokenized = self.condition_provider.tokenize(conditions)
cfg_conditions = self.condition_provider(tokenized)
elif conditions:
num_conditions = 2
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
if conditions:
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
Expand All @@ -553,23 +508,10 @@ def generate(self,
cfg_conditions = self.condition_provider(tokenized)
else:
cfg_conditions = {}
if postprocess_fn is not None:
if postprocess_fn == 'merge':
assert alphas is not None
cfg_conditions = merge_pairs_of_conditions(alphas, num_conditions, cfg_conditions)
elif postprocess_fn == 'sum':
assert which_conditions is not None
cfg_conditions = sum_pairs_of_conditions(which_conditions, num_conditions, cfg_conditions)
else:
assert False

if prompt is None:
assert num_samples > 0
if postprocess_fn in ['merge', 'sum']:
assert num_samples % 2 == 0
prompt = torch.zeros((num_samples // 2, self.num_codebooks, 0), dtype=torch.long, device=device)
else:
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)

B, K, T = prompt.shape
start_offset = T
Expand Down Expand Up @@ -606,7 +548,7 @@ def generate(self,
# sample next token from the model, next token shape is [B, K, 1]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
double_cfg=double_cfg, cfg_coef=cfg_coef, cfg_coef_2=cfg_coef_2, two_step_cfg=two_step_cfg)
cfg_coef=cfg_coef, cfg_coef_beta=cfg_coef_beta, two_step_cfg=two_step_cfg)
# ensure the tokens that should be masked are properly set to special_token_id
# as the model never output special_token_id
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
Expand Down
12 changes: 2 additions & 10 deletions audiocraft/models/lm_magnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,17 @@ def generate(self,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None,
two_step_cfg: tp.Optional[bool] = None,
double_cfg: bool = False,
cfg_coef_2: tp.Optional[float] = None,
cfg_coef_beta: tp.Optional[float] = None,
remove_prompts: bool = False,
check: bool = False,
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
postprocess_fn: tp.Optional[str] = None,
alphas: tp.Optional[tp.Dict[str, float]] = None,
which_conditions: tp.Optional[tp.List[str]] = None,
**kwargs) -> torch.Tensor:

assert cfg_coef is None, "Unsupported in MAGNeT. Use max_cfg_coef,min_cfg_coef instead."
assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance."
assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg."
assert check is False, "MAGNeT currently doesn't support the check arg."
assert postprocess_fn is None, "MAGNeT currently doesn't support the postprocess_fn arg."
assert alphas is None, "MAGNeT currently doesn't support the alphas arg."
assert which_conditions is None, "MAGNeT currently doesn't support the which_conditions arg."
assert double_cfg is False, "MAGNeT currently doesn't support the double_cfg arg."
assert cfg_coef_2 is None, "MAGNeT currently doesn't support the cfg_coef_2 arg."
assert cfg_coef_beta is None, "MAGNeT currently doesn't support the cfg_coef_beta arg."
# Call the MAGNeT-specific generation method
return self._generate_magnet(prompt=prompt,
conditions=conditions,
Expand Down
19 changes: 8 additions & 11 deletions audiocraft/models/musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,9 @@ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):

def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
top_p: float = 0.0, temperature: float = 1.0,
duration: float = 30.0, double_cfg: bool = False,
cfg_coef: float = 3.0, cfg_coef_2: tp.Optional[float] = None,
two_step_cfg: bool = False, extend_stride: float = 18,
postprocess_fn: tp.Optional[str] = None,
alphas: tp.Optional[tp.Dict[str, float]] = None,
which_conditions: tp.Optional[tp.List[str]] = None):
duration: float = 30.0, cfg_coef: float = 3.0,
cfg_coef_beta: tp.Optional[float] = None,
two_step_cfg: bool = False, extend_stride: float = 18,):
"""Set the generation parameters for MusicGen.
Args:
Expand All @@ -110,6 +107,10 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
cfg_coef_beta (float, optional): beta coefficient in double classifier free guidance.
Should be only used for MusicGen melody if we want to push the text condition more than
the melody conditioning. See paragraph 4.3 in https://arxiv.org/pdf/2407.12563 to understand
double CFG.
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
instead of batching together the two. This has some impact on how things
are padded but seems to have little impact in practice.
Expand All @@ -127,11 +128,7 @@ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
'top_p': top_p,
'cfg_coef': cfg_coef,
'two_step_cfg': two_step_cfg,
'double_cfg': double_cfg,
'cfg_coef_2': cfg_coef_2,
'postprocess_fn': postprocess_fn,
'alphas': alphas,
'which_conditions': which_conditions
'cfg_coef_beta': cfg_coef_beta,
}

def set_style_conditioner_params(self, eval_q: int = 3, excerpt_length: float = 3.0,
Expand Down
10 changes: 10 additions & 0 deletions audiocraft/modules/conditioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,16 @@ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
seek_time=[0] * embed.wav.shape[0],
)

def _drop_text_condition(conditions: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
"""Drop the text condition but keep the wav conditon on a list of ConditioningAttributes.
This is useful to calculate l_style in the double classifier free guidance formula.
See paragraph 4.3 in https://arxiv.org/pdf/2407.12563
Args:
conditions (tp.List[ConditioningAttributes]): List of conditions.
"""
return AttributeDropout(p={'text': {'description': 1.0},
'wav': {'self_wav': 0.0}})(conditions)

class Tokenizer:
"""Base tokenizer implementation
Expand Down
3 changes: 1 addition & 2 deletions config/solver/musicgen/musicgen_style_32khz.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ generate:
top_k: 250
top_p: 0.0
cfg_coef: 3.0
cfg_coef_2:
double_cfg: false
cfg_coef_beta:

optim:
epochs: 500
Expand Down
Loading

0 comments on commit fd9fee4

Please sign in to comment.