Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: update README for clarity and consistency; add encoder fairseq … #186

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions examples/slam_aac/README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# SLAM-AAC

SLAM-AAC is a LLM-based model for Automated Audio Captioning (AAC) task. Inspired by techniques in machine translation and ASR, the model enhances audio captioning by incorporating paraphrasing augmentation and a plug-and-play CLAP-Refine strategy. For more details, please refer to the [paper](https://arxiv.org/abs/2410.09503).
SLAM-AAC is a LLM-based framework for Automated Audio Captioning (AAC) task. Inspired by techniques in machine translation and ASR, the model enhances audio captioning by incorporating **paraphrasing augmentation** and a plug-and-play **CLAP-Refine** strategy. For more details, please refer to the [paper](https://arxiv.org/abs/2410.09503).

## Model Architecture
SLAM-AAC uses EAT as the audio encoder and Vicuna-7B as the LLM decoder. During training, only the Linear Projector and LoRA modules are trainable. For inference, multiple candidates are generated using different beam sizes, which are then refined using the CLAP-Refine strategy.
SLAM-AAC uses **EAT** as the audio encoder and **Vicuna-7B** as the LLM decoder. During training, only the Linear Projector and LoRA modules are trainable. For inference, multiple candidates are generated using different beam sizes, which are then refined using the CLAP-Refine strategy.

![](./docs/model.png)

## Performance and checkpoints
We have released the pre-trained checkpoint of SLAM-AAC, as well as the fine-tuned checkpoints for the Clotho and AudioCaps datasets. The provided checkpoints include the model's Linear Projector and LoRA modules. Please note that when using each component, be sure to set up the corresponding environments according to the instructions provided in the respective repositories (e.g., for [EAT](https://github.com/cwx-worst-one/EAT)).
Pre-trained and fine-tuned checkpoints for the **Clotho** and **AudioCaps** datasets are available. These checkpoints include the Linear Projector and LoRA modules. Ensure proper setup of the corresponding environments (e.g., [EAT](https://github.com/cwx-worst-one/EAT)) before use.


### Pre-training
SLAM-AAC was pre-trained on a combination of AudioCaps, Clotho, WavCaps, and MACS datasets. For more information on these datasets, you can refer to [this repository](https://github.com/Labbeti/aac-datasets). Additionally, the Clotho dataset was augmented using a back-translation-based paraphrasing technique.
SLAM-AAC was pre-trained on AudioCaps, Clotho, WavCaps, and MACS datasets. For more information on these datasets, you can refer to [this repository](https://github.com/Labbeti/aac-datasets). Additionally, the Clotho dataset was augmented using a back-translation-based paraphrasing technique.
Audio Encoder | LLM | Checkpoint | Pre-training Dataset|
|:---:|:---:|:---:|:---:|
[EAT-base (fine-tuned)](https://drive.google.com/file/d/1aCYiQmoZv_Gh1FxnR-CCWpNAp6DIJzn6/view?usp=sharing) |[vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | [link](https://drive.google.com/drive/folders/10kOjB112AeGYA_0mIUr8f1-i5rSg08_O?usp=sharing) | AudioCaps, Clotho, WavCaps, MACS |
Expand All @@ -25,7 +26,7 @@ Dataset | Audio Encoder | LLM | Checkpoint | METEOR | CIDEr | SPICE | SPIDEr | S


## Data preparation
Ensure your `jsonl` data follows the structure outlined below:
Ensure your `jsonl` data follows this format:
```json
{"key": "Y7fmOlUlwoNg_1", "source": "/root/data/AudioCaps/waveforms/test/Y7fmOlUlwoNg.wav", "target": "Constant rattling noise and sharp vibrations"}
{"key": "Y6BJ455B1aAs_1", "source": "/root/data/AudioCaps/waveforms/test/Y6BJ455B1aAs.wav", "target": "A rocket flies by followed by a loud explosion and fire crackling as a truck engine runs idle"}
Expand Down Expand Up @@ -57,7 +58,7 @@ You can also fine-tune the model without loading any pre-trained weights, though
- Due to differences in dependency versions, there may be slight variations in the performance of the SLAM-AAC model.

## Inference
To perform inference with the trained models, you can use the following commands to decode using the common beam search method:
To perform inference with the trained models with beam search:
```bash
# Inference on AudioCaps (Beam Search)
bash scripts/inference_audiocaps_bs.sh
Expand All @@ -66,7 +67,9 @@ bash scripts/inference_audiocaps_bs.sh
bash scripts/inference_clotho_bs.sh
```

For improved inference results, you can use the CLAP-Refine strategy, which utilizes multiple beam search decoding. To use this method, you need to download and use our pre-trained [CLAP](https://drive.google.com/drive/folders/1X4NYE08N-kbOy6s_Itb0wBR_3X8oZF56?usp=sharing) model. Note that CLAP-Refine may take longer to run, but it can provide better quality outputs. You can execute the following commands:
To generate better captions, use the CLAP-Refine strategy with multiple beam search decoding. This method leverages our pre-trained [CLAP](https://drive.google.com/drive/folders/1X4NYE08N-kbOy6s_Itb0wBR_3X8oZF56?usp=sharing) model. Though it takes more time, it ensures higher-quality results. Use the following commands to apply it:


```bash
# Inference on AudioCaps (CLAP-Refine)
bash scripts/inference_audiocaps_CLAP_Refine.sh
Expand All @@ -81,7 +84,7 @@ bash scripts/clap_refine.sh
```

## Citation
You can refer to the paper for more results.
If you find SLAM-AAC useful, please cite the following paper:
```
@article{chen2024slam,
title={SLAM-AAC: Enhancing Audio Captioning with Paraphrasing Augmentation and CLAP-Refine through LLMs},
Expand Down
3 changes: 3 additions & 0 deletions examples/slam_aac/scripts/finetune_audiocaps.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ run_dir=/data/wenxi.chen/SLAM-LLM
cd $run_dir
code_dir=examples/slam_aac

encoder_fairseq_dir=/fairseq/EAT # path to the fairseq directory of the encoder model

audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
llm_path=/data/xiquan.li/models/vicuna-7b-v1.5

Expand Down Expand Up @@ -38,6 +40,7 @@ hydra.run.dir=$output_dir \
++model_config.encoder_path=$audio_encoder_path \
++model_config.encoder_dim=768 \
++model_config.encoder_projector=linear \
++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \
++dataset_config.encoder_projector_ds_rate=${encoder_projector_ds_rate} \
++dataset_config.dataset=audio_dataset \
++dataset_config.train_data_path=$train_jsonl_path \
Expand Down
3 changes: 3 additions & 0 deletions examples/slam_aac/scripts/finetune_clotho.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ run_dir=/data/wenxi.chen/SLAM-LLM
cd $run_dir
code_dir=examples/slam_aac

encoder_fairseq_dir=/fairseq/EAT # path to the fairseq directory of the encoder model

audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
llm_path=/data/xiquan.li/models/vicuna-7b-v1.5

Expand Down Expand Up @@ -38,6 +40,7 @@ hydra.run.dir=$output_dir \
++model_config.encoder_path=$audio_encoder_path \
++model_config.encoder_dim=768 \
++model_config.encoder_projector=linear \
++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \
++dataset_config.encoder_projector_ds_rate=${encoder_projector_ds_rate} \
++dataset_config.dataset=audio_dataset \
++dataset_config.train_data_path=$train_jsonl_path \
Expand Down
3 changes: 3 additions & 0 deletions examples/slam_aac/scripts/inference_audiocaps_CLAP_Refine.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
llm_path=/data/xiquan.li/models/vicuna-7b-v1.5
clap_dir=/data/xiquan.li/models/clap

encoder_fairseq_dir=/fairseq/EAT # path to the fairseq directory of the encoder model

encoder_projector_ds_rate=5

inference_data_path=/data/wenxi.chen/data/audiocaps/new_test.jsonl
Expand Down Expand Up @@ -41,6 +43,7 @@ for num_beams in "${beam_range[@]}"; do
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
++model_config.normalize=true \
++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \
++dataset_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
++dataset_config.dataset=audio_dataset \
++dataset_config.val_data_path=$inference_data_path \
Expand Down
3 changes: 3 additions & 0 deletions examples/slam_aac/scripts/inference_audiocaps_bs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ run_dir=/data/wenxi.chen/SLAM-LLM
cd $run_dir
code_dir=examples/slam_aac

encoder_fairseq_dir=/fairseq/EAT # path to the fairseq directory of the encoder model

audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
llm_path=/data/xiquan.li/models/vicuna-7b-v1.5

Expand All @@ -31,6 +33,7 @@ python $code_dir/inference_aac_batch.py \
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
++model_config.normalize=true \
++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \
++dataset_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
++dataset_config.dataset=audio_dataset \
++dataset_config.val_data_path=$inference_data_path \
Expand Down
3 changes: 3 additions & 0 deletions examples/slam_aac/scripts/inference_clotho_CLAP_Refine.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ run_dir=/data/wenxi.chen/SLAM-LLM
cd $run_dir
code_dir=examples/slam_aac

encoder_fairseq_dir=/fairseq/EAT # path to the fairseq directory of the encoder model

audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
llm_path=/data/xiquan.li/models/vicuna-7b-v1.5
clap_dir=/data/xiquan.li/models/clap
Expand Down Expand Up @@ -41,6 +43,7 @@ for num_beams in "${beam_range[@]}"; do
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
++model_config.normalize=true \
++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \
++dataset_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
++dataset_config.dataset=audio_dataset \
++dataset_config.val_data_path=$inference_data_path \
Expand Down
3 changes: 3 additions & 0 deletions examples/slam_aac/scripts/inference_clotho_bs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ run_dir=/data/wenxi.chen/SLAM-LLM
cd $run_dir
code_dir=examples/slam_aac

encoder_fairseq_dir=/fairseq/EAT # path to the fairseq directory of the encoder model

audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
llm_path=/data/xiquan.li/models/vicuna-7b-v1.5

Expand All @@ -31,6 +33,7 @@ python $code_dir/inference_aac_batch.py \
++model_config.encoder_projector=linear \
++model_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
++model_config.normalize=true \
++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \
++dataset_config.encoder_projector_ds_rate=$encoder_projector_ds_rate \
++dataset_config.dataset=audio_dataset \
++dataset_config.val_data_path=$inference_data_path \
Expand Down
3 changes: 3 additions & 0 deletions examples/slam_aac/scripts/pretrain.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ run_dir=/data/wenxi.chen/SLAM-LLM
cd $run_dir
code_dir=examples/slam_aac

encoder_fairseq_dir=/fairseq/EAT # path to the fairseq directory of the encoder model

audio_encoder_path=/data/xiquan.li/models/EAT-base_epoch30_ft.pt
llm_path=/data/xiquan.li/models/vicuna-7b-v1.5

Expand All @@ -34,6 +36,7 @@ hydra.run.dir=$output_dir \
++model_config.encoder_path=$audio_encoder_path \
++model_config.encoder_dim=768 \
++model_config.encoder_projector=linear \
++model_config.encoder_fairseq_dir=$encoder_fairseq_dir \
++dataset_config.encoder_projector_ds_rate=${encoder_projector_ds_rate} \
++dataset_config.dataset=audio_dataset \
++dataset_config.train_data_path=$train_jsonl_path \
Expand Down
Loading