Skip to content

Commit

Permalink
fix beam search batch != 1
Browse files Browse the repository at this point in the history
  • Loading branch information
zszheng147 committed May 19, 2024
1 parent 22e7fc6 commit df55ebf
Show file tree
Hide file tree
Showing 16 changed files with 1,287 additions and 6 deletions.
46 changes: 46 additions & 0 deletions examples/seld_spatialsoundqa/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# <img src="assets/bat.png" alt="SELD_SpatialSoundQA" width="25" height="25"> SELD_SpatialSoundQA

This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)].

Checkout our [demo page](https://zhishengzheng.com/BAT/) and enjoy a QA game with spatial audio.

## Performance and checkpoints
Encoder | Projector | PEFT | LLM
|---|---|---|---|
[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b)

## Data preparation
You need to prepare the data jsonl in this format. Below is an example.
You can download the SpatialSoundQA dataset from [huggingface](https://huggingface.co/datasets/zhisheng01/SpatialSoundQA).
```
{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"}
...
{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"}
```

## Train a new model
```bash
bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh
```

## Decoding with checkpoints
```bash
bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
```


## TODO
- [x] Decode with checkpoints
- [ ] Upload SpatialSoundQA dataset
- [ ] Upload pretrained checkpoints
- [ ] Update model performance

## Citation
```
@article{zheng2024bat,
author = {Zheng, Zhisheng and Peng, Puyuan and Ma, Ziyang and Chen, Xie and Choi, Eunsol and Harwath, David},
title = {BAT: Learning to Reason about Spatial Sounds with Large Language Models},
journal = {arXiv preprint arXiv:2402.01591},
year = {2024},
}
```
Binary file added examples/seld_spatialsoundqa/assets/bat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 19 additions & 0 deletions examples/seld_spatialsoundqa/conf/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu"
}
}
}
45 changes: 45 additions & 0 deletions examples/seld_spatialsoundqa/finetune_seld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import hydra
import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf

from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig
from slam_llm.pipeline.finetune import main as train

@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
peft_config: PeftConfig = field(default_factory=PeftConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})

@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
def to_plain_list(cfg_item):
if isinstance(cfg_item, ListConfig):
return OmegaConf.to_container(cfg_item, resolve=True)
elif isinstance(cfg_item, DictConfig):
return {k: to_plain_list(v) for k, v in cfg_item.items()}
else:
return cfg_item

# kwargs = to_plain_list(cfg)
kwargs = cfg
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
logging.basicConfig(level=log_level)

if kwargs.get("debug", False):
import pdb;
pdb.set_trace()

train(kwargs)


if __name__ == "__main__":
main_hydra()
53 changes: 53 additions & 0 deletions examples/seld_spatialsoundqa/inference_seld_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import hydra
import logging
from dataclasses import dataclass, field
from omegaconf import DictConfig, ListConfig, OmegaConf
from typing import Optional

from slam_llm.pipeline.inference_batch import main as inference
from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig

@dataclass
class RunConfig:
dataset_config: DataConfig = field(default_factory=DataConfig)
model_config: ModelConfig = field(default_factory=ModelConfig)
train_config: TrainConfig = field(default_factory=TrainConfig)
log_config: LogConfig = field(default_factory=LogConfig)
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig)
peft_config: PeftConfig = field(default_factory=PeftConfig)
debug: bool = field(default=False, metadata={"help": "Use pdb when true"})
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"})
decode_log: str = field(
default="output/decode_log",
metadata={"help": "The prefix for the decode output"},
)
ckpt_path: str = field(
default="output/model.pt", metadata={"help": "The path to projector checkpoint"}
)
peft_ckpt: Optional[str] = field(
default=None,
metadata={
"help": "The path to peft checkpoint, should be a directory including adapter_config.json"
},
)


@hydra.main(config_name=None, version_base=None)
def main_hydra(cfg: DictConfig):
run_config = RunConfig()
cfg = OmegaConf.merge(run_config, cfg)
# kwargs = to_plain_list(cfg)
log_level = getattr(logging, cfg.get("log_level", "INFO").upper())

logging.basicConfig(level=log_level)

if cfg.get("debug", False):
import pdb

pdb.set_trace()

inference(cfg)


if __name__ == "__main__":
main_hydra()
154 changes: 154 additions & 0 deletions examples/seld_spatialsoundqa/model/slam_model_seld.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch
import os
import logging
from slam_llm.models.slam_model import (
slam_model,
setup_tokenizer,
setup_encoder,
setup_encoder_projector,
setup_llm,
)
from slam_llm.utils.train_utils import print_model_size

logger = logging.getLogger(__name__)

def model_factory(train_config, model_config, **kwargs):
# return necessary components for training
tokenizer = setup_tokenizer(train_config, model_config, **kwargs)

encoder = setup_encoder(train_config, model_config, **kwargs)

# llm
llm = setup_llm(train_config, model_config, **kwargs)

# projector
encoder_projector = setup_encoder_projector(
train_config, model_config, **kwargs
)
model = slam_model_seld(
encoder,
llm,
encoder_projector,
tokenizer,
train_config,
model_config,
**kwargs,
)

ckpt_path = kwargs.get(
"ckpt_path", None
) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft)
if ckpt_path is not None:
logger.info("loading other parts from: {}".format(ckpt_path))
ckpt_dict = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt_dict, strict=False)

print_model_size(
model,
train_config,
(
int(os.environ["RANK"])
if train_config.enable_fsdp or train_config.enable_ddp
else 0
),
)
return model, tokenizer

class slam_model_seld(slam_model):
def __init__(
self,
encoder,
llm,
encoder_projector,
tokenizer,
train_config,
model_config,
**kwargs,
):
super().__init__(
encoder,
llm,
encoder_projector,
tokenizer,
train_config,
model_config,
**kwargs,
)

@torch.no_grad()
def inference(
self,
wav_path=None,
reverb_path=None,
prompt=None,
generation_config=None,
logits_processor=None,
stopping_criteria=None,
prefix_allowed_tokens_fn=None,
synced_gpus=None,
assistant_model=None,
streamer=None,
negative_prompt_ids=None,
negative_prompt_attention_mask=None,
**kwargs,
):
#!TODO:
# inference for SELD model
device = kwargs.get("device", "cuda")
if os.path.exists(wav_path): # Audio-Text QA
import whisper

audio_raw = whisper.load_audio(wav_path)
audio_raw = whisper.pad_or_trim(audio_raw)

mel_size = getattr(
self.dataset_config, "mel_size", 80
) # 80 for large v1 and v2, 128 for large v3
audio_mel = (
whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size)
.permute(1, 0)[None, :, :]
.to(device)
)

encoder_outs = self.encoder.extract_variable_length_features(
audio_mel.permute(0, 2, 1)
)

if self.model_config.encoder_projector == "q-former":
audio_mel_post_mask = torch.ones(
encoder_outs.size()[:-1], dtype=torch.long
).to(encoder_outs.device)
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
if self.model_config.encoder_projector == "linear":
encoder_outs = self.encoder_projector(encoder_outs)
else: # Text QA
encoder_outs = torch.empty(
1, 0, self.llm.model.embed_tokens.embedding_dim
).to(device)

prompt = "USER: {}\n ASSISTANT:".format(prompt)
prompt_ids = self.tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device)

if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)

inputs_embeds = torch.cat(
(encoder_outs, inputs_embeds[None, :, :]), dim=1
) # [audio,prompt]

attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(
inputs_embeds.device
)

# generate
model_outputs = self.generate(
inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs
)

return model_outputs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/bin/bash
#export PYTHONPATH=/root/whisper:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0
export TOKENIZERS_PARALLELISM=false
# export CUDA_LAUNCH_BLOCKING=1

SLAM_DIR=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM
cd $SLAM_DIR
code_dir=examples/seld_spatialsoundqa

stage=stage1-clsdoa
qa_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/closed-end
reverb_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/reverb/mp3d
anechoic_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/AudioSet

audio_encoder_path=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/models/SpatialAST/SpatialAST.pth
llm_path=/mnt/lustre/hpc_stor03/sjtu_pub/cxgroup/model/Llama-2-7b-hf

split=eval
output_dir=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-llama-2-spatialAST-8qformer-steplrwarmupkeep1e-4-stage1-clsdoa-20240519/
ckpt_path=$output_dir/bat_epoch_1_step_4000
decode_log=$ckpt_path/decode_${split}_beam4

# -m debugpy --listen 5678 --wait-for-client
python -u $code_dir/inference_seld_batch.py \
--config-path "conf" \
hydra.run.dir=$ckpt_path \
++model_config.llm_name=llama-2-7b \
++model_config.llm_path=$llm_path \
++model_config.llm_dim=4096 \
++model_config.encoder_name=SpatialAST \
++model_config.encoder_projector=q-former \
++model_config.encoder_ckpt=$audio_encoder_path \
++dataset_config.stage=$stage \
++dataset_config.qa_data_root=$qa_data_root \
++dataset_config.anechoic_data_root=$anechoic_data_root \
++dataset_config.reverb_data_root=$reverb_data_root \
++dataset_config.fix_length_audio=64 \
++dataset_config.inference_mode=true \
++train_config.model_name=bat \
++train_config.freeze_encoder=true \
++train_config.freeze_llm=true \
++train_config.batching_strategy=custom \
++train_config.num_epochs=1 \
++train_config.val_batch_size=8 \
++train_config.num_workers_dataloader=1 \
++train_config.output_dir=$output_dir \
++train_config.use_peft=true \
++peft_config.peft_method=llama_adapter \
++log_config.log_file=$output_dir/test.log \
++decode_log=$decode_log \
++ckpt_path=$ckpt_path/model.pt \
# ++peft_ckpt=$ckpt_path \
# ++train_config.use_peft=true \
# ++train_config.peft_config.r=32 \
# ++dataset_config.normalize=true \
# ++model_config.encoder_projector=q-former \
# ++dataset_config.fix_length_audio=64 \
Loading

0 comments on commit df55ebf

Please sign in to comment.