Skip to content

Commit

Permalink
prevent deadlock in dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
MoayedHajiAli committed Oct 13, 2024
1 parent a608256 commit 7503e62
Show file tree
Hide file tree
Showing 7 changed files with 407 additions and 287 deletions.
148 changes: 68 additions & 80 deletions AutoCap/README.md
Original file line number Diff line number Diff line change
@@ -1,119 +1,106 @@
<!-- [![arXiv](ARXIV ICON)](ARXIV LINK) -->
[![arXiv](ARXIV ICON)](ARXIV LINK)

# GenAU inference, training, and evaluation
- [Introduction](#introduction)
- [Environemnt setup](#environment-initalization)
# AutoCap inference, training and evaluation
- [Inference](#inference)
* [Audio to text script](#text-to-audio)<!-- * [Gradio demo](#gradio-demo) -->
* [Inference a list of promots](#inference-a-list-of-prompts)
* [Audio to text script](#audio-to-text)
* [Gradio demo](#gradio-demo)
* [Caption a list of audio files](#caption-list-of-audio-files)
* [Caption your custom dataset](#caption-a-dataset)
- [Training](#training)
* [GenAU](#genau)
* [Finetuning GenAU](#finetuning-genau)
* [1D-VAE (optional)](#1d-vae-optional)
- [Evaluation](#evaluation)
- [Cite this work](#cite-this-work)
- [Acknowledgements](#acknowledgements)

# Introduction
We introduce GenAU, a transformer-based audio latent diffusion model leveraging the FIT architecture. Our model compresses mel-spectrogram data into a 1D representation and utilizes layered attention processes to achieve state-of-the-art audio generation results among open-source models.
<br/>

<div align="center">
<img src="../assets/genau.png" width="900" />
</div>

<br/>

# Environment initialization
# Environment initalization
For initializing your environment, please refer to the [general README](../README.md).

# Inference

## Text to Audio
To quickly generate an audio based on an input text prompt, run
## Audio to Text
To quickly generate a caption for an input audio, run
```shell
python scripts/text_to_audio.py --prompt "Horses growl and clop hooves." --model "genau-full-l"
python scripts/audio_to_text.py --wav_path <path-to-wav-file>

# Example inference
python scripts/audio_to_text.py --wav_path samples/ood_samples/loudwhistle-91003.wav
```
- This will automatically download and use the model `genau-full-l` with default settings. You may change these parameters or provide your custom model config file and checkpoint path.
- Available models include `genau-full-l` (1.25B parameters) and `genau-full-s` (493M parameters)
- These models are trained to generate ambient sounds and is incapable of generating speech or music.
- Outputs will be saved by default at `samples/model_output` using the provided prompt as the file name.
- This will automatically download `TODO` model and run the inference with the default parameters. You may change these parameters or provide your cutome model config file and checkpoint path.
- For more accurate captioning, provide meta data using `--title`, `description`, and `--video_caption` arguments.

<!-- ## Gradio Demo
Run a local interactive demo with Gradio:
## Gradio Demo
A local Gradio demo is also available by running
```shell
python app_text2audio.py
``` -->
python app_audio2text.py
```

## Inference a list of prompts
Optionally, you may prepare a `.txt` file with your target prompts and run
## Caption list of audio files
- Prepare all target audio files in a single folder
- Optionally, provide meta data information in `yaml` file using the following structure
```yaml
file_name.wav:
title: "video title"
description: "video description"
video_caption: "video caption"
```
Then run the following script
```shell
python scripts/inference_file.py --list_inference <path-to-prompts-file> --model <model_name>
python scripts/inference_folder.py --folder_path <path-to-audio-folder> --meta_data_file <path-to-metadata-yaml-file>

# Example
python scripts/inference_file.py --list_inference samples/prompts_list.txt --model "genau-full-l"
# Example inference
python scripts/inference_folder.py --folder_path samples/ood_samples --meta_data_file samples/ood_samples/meta_data.yaml
```

## Caption your custom dataset

## Training

### Dataset
Please refer to the [dataset preparation README](../dataset_preperation/README.md) for instructions on downloading our dataset or preparing your own.

### GenAU
- Prepare a yaml config file for your experiments. A sample config file is provided at `settings/simple_runs/genau.yaml`
- Specify your project name and provide your Wandb key in the config file. A Wandb key can be obtained from [https://wandb.ai/authorize](https://wandb.ai/authorize)
- Optionally, provide your S3 bucket and folder to save intermediate checkpoints.
- By default, checkpoints will be saved under `run_logs/genau/train` at the same level as the config file.
If you want to caption a large dataset, we provide a script that works with multigpus for faster inference.
- Prepare your custom dataset by following the instruction in the dataset prepeartion README (TODO) and run

```shell
# Training GenAU from scratch
python train/genau.py -c settings/simple_runs/genau.yaml
```
python scripts/caption_dataset.py \
--caption_store_key <key-to-store-generated-captions> \
--beam_size 2 \
--start_idx 0 \
--end_idx 1000000 \
--dataset_keys "dataset_1" "dataset_2" ...

For multinode training, run
```shell
python -m torch.distributed.run --nproc_per_node=8 train/genau.py -c settings/simple_runs/genau.yaml
```
### Finetuning GenAU
- Provide your dataset keys as registered in the dataset preperation (TODO)
- Captions will be generated and stores in each file json file with the specified caption_ store_key
- `start_idx` and `end_idx` arugments can be used to resume or distribute captioning experiments
- Add your `caption_store_key` under `keys_synonyms:gt_audio_caption` in the target yaml config file for it to be selected when the ground truth caption is not available in your audio captioning or audio generation experiments.

- Prepare your custom dataset and obtain the dataset keys following [dataset preparation README](../dataset_preperation/README.md)
- Make a copy and adjust the default config file of `genau-full-l` which you can find under `pretrained_models/genau/genau-full-l.yaml`
- Add ids for your dataset keys under `dataset2id` attribute in the config file.

# Training
### Dataset
Please refer to the dataset README (TODO) for instructions on downloading our dataset or preparing your own dataset.

### Stage 1 (pretraining)
- Specify your model parameters in a config yaml file. A sample yaml file is given under `settings/pretraining.yaml`
- Specify your project name and provide your wandb key in the config file. A wandb key can be obtained from [https://wandb.ai/authorize](https://wandb.ai/authorize)
- Optionally, provide your S3 bucket and folder to save intermediate checkpoints.
- By default, checkpoints will be save under `run_logs/train`
```shell
# Finetuning GenAU
python train/genau.py --reload_from_ckpt 'genau-full-l' \
--config <path-to-config-file> \
--dataset_keys "<dataset_key_1>" "<dataset_key_2>" ...
python train.py -c settings/pretraining.yaml
```


### 1D VAE (Optional)
By default, we offer a pre-trained 1D-VAE for GenAU training. If you prefer, you can train your own VAE by following the provided instructions.
- Prepare your own dataset following the instructions in the [dataset preparation README](../dataset_preperation/README.md)
- Prepare your yaml config file in a similar way to the GenAU config file
- A sample config file is provided at `settings/simple_runs/1d_vae.yaml`
### Stage 2 (finetuning)
- Prepare your finetuning config file in a similar way as the pretraining stage. Typically, you only need to provide `pretrain_path` to your pretraining checkpoint, adjust learning rate, and untoggle the freeze option for the `text_decoder`.
- A sample fintuning config is provided under `settings/finetuning.yaml`

```shell
python train/1d_vae.py -c settings/simple_runs/1d_vae.yaml
python train.py -c settings/finetuning.yaml
```

## Evaluation
- We follow [audioldm](https://github.com/haoheliu/AudioLDM-training-finetuning) to perform our evaulations.
- By default, the models will be evaluated periodically during training as specified in the config file. For each evaluation, a folder with the generated audio will be saved under `run_logs/train' at the same levels as the specified config file.
- The code identifies the test dataset in an already existing folder according to the number of samples. If you would like to test on a new test dataset, register it in `scripts/generate_and_eval`

# Evalution
- By default, the models will be log metrics on the validation set to wandb periodically during training as specified in the config file.
- We exclude the `spice`, `spideer` and `meteor` metrics during training as they tend to hang out the training during multigpu training. You man inlcude them by changing the configruation.
- A file with the predicted captions during evaluation will be saved under `run_logs/train` and metrics can be found in a file named `output.txt` under the logging folder.
- To run the evaluation on the test set, after the training finishes, run:
```shell

# Evaluate an existing generated folder
python scripts/evaluate.py --log_path <path-to-the-experiment-folder>

# Geneate test audios from a pre-trained checkpoint and run evaulation
python scripts/generate_and_eval.py -c <path-to-config> -ckpt <path-to-pretrained-ckpt>
python evaluate.py -c <path-to-config> -ckpt <path-to-checkpoint>
```
The evaluation result will be saved in a JSON file at the same level of the generated audio folder.

# Cite this work
If you found this useful, please consider citing our work
Expand All @@ -122,5 +109,6 @@ If you found this useful, please consider citing our work
```

# Acknowledgements
Our audio generation and evaluation codebase relies on [audioldm](https://github.com/haoheliu/AudioLDM-training-finetuning). We sincerely appreciate the authors for sharing their code openly.

We sincerely thank the authors of the following work for sharing their code publicly:
- [WavCaps: A ChatGPT-Assisted Weakly-Labelled Audio Captioning Dataset for Audio-Language Multimodal Research](https://github.com/XinhaoMei/WavCaps)
- [Audio Captioning Transformer](https://github.com/XinhaoMei/ACT/tree/main/coco_caption)
31 changes: 18 additions & 13 deletions AutoCap/src/models/pl_htsat_q_bart_captioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def forward(self, token_ids):
token_ids = torch.where(special_tokens_mask, self.base_tokenizer.pad_token_id, token_ids)
embeddings = self.model.encoder.embed_tokens(token_ids)


embeddings.view(-1, embeddings.shape[-1])[special_tokens_mask.view(-1)] = special_tokens_embeds
return embeddings

Expand Down Expand Up @@ -198,7 +197,7 @@ def __init__(self, config):
self.decoder = BartForConditionalGeneration.from_pretrained(decoder_name)
else:
bart_config = BartConfig.from_pretrained(decoder_name)
self.decoder = BartForConditionalGeneration.from_config(bart_config)
self.decoder = BartForConditionalGeneration(config=bart_config)

self.set_decoder_requires_grad(freeze=freeze_decoder, freeze_embed_layer=freeze_embed_layer)

Expand Down Expand Up @@ -381,7 +380,9 @@ def __init__(self, config):
else:
text_max_tokens = self.num_text_query_token

self.decoder = self.adjust_max_pos_embeds(self.decoder, audio_max_tokens+text_max_tokens+2+self.use_clap_embeds) # extra token for CLAP
# extra token for CLAP, two for offsent in bart, two for bos and eos, and the rest because why not
self.decoder = self.adjust_max_pos_embeds(self.decoder, audio_max_tokens+text_max_tokens+20+self.use_clap_embeds)


# dropout kayer
self.audio_features_dropout = nn.Dropout(p=config['model']['audio_features_dropout_p'])
Expand Down Expand Up @@ -875,7 +876,7 @@ def get_meta_dict(self, batch, meta_keys=None, drop_keys=[]):
def training_step(self, batch, batch_idx):
audio = batch['waveform'].squeeze(1)
text = batch['gt_audio_caption'] # list of captions
audio = audio.to(self.device, non_blocking=True)
audio = audio.to(self.device)

# prepare meta
if 'meta_keys' in self.config['model'].keys():
Expand Down Expand Up @@ -999,11 +1000,15 @@ def on_validation_epoch_start(self):


def on_validation_epoch_end(self):

# place a barrier to ensure that all ranks reach the gather operation
torch.distributed.barrier()
gathered = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(gathered, self.val_outputs)
if not self.trainer.is_global_zero:
return

torch.distributed.barrier()

metrics_log = {}
# all ranks should excute the blocks to avoid deadlock
self.gathered_output_dict = []
for idx in range(len(self.val_loaders_labels)):
val_loader_out = {}
Expand All @@ -1015,7 +1020,7 @@ def on_validation_epoch_end(self):

self.gathered_output_dict.append(val_loader_out)


val_logger = logger.bind(indent=1)
for split, dataloader_outputs in zip(self.val_loaders_labels, self.gathered_output_dict):
val_logger.info(f"[INFO] evaluating metrics for split: {split}")
Expand Down Expand Up @@ -1051,12 +1056,12 @@ def get_score(metrics, key):
f'Spider score using beam search (beam size:{beam_size}): {spider:7.4f}')

metrics_log = {f"{split}/spider_beam_{beam_size}" : spider,
f"{split}/cider_beam_{beam_size}":cider,
f"{split}/spice_beam_{beam_size}":spice,
f"{split}/bleu_1_beam_{beam_size}":bleu_1,
f"{split}/cider_beam_{beam_size}":cider,
f"{split}/spice_beam_{beam_size}":spice,
f"{split}/bleu_1_beam_{beam_size}":bleu_1,
f"{split}/bleu_4_beam_{beam_size}":bleu_4,
f"{split}/rouge_l_beam_{beam_size}":rouge_l,
f"{split}/meteor_beam_{beam_size}":meteor }
f"{split}/rouge_l_beam_{beam_size}":rouge_l,
f"{split}/meteor_beam_{beam_size}":meteor }
if 'bert_score' in metrics:
bert_score = metrics.pop('bert_score')
metrics_log[f"{split}/bertscore_beam_{beam_size}"] = bert_score
Expand Down
9 changes: 5 additions & 4 deletions AutoCap/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main():
config["optim_args"]["lr"] = args.lr

if args.num_workers is not None:
config['data_args']['num_workers']
config['data_args']['num_workers'] = args.num_workers

# set up model
devices = torch.cuda.device_count()
Expand Down Expand Up @@ -90,6 +90,7 @@ def main():
return_test=False,
cache_dir=None)


# print training settings
printer = PrettyPrinter()
main_logger.info('Training setting:\n'
Expand All @@ -112,15 +113,15 @@ def main():
main_logger.info(f'Size of {val_k} validation set: {len(val_loader.dataset)}, size of batches: {len(val_loader)}')


# update the model with data types, combine test and val loaders # TODO: delete the test

print("val_loaders", len(val_loaders), val_loaders.keys())
model.val_loaders_labels = list(val_loaders.keys())
val_loaders = list(val_loaders.values())
model.log_output_dir = log_output_dir


# ckpt
validation_every_n_epochs = config["step"]["validation_every_n_epochs"]
validation_every_n_epochs = config["step"].get("validation_every_n_epochs", None)
save_checkpoint_every_n_epochs = config["logging"]["save_checkpoint_every_n_epochs"]

checkpoint_callback = S3ModelCheckpoint(
Expand All @@ -146,7 +147,7 @@ def main():
limit_val_batches=config['step'].get('limit_val_batches', None),
limit_train_batches=config['step'].get('limit_train_batches', None),
check_val_every_n_epoch=validation_every_n_epochs,
strategy=DDPStrategy(find_unused_parameters=True),
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=[checkpoint_callback],
gradient_clip_val=config["model"].get("clip_grad", None),
profiler=config['training'].get('profiler', None),
Expand Down
2 changes: 1 addition & 1 deletion GenAU/audioldm_eval/audioldm_eval/datasets/load_mel.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
sr=16000,
limit_num=None,
):
self.datalist = [os.path.join(datadir, x) for x in os.listdir(datadir)]
self.datalist = [os.path.join(datadir, x) for x in os.listdir(datadir) if x.endswith('.wav')]
self.datalist = sorted(self.datalist)
if limit_num is not None:
self.datalist = self.datalist[:limit_num]
Expand Down
4 changes: 2 additions & 2 deletions GenAU/src/tools/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def build_dataset_json_from_list(list_path):
wav = ""
data.append(
{
"wav": wav,
"fname": wav,
"caption": caption,
}
)
Expand All @@ -83,7 +83,7 @@ def read_json(dataset_json_file):
def copy_test_subset_data(metadata, testset_copy_target_path):
# metadata = read_json(testset_metadata)
os.makedirs(testset_copy_target_path, exist_ok=True)
if len(os.listdir(testset_copy_target_path)) == len(metadata):
if len(os.listdir(testset_copy_target_path)) >= len(metadata) - 1:
return
else:
# delete files in folder testset_copy_target_path
Expand Down
Loading

0 comments on commit 7503e62

Please sign in to comment.