Skip to content

Commit

Permalink
Merge branch 'facebookresearch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
0xlws authored Nov 8, 2023
2 parents e95bcf9 + 5905d2e commit f317a5a
Show file tree
Hide file tree
Showing 26 changed files with 491 additions and 90 deletions.
4 changes: 2 additions & 2 deletions .github/actions/audiocraft_build/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ runs:
python3 -m venv env
. env/bin/activate
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install --pre xformers
pip install torch torchvision torchaudio
pip install xformers
pip install -e '.[dev]'
- name: System Dependencies
shell: bash
Expand Down
22 changes: 21 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,31 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.2.0a] - TBD

## [1.0.1] - TBD
Adding stereo models.


## [1.1.0] - 2023-11-06

Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.

Fixed DAC support with non default number of codebooks.

Fixed bug when `two_step_cfg` was overriden when calling `generate()`.

Fixed samples being always prompted with audio, rather than having both prompted and unprompted.

**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
We removed it, so you might need to retrain models.

**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).

**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
retrained a model with this pattern, so hopefully this won't impact you!


## [1.0.0] - 2023-09-07

Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
Expand Down
2 changes: 1 addition & 1 deletion audiocraft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@
# flake8: noqa
from . import data, modules, models

__version__ = '1.0.0'
__version__ = '1.2.0a1'
57 changes: 57 additions & 0 deletions audiocraft/grids/musicgen/musicgen_stereo_finetune_32khz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from pathlib import Path
from ._explorers import LMExplorer
from ...environment import AudioCraftEnvironment


@LMExplorer
def explorer(launcher):
partitions = AudioCraftEnvironment.get_slurm_partitions(['team', 'global'])
launcher.slurm_(gpus=32, partition=partitions)
launcher.bind_(solver='musicgen/musicgen_base_32khz')
# replace this by the desired music dataset, which needs to be stereo
launcher.bind_(dset='audio/example')

fsdp = {'autocast': False, 'fsdp.use': True}
medium = {'model/lm/model_scale': 'medium'}
large = {'model/lm/model_scale': 'large'}

cfg_low = {'classifier_free_guidance.training_dropout': 0.2}
wd_low = {'conditioners.description.t5.word_dropout': 0.2}

adam = {'optim.optimizer': 'adamw', 'optim.lr': 1e-4}

stereo = {
'codebooks_pattern.delay.delays': [0, 0, 1, 1, 2, 2, 3, 3],
'transformer_lm.n_q': 8,
'interleave_stereo_codebooks.use': True,
'channels': 2,
}

# You must follow the instructions in docs/MUSICGEN.md about the creation
# of the proper fine tuning checkpoints. We will assume they are stored under
# ~/checkpoints/{mode_name}.

checkpoints = Path.home() / 'checkpoints'

launcher.bind_(fsdp, stereo, {'optim.epochs': 100})

launcher.slurm_(gpus=32).bind_(label='32gpus')
with launcher.job_array():
sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-small.th')})
sub()

launcher.slurm_(gpus=64).bind_(label='64gpus')
with launcher.job_array():
sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-medium.th')})
sub(medium, adam)

launcher.slurm_(gpus=96).bind_(label='96gpus')
with launcher.job_array():
sub = launcher.bind({'continue_from': str(checkpoints / 'stereo_finetune_musicgen-large.th')})
sub(large, cfg_low, wd_low, adam, {'optim.max_norm': 3})
4 changes: 4 additions & 0 deletions audiocraft/models/audiogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
self.name = name
self.compression_model = compression_model
self.lm = lm
# Just to be safe, let's put everything in eval mode.
self.compression_model.eval()
self.lm.eval()

if max_duration is None:
if hasattr(lm, 'cfg'):
max_duration = lm.cfg.dataset.segment_duration # type: ignore
Expand Down
16 changes: 11 additions & 5 deletions audiocraft/models/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import omegaconf
import torch

from .encodec import CompressionModel, EncodecModel
from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
from .lm import LMModel
from ..modules.codebooks_patterns import (
CodebooksPatternProvider,
DelayedPatternProvider,
MusicLMPattern,
ParallelPatternProvider,
UnrolledPatternProvider,
VALLEPattern,
CoarseFirstPattern,
)
from ..modules.conditioners import (
BaseConditioner,
Expand Down Expand Up @@ -172,7 +172,7 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb
'parallel': ParallelPatternProvider,
'delay': DelayedPatternProvider,
'unroll': UnrolledPatternProvider,
'valle': VALLEPattern,
'coarse_first': CoarseFirstPattern,
'musiclm': MusicLMPattern,
}
name = cfg.modeling
Expand All @@ -196,7 +196,6 @@ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
'dimension': 32,
'ratios': ratios,
}
print(seanet_kwargs)
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
Expand Down Expand Up @@ -248,5 +247,12 @@ def get_debug_lm_model(device='cpu'):
def get_wrapped_compression_model(
compression_model: CompressionModel,
cfg: omegaconf.DictConfig) -> CompressionModel:
# more to come.
if hasattr(cfg, 'interleave_stereo_codebooks'):
if cfg.interleave_stereo_codebooks.use:
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
kwargs.pop('use')
compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
if hasattr(cfg, 'compression_model_n_q'):
if cfg.compression_model_n_q is not None:
compression_model.set_num_codebooks(cfg.compression_model_n_q)
return compression_model
115 changes: 114 additions & 1 deletion audiocraft/models/encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pathlib import Path
import typing as tp

from einops import rearrange
import numpy as np
import torch
from torch import nn
Expand Down Expand Up @@ -276,7 +277,7 @@ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:

def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
codes = self.model.encode(x, self.n_quantizers)[1]
return codes, None
return codes[:, :self.n_quantizers], None

def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
assert scale is None
Expand Down Expand Up @@ -391,3 +392,115 @@ def set_num_codebooks(self, n: int):
if n not in self.possible_num_codebooks:
raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
self._num_codebooks = n


class InterleaveStereoCompressionModel(CompressionModel):
"""Wraps a CompressionModel to support stereo inputs. The wrapped model
will be applied independently to the left and right channels, and both codebooks
will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
`per_timestep`.
Args:
model (CompressionModel): Compression model to wrap.
per_timestep (bool): Whether to interleave on the timestep dimension
or on the codebooks dimension.
"""
def __init__(self, model: CompressionModel, per_timestep: bool = False):
super().__init__()
self.model = model
self.per_timestep = per_timestep
assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"

@property
def total_codebooks(self):
return self.model.total_codebooks

@property
def num_codebooks(self):
"""Active number of codebooks used by the quantizer.
..Warning:: this reports the number of codebooks after the interleaving
of the codebooks!
"""
return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2

def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer.
..Warning:: this sets the number of codebooks before the interleaving!
"""
self.model.set_num_codebooks(n)

@property
def num_virtual_steps(self) -> float:
"""Return the number of virtual steps, e.g. one real step
will be split into that many steps.
"""
return 2 if self.per_timestep else 1

@property
def frame_rate(self) -> float:
return self.model.frame_rate * self.num_virtual_steps

@property
def sample_rate(self) -> int:
return self.model.sample_rate

@property
def channels(self) -> int:
return 2

@property
def cardinality(self):
"""Cardinality of each codebook.
"""
return self.model.cardinality

def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
raise NotImplementedError("Not supported, use encode and decode.")

def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
B, C, T = x.shape
assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"

indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
indices = torch.stack([indices_c0, indices_c1], dim=0)
scales: tp.Optional[torch.Tensor] = None
if scales_c0 is not None and scales_c1 is not None:
scales = torch.stack([scales_c0, scales_c1], dim=1)

if self.per_timestep:
indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
else:
indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)

return (indices, scales)

def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
if self.per_timestep:
codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
else:
codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
return codes[0], codes[1]

def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
B, K, T = codes.shape
assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"

scale_c0, scale_c1 = None, None
if scale is not None:
assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
scale_c0 = scale[0, ...]
scale_c1 = scale[1, ...]

codes_c0, codes_c1 = self.get_left_right_codes(codes)
audio_c0 = self.model.decode(codes_c0, scale_c0)
audio_c1 = self.model.decode(codes_c1, scale_c1)
return torch.cat([audio_c0, audio_c1], dim=1)

def decode_latent(self, codes: torch.Tensor):
"""Decode from the discrete codes to continuous latent space."""
raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
8 changes: 5 additions & 3 deletions audiocraft/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,8 @@ def _sample_next_token(self,
temp: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
cfg_coef: tp.Optional[float] = None,
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
"""Sample next token from the model given a sequence and a set of conditions. The model supports
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
Expand All @@ -335,7 +336,8 @@ def _sample_next_token(self,
B = sequence.shape[0]
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
model = self if self._fsdp is None else self._fsdp
if self.two_step_cfg and cfg_conditions != {}:
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
if two_step_cfg and cfg_conditions != {}:
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
condition_tensors, null_condition_tensors = cfg_conditions
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
Expand Down Expand Up @@ -493,7 +495,7 @@ def generate(self,
# sample next token from the model, next token shape is [B, K, 1]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
cfg_coef=cfg_coef)
cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
# ensure the tokens that should be masked are properly set to special_token_id
# as the model never output special_token_id
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
Expand Down
5 changes: 4 additions & 1 deletion audiocraft/models/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from omegaconf import OmegaConf, DictConfig
import torch

import audiocraft
from . import builders
from .encodec import CompressionModel

Expand Down Expand Up @@ -60,7 +61,9 @@ def _get_state_dict(
else:
assert filename is not None, "filename needs to be defined if using HF checkpoints"

file = hf_hub_download(repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir)
file = hf_hub_download(
repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
library_name="audiocraft", library_version=audiocraft.__version__)
return torch.load(file, map_location=device)


Expand Down
Loading

0 comments on commit f317a5a

Please sign in to comment.