Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
simonrouard committed Nov 5, 2024
1 parent 47a5b14 commit 4e4972d
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 66 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,3 @@ dataset/*
/notebooks
/local_scripts
/notes

Untitled.ipynb
magnet_demo.ipynb
16 changes: 2 additions & 14 deletions audiocraft/modules/conditioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,8 +708,7 @@ class FeatureExtractor(WaveformConditioner):
Then, we feed this excerpt to a feature extractor.
Args:
model_name (str): 'encodec', 'musicfm' or 'mert'. For now 'musicfm'
is not supported.
model_name (str): 'encodec' or 'mert'.
sample_rate (str): sample rate of the input audio. (32000)
encodec_checkpoint (str): if encodec is used as a feature extractor, checkpoint
of the model. ('//pretrained/facebook/encodec_32khz' is the default)
Expand All @@ -736,13 +735,10 @@ def __init__(
use_middle_of_segment: bool = False, ds_rate_compression: int = 640,
num_codebooks_lm: int = 4
):
assert model_name in ['encodec', 'musicfm', 'mert']
assert model_name in ['encodec', 'mert']
if model_name == 'encodec':
from ..solvers.compression import CompressionSolver
feat_extractor = CompressionSolver.model_from_checkpoint(encodec_checkpoint, device)
# elif model_name == 'musicfm':
# from ..musicfmbis.model.musicfm_25hz import MusicFM25Hz
# feat_extractor = MusicFM25Hz()
elif model_name == 'mert':
from transformers import AutoModel
feat_extractor = AutoModel.from_pretrained("m-a-p/MERT-v1-95M", trust_remote_code=True)
Expand All @@ -759,9 +755,6 @@ def __init__(
self.__dict__["feat_extractor"] = feat_extractor.to(device)
self.encodec_n_q = encodec_n_q
self.embed = nn.ModuleList([nn.Embedding(feat_extractor.cardinality, dim) for _ in range(encodec_n_q)])
elif model_name == 'musicfm':
self.__dict__["feat_extractor"] = feat_extractor.eval().to(device)
self.embed = nn.Linear(1024, dim) # hardcoded
if model_name == 'mert':
self.__dict__["feat_extractor"] = feat_extractor.eval().to(device)
self.embed = nn.Linear(768, dim) # hardcoded
Expand All @@ -787,9 +780,6 @@ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
self.temp_mask = self._get_mask_wav(x, start)
if self.model_name == 'encodec':
tokens = self.feat_extractor.encode(wav)[0] # type: ignore
elif self.model_name == 'musicfm':
wav = convert_audio(wav, from_rate=x.sample_rate[0], to_rate=24000, to_channels=1)
embeds = self.feat_extractor.get_latent(wav, layer_ix=6) # type: ignore
elif self.model_name == 'mert':
wav = convert_audio(wav, from_rate=x.sample_rate[0], to_rate=24000, to_channels=1)
embeds = self.feat_extractor(wav.squeeze(-2)).last_hidden_state
Expand All @@ -804,8 +794,6 @@ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
def _downsampling_factor(self):
if self.model_name == 'encodec':
return self.sample_rate / self.feat_extractor.frame_rate
elif self.model_name == 'musicfm':
return self.sample_rate / 25
elif self.model_name == 'mert':
return self.sample_rate / 75

Expand Down
2 changes: 1 addition & 1 deletion config/solver/musicgen/musicgen_style_32khz.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package __global__

# This is the training loop solver
# for the base MusicGen model (text-to-music)
# for MusicGen-Style model (text-and-style-to-music)
# on monophonic audio sampled at 32 kHz
defaults:
- musicgen/default
Expand Down
18 changes: 4 additions & 14 deletions config/teams/labs.yaml
Original file line number Diff line number Diff line change
@@ -1,29 +1,19 @@
aws:
dora_dir: /fsx-codegen/${oc.env:USER}/experiments/audiocraft/outputs
dora_dir: /fsx-audio-craft-llm/${oc.env:USER}/experiments/audiocraft/outputs
partitions:
global: learnlab
team: learnlab
reference_dir: /fsx-codegen/shared/audiocraft/reference
reference_dir: /fsx-audio-craft-llm/shared/audiocraft/reference
dataset_mappers:
"^/checkpoint/jadecopet/datasets/mmi/mmi_11k_32khz": "/fsx-shutterstock-music-resampled/dataset/mmi_11k_32khz"
# "^/fsx-audio-craft-llm/datasets/mmi_nv/mmi_11k"
"^/fsx-audio-craft-llm/datasets/shutterstock_nv/p5": "/fsx-shutterstock-music-original/dataset/p5"
"^/fsx-audio-craft-llm/datasets/shutterstock_nv/shutterstock": "/fsx-shutterstock-music-original/dataset/shutterstock"
"^/checkpoint/[a-z]+": "/fsx-audio-craft-llm"
fair:
dora_dir: /checkpoint/${oc.env:USER}/experiments/audiocraft/outputs
partitions:
global: learnlab
team: learnlab
reference_dir: /large_experiments/audiocraft/audiocraft_magma/reference
reference_dir: /large_experiments/audiocraft/reference
dataset_mappers:
"^/datasets01/datasets01": "/datasets01"
"^/datasets01/shutterstock-music-resampled/shutterstock_32khz_wav": "/large_experiments/audiocraft/datasets/datasets_32khz_mono/shutterstock"
"^/datasets01/shutterstock-music-resampled/p5_32khz_wav": "/large_experiments/audiocraft/datasets/datasets_32khz_mono/p5"
"^/checkpoint/jadecopet/datasets/mmi/mmi_11k_32khz": "/large_experiments/audiocraft/datasets/datasets_32khz_mono/mmi_11k"
"^/fsx-shutterstock-music-resampled/dataset/p5_32khz_wav": "/large_experiments/audiocraft/datasets/datasets_32khz_mono/p5"
"^/fsx-shutterstock-music-resampled/dataset/shutterstock_32khz_wav": "/large_experiments/audiocraft/datasets/datasets_32khz_mono/shutterstock"
"^/large_experiments/audiocraft/datasets/datasets_32khz_mono/shutterstock/db/e54/c55/f6f3/5329/14443.wav": "/datasets01/shutterstock-music-resampled/shutterstock_32khz_wav/db/e54/c55/f6f3/5329/14443.wav"
# last line is to fix the missing file '/large_experiments/audiocraft/datasets/datasets_32khz_mono/shutterstock/db/e54/c55/f6f3/5329/14443.wav'
darwin:
dora_dir: /tmp/audiocraft_${oc.env:USER}
partitions:
Expand Down
34 changes: 0 additions & 34 deletions docs/MUSICGEN_STYLE.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,37 +196,3 @@ dora run solver=musicgen/musicgen_style_32khz model/lm/model_scale=medium contin
to change some parts, like the conditioning or some other parts of the model, you are responsible for manually crafting a checkpoint file from which we can safely run `load_state_dict`.
If you decide to do so, make sure your checkpoint is saved with `torch.save` and contains a dict
`{'best_state': {'model': model_state_dict_here}}`. Directly give the path to `continue_from` without a `//pretrained/` prefix.


## FAQ

#### What are top-k, top-p, temperature and classifier-free guidance?

Check out [@FurkanGozukara tutorial](https://github.com/FurkanGozukara/Stable-Diffusion/blob/main/Tutorials/AI-Music-Generation-Audiocraft-Tutorial.md#more-info-about-top-k-top-p-temperature-and-classifier-free-guidance-from-chatgpt).

#### Should I use FSDP or autocast ?

The two are mutually exclusive (because FSDP does autocast on its own).
You can use autocast up to 1.5B (medium), if you have enough RAM on your GPU.
FSDP makes everything more complex but will free up some memory for the actual
activations by sharding the optimizer state.

## Citation
```
@misc{rouard2024audioconditioningmusicgeneration,
title={Audio Conditioning for Music Generation via Discrete Bottleneck Features},
author={Simon Rouard and Yossi Adi and Jade Copet and Axel Roebel and Alexandre Défossez},
year={2024},
eprint={2407.12563},
archivePrefix={arXiv},
primaryClass={cs.SD},
url={https://arxiv.org/abs/2407.12563},
}
```

## License

See license information in the [model card](../model_cards/MUSICGEN_STYLE_MODEL_CARD.md).

[arxiv]: https://arxiv.org/abs/2407.12563
[musicgen_style_samples]: https://musicgenstyle.github.io/

0 comments on commit 4e4972d

Please sign in to comment.