-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
22e7fc6
commit df55ebf
Showing
16 changed files
with
1,287 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# <img src="assets/bat.png" alt="SELD_SpatialSoundQA" width="25" height="25"> SELD_SpatialSoundQA | ||
|
||
This repo hosts the code and models of "[BAT: Learning to Reason about Spatial Sounds with Large Language Models](https://arxiv.org/abs/2402.01591)" [ICML 2024 [bib](https://github.com/zszheng147/Spatial-AST#citation)]. | ||
|
||
Checkout our [demo page](https://zhishengzheng.com/BAT/) and enjoy a QA game with spatial audio. | ||
|
||
## Performance and checkpoints | ||
Encoder | Projector | PEFT | LLM | ||
|---|---|---|---| | ||
[Spatial-AST](https://huggingface.co/zhisheng01/Bat/blob/main/spatial-ast.pth) | Q-Former | adapter |[llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | ||
|
||
## Data preparation | ||
You need to prepare the data jsonl in this format. Below is an example. | ||
You can download the SpatialSoundQA dataset from [huggingface](https://huggingface.co/datasets/zhisheng01/SpatialSoundQA). | ||
``` | ||
{"audio_id": "eval/audio/YI-HlrcP6Qg4", "reverb_id": "q9vSo1VnCiC/0.npy", "audio_id2": null, "reverb_id2": null, "question_id": 0, "question_type": "CLASSIFICATION", "question": "Enumerate the sound occurrences in the audio clip.", "answer": "accelerating, revving, vroom; car; vehicle"} | ||
... | ||
{"audio_id": "eval/audio/YZX2fVPmUidA", "reverb_id": "q9vSo1VnCiC/32.npy", "audio_id2": "eval/audio/YjNjUU01quLs", "reverb_id2": "q9vSo1VnCiC/31.npy", "question_id": 58, "question_type": "MIXUP_NONBINARY_DISTANCE", "question": "How far away is the sound of the banjo from the sound of the whack, thwack?", "answer": "2m"} | ||
``` | ||
|
||
## Train a new model | ||
```bash | ||
bash examples/seld_spatialsoundqa/scripts/finetune_spatial-ast_qformer_llama_2_7b.sh | ||
``` | ||
|
||
## Decoding with checkpoints | ||
```bash | ||
bash examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh | ||
``` | ||
|
||
|
||
## TODO | ||
- [x] Decode with checkpoints | ||
- [ ] Upload SpatialSoundQA dataset | ||
- [ ] Upload pretrained checkpoints | ||
- [ ] Update model performance | ||
|
||
## Citation | ||
``` | ||
@article{zheng2024bat, | ||
author = {Zheng, Zhisheng and Peng, Puyuan and Ma, Ziyang and Chen, Xie and Choi, Eunsol and Harwath, David}, | ||
title = {BAT: Learning to Reason about Spatial Sounds with Large Language Models}, | ||
journal = {arXiv preprint arXiv:2402.01591}, | ||
year = {2024}, | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 4, | ||
"gradient_accumulation_steps": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-4 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
|
||
from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig | ||
from slam_llm.pipeline.finetune import main as train | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
peft_config: PeftConfig = field(default_factory=PeftConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
def to_plain_list(cfg_item): | ||
if isinstance(cfg_item, ListConfig): | ||
return OmegaConf.to_container(cfg_item, resolve=True) | ||
elif isinstance(cfg_item, DictConfig): | ||
return {k: to_plain_list(v) for k, v in cfg_item.items()} | ||
else: | ||
return cfg_item | ||
|
||
# kwargs = to_plain_list(cfg) | ||
kwargs = cfg | ||
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) | ||
logging.basicConfig(level=log_level) | ||
|
||
if kwargs.get("debug", False): | ||
import pdb; | ||
pdb.set_trace() | ||
|
||
train(kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import hydra | ||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from typing import Optional | ||
|
||
from slam_llm.pipeline.inference_batch import main as inference | ||
from seld_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig, PeftConfig | ||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) | ||
peft_config: PeftConfig = field(default_factory=PeftConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
decode_log: str = field( | ||
default="output/decode_log", | ||
metadata={"help": "The prefix for the decode output"}, | ||
) | ||
ckpt_path: str = field( | ||
default="output/model.pt", metadata={"help": "The path to projector checkpoint"} | ||
) | ||
peft_ckpt: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "The path to peft checkpoint, should be a directory including adapter_config.json" | ||
}, | ||
) | ||
|
||
|
||
@hydra.main(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
# kwargs = to_plain_list(cfg) | ||
log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if cfg.get("debug", False): | ||
import pdb | ||
|
||
pdb.set_trace() | ||
|
||
inference(cfg) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import torch | ||
import os | ||
import logging | ||
from slam_llm.models.slam_model import ( | ||
slam_model, | ||
setup_tokenizer, | ||
setup_encoder, | ||
setup_encoder_projector, | ||
setup_llm, | ||
) | ||
from slam_llm.utils.train_utils import print_model_size | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def model_factory(train_config, model_config, **kwargs): | ||
# return necessary components for training | ||
tokenizer = setup_tokenizer(train_config, model_config, **kwargs) | ||
|
||
encoder = setup_encoder(train_config, model_config, **kwargs) | ||
|
||
# llm | ||
llm = setup_llm(train_config, model_config, **kwargs) | ||
|
||
# projector | ||
encoder_projector = setup_encoder_projector( | ||
train_config, model_config, **kwargs | ||
) | ||
model = slam_model_seld( | ||
encoder, | ||
llm, | ||
encoder_projector, | ||
tokenizer, | ||
train_config, | ||
model_config, | ||
**kwargs, | ||
) | ||
|
||
ckpt_path = kwargs.get( | ||
"ckpt_path", None | ||
) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) | ||
if ckpt_path is not None: | ||
logger.info("loading other parts from: {}".format(ckpt_path)) | ||
ckpt_dict = torch.load(ckpt_path, map_location="cpu") | ||
model.load_state_dict(ckpt_dict, strict=False) | ||
|
||
print_model_size( | ||
model, | ||
train_config, | ||
( | ||
int(os.environ["RANK"]) | ||
if train_config.enable_fsdp or train_config.enable_ddp | ||
else 0 | ||
), | ||
) | ||
return model, tokenizer | ||
|
||
class slam_model_seld(slam_model): | ||
def __init__( | ||
self, | ||
encoder, | ||
llm, | ||
encoder_projector, | ||
tokenizer, | ||
train_config, | ||
model_config, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
encoder, | ||
llm, | ||
encoder_projector, | ||
tokenizer, | ||
train_config, | ||
model_config, | ||
**kwargs, | ||
) | ||
|
||
@torch.no_grad() | ||
def inference( | ||
self, | ||
wav_path=None, | ||
reverb_path=None, | ||
prompt=None, | ||
generation_config=None, | ||
logits_processor=None, | ||
stopping_criteria=None, | ||
prefix_allowed_tokens_fn=None, | ||
synced_gpus=None, | ||
assistant_model=None, | ||
streamer=None, | ||
negative_prompt_ids=None, | ||
negative_prompt_attention_mask=None, | ||
**kwargs, | ||
): | ||
#!TODO: | ||
# inference for SELD model | ||
device = kwargs.get("device", "cuda") | ||
if os.path.exists(wav_path): # Audio-Text QA | ||
import whisper | ||
|
||
audio_raw = whisper.load_audio(wav_path) | ||
audio_raw = whisper.pad_or_trim(audio_raw) | ||
|
||
mel_size = getattr( | ||
self.dataset_config, "mel_size", 80 | ||
) # 80 for large v1 and v2, 128 for large v3 | ||
audio_mel = ( | ||
whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) | ||
.permute(1, 0)[None, :, :] | ||
.to(device) | ||
) | ||
|
||
encoder_outs = self.encoder.extract_variable_length_features( | ||
audio_mel.permute(0, 2, 1) | ||
) | ||
|
||
if self.model_config.encoder_projector == "q-former": | ||
audio_mel_post_mask = torch.ones( | ||
encoder_outs.size()[:-1], dtype=torch.long | ||
).to(encoder_outs.device) | ||
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) | ||
if self.model_config.encoder_projector == "linear": | ||
encoder_outs = self.encoder_projector(encoder_outs) | ||
else: # Text QA | ||
encoder_outs = torch.empty( | ||
1, 0, self.llm.model.embed_tokens.embedding_dim | ||
).to(device) | ||
|
||
prompt = "USER: {}\n ASSISTANT:".format(prompt) | ||
prompt_ids = self.tokenizer.encode(prompt) | ||
prompt_length = len(prompt_ids) | ||
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) | ||
|
||
if hasattr(self.llm.model, "embed_tokens"): | ||
inputs_embeds = self.llm.model.embed_tokens(prompt_ids) | ||
elif hasattr(self.llm.model.model, "embed_tokens"): | ||
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) | ||
else: | ||
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) | ||
|
||
inputs_embeds = torch.cat( | ||
(encoder_outs, inputs_embeds[None, :, :]), dim=1 | ||
) # [audio,prompt] | ||
|
||
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( | ||
inputs_embeds.device | ||
) | ||
|
||
# generate | ||
model_outputs = self.generate( | ||
inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs | ||
) | ||
|
||
return model_outputs |
58 changes: 58 additions & 0 deletions
58
examples/seld_spatialsoundqa/scripts/decode_spatial-ast_qformer_llama_2_7b.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#!/bin/bash | ||
#export PYTHONPATH=/root/whisper:$PYTHONPATH | ||
export CUDA_VISIBLE_DEVICES=0 | ||
export TOKENIZERS_PARALLELISM=false | ||
# export CUDA_LAUNCH_BLOCKING=1 | ||
|
||
SLAM_DIR=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM | ||
cd $SLAM_DIR | ||
code_dir=examples/seld_spatialsoundqa | ||
|
||
stage=stage1-clsdoa | ||
qa_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/closed-end | ||
reverb_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/SpatialAudio/reverb/mp3d | ||
anechoic_data_root=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/data/AudioSet | ||
|
||
audio_encoder_path=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/models/SpatialAST/SpatialAST.pth | ||
llm_path=/mnt/lustre/hpc_stor03/sjtu_pub/cxgroup/model/Llama-2-7b-hf | ||
|
||
split=eval | ||
output_dir=/mnt/lustre/hpc_stor03/sjtu_home/zhisheng.zheng/SLAM-LLM/outputs/bat-llama-2-spatialAST-8qformer-steplrwarmupkeep1e-4-stage1-clsdoa-20240519/ | ||
ckpt_path=$output_dir/bat_epoch_1_step_4000 | ||
decode_log=$ckpt_path/decode_${split}_beam4 | ||
|
||
# -m debugpy --listen 5678 --wait-for-client | ||
python -u $code_dir/inference_seld_batch.py \ | ||
--config-path "conf" \ | ||
hydra.run.dir=$ckpt_path \ | ||
++model_config.llm_name=llama-2-7b \ | ||
++model_config.llm_path=$llm_path \ | ||
++model_config.llm_dim=4096 \ | ||
++model_config.encoder_name=SpatialAST \ | ||
++model_config.encoder_projector=q-former \ | ||
++model_config.encoder_ckpt=$audio_encoder_path \ | ||
++dataset_config.stage=$stage \ | ||
++dataset_config.qa_data_root=$qa_data_root \ | ||
++dataset_config.anechoic_data_root=$anechoic_data_root \ | ||
++dataset_config.reverb_data_root=$reverb_data_root \ | ||
++dataset_config.fix_length_audio=64 \ | ||
++dataset_config.inference_mode=true \ | ||
++train_config.model_name=bat \ | ||
++train_config.freeze_encoder=true \ | ||
++train_config.freeze_llm=true \ | ||
++train_config.batching_strategy=custom \ | ||
++train_config.num_epochs=1 \ | ||
++train_config.val_batch_size=8 \ | ||
++train_config.num_workers_dataloader=1 \ | ||
++train_config.output_dir=$output_dir \ | ||
++train_config.use_peft=true \ | ||
++peft_config.peft_method=llama_adapter \ | ||
++log_config.log_file=$output_dir/test.log \ | ||
++decode_log=$decode_log \ | ||
++ckpt_path=$ckpt_path/model.pt \ | ||
# ++peft_ckpt=$ckpt_path \ | ||
# ++train_config.use_peft=true \ | ||
# ++train_config.peft_config.r=32 \ | ||
# ++dataset_config.normalize=true \ | ||
# ++model_config.encoder_projector=q-former \ | ||
# ++dataset_config.fix_length_audio=64 \ |
Oops, something went wrong.