diff --git a/AutoCap/README.md b/AutoCap/README.md
index a08d729..9454fe9 100644
--- a/AutoCap/README.md
+++ b/AutoCap/README.md
@@ -1,119 +1,106 @@
-
+[![arXiv](ARXIV ICON)](ARXIV LINK)
-# GenAU inference, training, and evaluation
-- [Introduction](#introduction)
-- [Environemnt setup](#environment-initalization)
+# AutoCap inference, training and evaluation
- [Inference](#inference)
- * [Audio to text script](#text-to-audio)
- * [Inference a list of promots](#inference-a-list-of-prompts)
+ * [Audio to text script](#audio-to-text)
+ * [Gradio demo](#gradio-demo)
+ * [Caption a list of audio files](#caption-list-of-audio-files)
+ * [Caption your custom dataset](#caption-a-dataset)
- [Training](#training)
- * [GenAU](#genau)
- * [Finetuning GenAU](#finetuning-genau)
- * [1D-VAE (optional)](#1d-vae-optional)
- [Evaluation](#evaluation)
- [Cite this work](#cite-this-work)
- [Acknowledgements](#acknowledgements)
-# Introduction
-We introduce GenAU, a transformer-based audio latent diffusion model leveraging the FIT architecture. Our model compresses mel-spectrogram data into a 1D representation and utilizes layered attention processes to achieve state-of-the-art audio generation results among open-source models.
-
-
-
-
-
-
-
-
-# Environment initialization
+# Environment initalization
For initializing your environment, please refer to the [general README](../README.md).
# Inference
-## Text to Audio
-To quickly generate an audio based on an input text prompt, run
+## Audio to Text
+To quickly generate a caption for an input audio, run
```shell
-python scripts/text_to_audio.py --prompt "Horses growl and clop hooves." --model "genau-full-l"
+python scripts/audio_to_text.py --wav_path
+
+# Example inference
+python scripts/audio_to_text.py --wav_path samples/ood_samples/loudwhistle-91003.wav
```
-- This will automatically download and use the model `genau-full-l` with default settings. You may change these parameters or provide your custom model config file and checkpoint path.
-- Available models include `genau-full-l` (1.25B parameters) and `genau-full-s` (493M parameters)
-- These models are trained to generate ambient sounds and is incapable of generating speech or music.
-- Outputs will be saved by default at `samples/model_output` using the provided prompt as the file name.
+- This will automatically download `TODO` model and run the inference with the default parameters. You may change these parameters or provide your cutome model config file and checkpoint path.
+- For more accurate captioning, provide meta data using `--title`, `description`, and `--video_caption` arguments.
-
+python app_audio2text.py
+```
-## Inference a list of prompts
-Optionally, you may prepare a `.txt` file with your target prompts and run
+## Caption list of audio files
+- Prepare all target audio files in a single folder
+- Optionally, provide meta data information in `yaml` file using the following structure
+```yaml
+file_name.wav:
+ title: "video title"
+ description: "video description"
+ video_caption: "video caption"
+```
+Then run the following script
```shell
-python scripts/inference_file.py --list_inference --model
+python scripts/inference_folder.py --folder_path --meta_data_file
-# Example
-python scripts/inference_file.py --list_inference samples/prompts_list.txt --model "genau-full-l"
+# Example inference
+python scripts/inference_folder.py --folder_path samples/ood_samples --meta_data_file samples/ood_samples/meta_data.yaml
```
+## Caption your custom dataset
-## Training
-
-### Dataset
-Please refer to the [dataset preparation README](../dataset_preperation/README.md) for instructions on downloading our dataset or preparing your own.
-
-### GenAU
-- Prepare a yaml config file for your experiments. A sample config file is provided at `settings/simple_runs/genau.yaml`
-- Specify your project name and provide your Wandb key in the config file. A Wandb key can be obtained from [https://wandb.ai/authorize](https://wandb.ai/authorize)
-- Optionally, provide your S3 bucket and folder to save intermediate checkpoints.
-- By default, checkpoints will be saved under `run_logs/genau/train` at the same level as the config file.
+If you want to caption a large dataset, we provide a script that works with multigpus for faster inference.
+- Prepare your custom dataset by following the instruction in the dataset prepeartion README (TODO) and run
```shell
-# Training GenAU from scratch
-python train/genau.py -c settings/simple_runs/genau.yaml
-```
+python scripts/caption_dataset.py \
+ --caption_store_key \
+ --beam_size 2 \
+ --start_idx 0 \
+ --end_idx 1000000 \
+ --dataset_keys "dataset_1" "dataset_2" ...
-For multinode training, run
-```shell
-python -m torch.distributed.run --nproc_per_node=8 train/genau.py -c settings/simple_runs/genau.yaml
```
-### Finetuning GenAU
+- Provide your dataset keys as registered in the dataset preperation (TODO)
+- Captions will be generated and stores in each file json file with the specified caption_ store_key
+- `start_idx` and `end_idx` arugments can be used to resume or distribute captioning experiments
+- Add your `caption_store_key` under `keys_synonyms:gt_audio_caption` in the target yaml config file for it to be selected when the ground truth caption is not available in your audio captioning or audio generation experiments.
-- Prepare your custom dataset and obtain the dataset keys following [dataset preparation README](../dataset_preperation/README.md)
-- Make a copy and adjust the default config file of `genau-full-l` which you can find under `pretrained_models/genau/genau-full-l.yaml`
-- Add ids for your dataset keys under `dataset2id` attribute in the config file.
+# Training
+### Dataset
+Please refer to the dataset README (TODO) for instructions on downloading our dataset or preparing your own dataset.
+
+### Stage 1 (pretraining)
+- Specify your model parameters in a config yaml file. A sample yaml file is given under `settings/pretraining.yaml`
+- Specify your project name and provide your wandb key in the config file. A wandb key can be obtained from [https://wandb.ai/authorize](https://wandb.ai/authorize)
+- Optionally, provide your S3 bucket and folder to save intermediate checkpoints.
+- By default, checkpoints will be save under `run_logs/train`
```shell
-# Finetuning GenAU
-python train/genau.py --reload_from_ckpt 'genau-full-l' \
- --config \
- --dataset_keys "" "" ...
+python train.py -c settings/pretraining.yaml
```
-
-### 1D VAE (Optional)
-By default, we offer a pre-trained 1D-VAE for GenAU training. If you prefer, you can train your own VAE by following the provided instructions.
-- Prepare your own dataset following the instructions in the [dataset preparation README](../dataset_preperation/README.md)
-- Prepare your yaml config file in a similar way to the GenAU config file
-- A sample config file is provided at `settings/simple_runs/1d_vae.yaml`
+### Stage 2 (finetuning)
+- Prepare your finetuning config file in a similar way as the pretraining stage. Typically, you only need to provide `pretrain_path` to your pretraining checkpoint, adjust learning rate, and untoggle the freeze option for the `text_decoder`.
+- A sample fintuning config is provided under `settings/finetuning.yaml`
```shell
-python train/1d_vae.py -c settings/simple_runs/1d_vae.yaml
+python train.py -c settings/finetuning.yaml
```
-## Evaluation
-- We follow [audioldm](https://github.com/haoheliu/AudioLDM-training-finetuning) to perform our evaulations.
-- By default, the models will be evaluated periodically during training as specified in the config file. For each evaluation, a folder with the generated audio will be saved under `run_logs/train' at the same levels as the specified config file.
-- The code identifies the test dataset in an already existing folder according to the number of samples. If you would like to test on a new test dataset, register it in `scripts/generate_and_eval`
+# Evalution
+- By default, the models will be log metrics on the validation set to wandb periodically during training as specified in the config file.
+- We exclude the `spice`, `spideer` and `meteor` metrics during training as they tend to hang out the training during multigpu training. You man inlcude them by changing the configruation.
+- A file with the predicted captions during evaluation will be saved under `run_logs/train` and metrics can be found in a file named `output.txt` under the logging folder.
+- To run the evaluation on the test set, after the training finishes, run:
```shell
-
-# Evaluate an existing generated folder
-python scripts/evaluate.py --log_path
-
-# Geneate test audios from a pre-trained checkpoint and run evaulation
-python scripts/generate_and_eval.py -c -ckpt
+python evaluate.py -c -ckpt
```
-The evaluation result will be saved in a JSON file at the same level of the generated audio folder.
# Cite this work
If you found this useful, please consider citing our work
@@ -122,5 +109,6 @@ If you found this useful, please consider citing our work
```
# Acknowledgements
-Our audio generation and evaluation codebase relies on [audioldm](https://github.com/haoheliu/AudioLDM-training-finetuning). We sincerely appreciate the authors for sharing their code openly.
-
+We sincerely thank the authors of the following work for sharing their code publicly:
+- [WavCaps: A ChatGPT-Assisted Weakly-Labelled Audio Captioning Dataset for Audio-Language Multimodal Research](https://github.com/XinhaoMei/WavCaps)
+- [Audio Captioning Transformer](https://github.com/XinhaoMei/ACT/tree/main/coco_caption)
\ No newline at end of file
diff --git a/AutoCap/src/models/pl_htsat_q_bart_captioning.py b/AutoCap/src/models/pl_htsat_q_bart_captioning.py
index c010656..aa29eed 100644
--- a/AutoCap/src/models/pl_htsat_q_bart_captioning.py
+++ b/AutoCap/src/models/pl_htsat_q_bart_captioning.py
@@ -57,7 +57,6 @@ def forward(self, token_ids):
token_ids = torch.where(special_tokens_mask, self.base_tokenizer.pad_token_id, token_ids)
embeddings = self.model.encoder.embed_tokens(token_ids)
-
embeddings.view(-1, embeddings.shape[-1])[special_tokens_mask.view(-1)] = special_tokens_embeds
return embeddings
@@ -198,7 +197,7 @@ def __init__(self, config):
self.decoder = BartForConditionalGeneration.from_pretrained(decoder_name)
else:
bart_config = BartConfig.from_pretrained(decoder_name)
- self.decoder = BartForConditionalGeneration.from_config(bart_config)
+ self.decoder = BartForConditionalGeneration(config=bart_config)
self.set_decoder_requires_grad(freeze=freeze_decoder, freeze_embed_layer=freeze_embed_layer)
@@ -381,7 +380,9 @@ def __init__(self, config):
else:
text_max_tokens = self.num_text_query_token
- self.decoder = self.adjust_max_pos_embeds(self.decoder, audio_max_tokens+text_max_tokens+2+self.use_clap_embeds) # extra token for CLAP
+ # extra token for CLAP, two for offsent in bart, two for bos and eos, and the rest because why not
+ self.decoder = self.adjust_max_pos_embeds(self.decoder, audio_max_tokens+text_max_tokens+20+self.use_clap_embeds)
+
# dropout kayer
self.audio_features_dropout = nn.Dropout(p=config['model']['audio_features_dropout_p'])
@@ -875,7 +876,7 @@ def get_meta_dict(self, batch, meta_keys=None, drop_keys=[]):
def training_step(self, batch, batch_idx):
audio = batch['waveform'].squeeze(1)
text = batch['gt_audio_caption'] # list of captions
- audio = audio.to(self.device, non_blocking=True)
+ audio = audio.to(self.device)
# prepare meta
if 'meta_keys' in self.config['model'].keys():
@@ -999,11 +1000,15 @@ def on_validation_epoch_start(self):
def on_validation_epoch_end(self):
+
+ # place a barrier to ensure that all ranks reach the gather operation
+ torch.distributed.barrier()
gathered = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(gathered, self.val_outputs)
- if not self.trainer.is_global_zero:
- return
-
+ torch.distributed.barrier()
+
+ metrics_log = {}
+ # all ranks should excute the blocks to avoid deadlock
self.gathered_output_dict = []
for idx in range(len(self.val_loaders_labels)):
val_loader_out = {}
@@ -1015,7 +1020,7 @@ def on_validation_epoch_end(self):
self.gathered_output_dict.append(val_loader_out)
-
+
val_logger = logger.bind(indent=1)
for split, dataloader_outputs in zip(self.val_loaders_labels, self.gathered_output_dict):
val_logger.info(f"[INFO] evaluating metrics for split: {split}")
@@ -1051,12 +1056,12 @@ def get_score(metrics, key):
f'Spider score using beam search (beam size:{beam_size}): {spider:7.4f}')
metrics_log = {f"{split}/spider_beam_{beam_size}" : spider,
- f"{split}/cider_beam_{beam_size}":cider,
- f"{split}/spice_beam_{beam_size}":spice,
- f"{split}/bleu_1_beam_{beam_size}":bleu_1,
+ f"{split}/cider_beam_{beam_size}":cider,
+ f"{split}/spice_beam_{beam_size}":spice,
+ f"{split}/bleu_1_beam_{beam_size}":bleu_1,
f"{split}/bleu_4_beam_{beam_size}":bleu_4,
- f"{split}/rouge_l_beam_{beam_size}":rouge_l,
- f"{split}/meteor_beam_{beam_size}":meteor }
+ f"{split}/rouge_l_beam_{beam_size}":rouge_l,
+ f"{split}/meteor_beam_{beam_size}":meteor }
if 'bert_score' in metrics:
bert_score = metrics.pop('bert_score')
metrics_log[f"{split}/bertscore_beam_{beam_size}"] = bert_score
diff --git a/AutoCap/train.py b/AutoCap/train.py
index 9fd2069..57b0260 100644
--- a/AutoCap/train.py
+++ b/AutoCap/train.py
@@ -49,7 +49,7 @@ def main():
config["optim_args"]["lr"] = args.lr
if args.num_workers is not None:
- config['data_args']['num_workers']
+ config['data_args']['num_workers'] = args.num_workers
# set up model
devices = torch.cuda.device_count()
@@ -90,6 +90,7 @@ def main():
return_test=False,
cache_dir=None)
+
# print training settings
printer = PrettyPrinter()
main_logger.info('Training setting:\n'
@@ -112,7 +113,7 @@ def main():
main_logger.info(f'Size of {val_k} validation set: {len(val_loader.dataset)}, size of batches: {len(val_loader)}')
- # update the model with data types, combine test and val loaders # TODO: delete the test
+
print("val_loaders", len(val_loaders), val_loaders.keys())
model.val_loaders_labels = list(val_loaders.keys())
val_loaders = list(val_loaders.values())
@@ -120,7 +121,7 @@ def main():
# ckpt
- validation_every_n_epochs = config["step"]["validation_every_n_epochs"]
+ validation_every_n_epochs = config["step"].get("validation_every_n_epochs", None)
save_checkpoint_every_n_epochs = config["logging"]["save_checkpoint_every_n_epochs"]
checkpoint_callback = S3ModelCheckpoint(
@@ -146,7 +147,7 @@ def main():
limit_val_batches=config['step'].get('limit_val_batches', None),
limit_train_batches=config['step'].get('limit_train_batches', None),
check_val_every_n_epoch=validation_every_n_epochs,
- strategy=DDPStrategy(find_unused_parameters=True),
+ strategy=DDPStrategy(find_unused_parameters=False),
callbacks=[checkpoint_callback],
gradient_clip_val=config["model"].get("clip_grad", None),
profiler=config['training'].get('profiler', None),
diff --git a/GenAU/audioldm_eval/audioldm_eval/datasets/load_mel.py b/GenAU/audioldm_eval/audioldm_eval/datasets/load_mel.py
index 8c35eb3..ed89118 100644
--- a/GenAU/audioldm_eval/audioldm_eval/datasets/load_mel.py
+++ b/GenAU/audioldm_eval/audioldm_eval/datasets/load_mel.py
@@ -130,7 +130,7 @@ def __init__(
sr=16000,
limit_num=None,
):
- self.datalist = [os.path.join(datadir, x) for x in os.listdir(datadir)]
+ self.datalist = [os.path.join(datadir, x) for x in os.listdir(datadir) if x.endswith('.wav')]
self.datalist = sorted(self.datalist)
if limit_num is not None:
self.datalist = self.datalist[:limit_num]
diff --git a/GenAU/src/tools/training_utils.py b/GenAU/src/tools/training_utils.py
index d8a7257..8340d25 100755
--- a/GenAU/src/tools/training_utils.py
+++ b/GenAU/src/tools/training_utils.py
@@ -61,7 +61,7 @@ def build_dataset_json_from_list(list_path):
wav = ""
data.append(
{
- "wav": wav,
+ "fname": wav,
"caption": caption,
}
)
@@ -83,7 +83,7 @@ def read_json(dataset_json_file):
def copy_test_subset_data(metadata, testset_copy_target_path):
# metadata = read_json(testset_metadata)
os.makedirs(testset_copy_target_path, exist_ok=True)
- if len(os.listdir(testset_copy_target_path)) == len(metadata):
+ if len(os.listdir(testset_copy_target_path)) >= len(metadata) - 1:
return
else:
# delete files in folder testset_copy_target_path
diff --git a/GenAU/src/utilities/data/videoaudio_dataset.py b/GenAU/src/utilities/data/videoaudio_dataset.py
index 1f76c41..6e960de 100644
--- a/GenAU/src/utilities/data/videoaudio_dataset.py
+++ b/GenAU/src/utilities/data/videoaudio_dataset.py
@@ -17,6 +17,9 @@
from src.tools.io import load_file, write_json, load_json
from src.tools.torch_utils import spectral_normalize_torch, random_uniform
from src.tools.training_utils import build_dataset_json_from_list
+import gc
+import librosa
+import threading
class VideoAudioDataset(Dataset):
def __init__(
@@ -33,6 +36,7 @@ def __init__(
dataset_json=None,
sample_single_caption=True,
augment_p=0.0,
+ limit_data_percentage = None,
cache_dir=None
):
"""
@@ -48,13 +52,19 @@ def __init__(
self.load_audio = load_audio
self.keep_audio_files = keep_audio_files
self.sample_single_caption = sample_single_caption
+ self.limit_data_percentage = config['data'].get('limit_data_percentage', False)
self.trim_wav = False
self.waveform_only = waveform_only
self.augment_p = augment_p
self.add_ons = [eval(x) for x in add_ons]
- self.cache_dir = cache_dir
self.consistent_start_time = config['data'].get('consistent_start_time', False)
+
+ self.cache_dir = config['data'].get('cache_dir', None)
+ if self.cache_dir is not None:
+ os.makedirs(self.cache_dir, exist_ok=True)
+
print("[INFO] Add-ons:", self.add_ons)
+ self.obtained_samples = 0
# transforms
if video_transform is None:
@@ -80,6 +90,20 @@ def __init__(
% (split, self.config["data"].keys())
)
self.retrieve_paths()
+
+
+ if split=='train' and self.limit_data_percentage:
+ print(f"[INFO] limiting data to only {self.limit_data_percentage} of the total data {len(self.data)}")
+ num_datapoints = int(len(self.data) * self.limit_data_percentage)
+
+ # fix the seed to make sure we select the same data.
+ np.random.seed(42)
+ selected_idx = np.random.randint(0, len(self.data), size=num_datapoints)
+
+ # select
+ self.video_json_paths = np.asarray(self.video_json_paths)[selected_idx]
+ self.data = np.asarray(self.data)[selected_idx]
+ self.datasets_of_datapoints = np.asarray(self.datasets_of_datapoints)[selected_idx]
self.build_dsp()
@@ -108,71 +132,130 @@ def get_data_from_keys(self, data, key, default_value=None):
return data[key]
return default_value # Or return a default value if none of the keys are found
- def __getitem__(self, index, augment=True):
- (
- index,
- fname,
- video_frames,
- waveform,
- stft,
- log_mel_spec,
- _, # the one-hot representation of the audio class
- (datum, mix_datum),
- random_start,
- ) = self.feature_extraction(index)
-
- if '.json' in self.data[index]:
- dataset_name = self.datasets_of_datapoints[index]
- absolute_file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
- else:
- dataset_name = absolute_file_path = ""
-
+
+ def default_sample(self):
data = {
- "dataset_name": dataset_name,
- "json_path": absolute_file_path,
- "fname": fname, # list
- "waveform": "" if (not self.load_audio) else waveform.float(),
- # tensor, [batchsize, t-steps, f-bins]
- "stft": "" if (stft is None) else stft.float(),
- # tensor, [batchsize, t-steps, mel-bins]
- "log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(),
- "duration": self.duration,
- "sampling_rate": self.sampling_rate,
- "random_start_sample_in_original_audio_file": random_start if random_start is not None else 0,
- "labels": ', '.join(datum.get('labels', [])),
-
- # # video
- "frames": video_frames if self.load_video else "",
+ "dataset_name": "UNK",
+ "json_path": "UNK",
+ "fname": "UNK", # list
+ "waveform": "" if (not self.load_audio) else torch.zeros(1, int(self.sampling_rate * self.duration)),
+ # "waveform": torch.zeros(1, int(self.sampling_rate * self.duration)),
+ # tensor, [batchsize, t-steps, f-bins]
+ "stft": "" if self.waveform_only else torch.zeros(int(self.duration * 100), 512),
+ # tensor, [batchsize, t-steps, mel-bins]
+ "log_mel_spec": "" if self.waveform_only else torch.zeros(int(self.duration * 100), 64),
+ "duration": self.duration,
+ "sampling_rate": self.sampling_rate,
+ "random_start_sample_in_original_audio_file": -1,
+ "labels": "UNK",
+
+ # # video
+ "frames": "",
+
+ # additional meta data
+ "title": "UNK",
+ "url": "UNK",
+ "description": "UNK",
+ "original_captions": "UNK",
+ "automatic_captions": "UNK",
+ "gt_audio_caption": "UNK" if self.sample_single_caption else ["UNK"] * 5,
+ "video_caption": "UNK",
+ "videollama_caption": "UNK",
+ "text": "UNK" if self.sample_single_caption else ["UNK"] * 5
+ }
- # additional meta data
- "title": self.filter_text(datum.get('title', '')),
- "url": self.filter_text(datum.get('url', '')),
- "description": self.filter_text(self.get_sample_description(datum)),
- "original_captions": self.filter_text(datum.get('original_captions', '')),
- "automatic_captions": self.filter_text(datum.get('automatic_captions', '')),
- "gt_audio_caption": self.get_sample_caption(datum, index=index),
- "panda_caption": datum.get('panda70m_caption_0000', '').replace("", "").strip(),
- "videollama_caption": datum.get('videollama_caption_0000', ''),
- }
-
- # select one caption if multiple exists
- if isinstance(data['gt_audio_caption'], list) and len(data['gt_audio_caption']) > 0 and self.sample_single_caption:
- idx = np.random.randint(len(data['gt_audio_caption']))
- data['gt_audio_caption'] = data['gt_audio_caption'][idx]
-
-
- for add_on in self.add_ons:
- data.update(add_on(self.config, data, self.data[index]))
+ return data
- # augment data
- if augment and np.random.rand() < self.augment_p:
- data = self.pair_augmentation(data)
+ def __getitem__(self, index, augment=True):
- data['text'] = data['gt_audio_caption']
- if not data['fname']:
- data['fname'] = data['text']
+ retries = 0
+ max_retries = 1
- return data
+ while retries < max_retries:
+ try:
+ if '.json' in self.data[index]:
+ dataset_name = self.datasets_of_datapoints[index]
+ absolute_file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
+ if not os.path.exists(absolute_file_path):
+ print(f"file {absolute_file_path} does not exists. Retying..")
+ index = random.randint(0, len(self.data) - 1)
+ retries += 1
+ continue
+ else:
+ dataset_name = absolute_file_path = ""
+
+ (
+ index,
+ fname,
+ video_frames,
+ waveform,
+ stft,
+ log_mel_spec,
+ _, # the one-hot representation of the audio class
+ (datum, mix_datum),
+ random_start,
+ ) = self.feature_extraction(index)
+
+ data = {
+ "dataset_name": dataset_name,
+ "json_path": absolute_file_path,
+ "fname": fname, # list
+ "waveform": "" if (not self.load_audio) else waveform.float(),
+ # tensor, [batchsize, t-steps, f-bins]
+ "stft": "" if (stft is None) else stft.float(),
+ # tensor, [batchsize, t-steps, mel-bins]
+ "log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(),
+ "duration": self.duration,
+ "sampling_rate": self.sampling_rate,
+ "random_start_sample_in_original_audio_file": -1 if random_start is None else random_start,
+ "labels": ', '.join(datum.get('labels', [])),
+
+ # # video
+ "frames": video_frames if self.load_video else "",
+
+ # additional meta data
+ "title": self.filter_text(datum.get('title', '')),
+ "url": self.filter_text(datum.get('url', '')),
+ "description": self.filter_text(self.get_sample_description(datum)),
+ "original_captions": self.filter_text(datum.get('original_captions', '')),
+ "automatic_captions": self.filter_text(datum.get('automatic_captions', '')),
+ "gt_audio_caption": self.get_sample_caption(datum, index=index),
+ "video_caption": datum.get('panda70m_caption_0000', '').replace("", "").strip(),
+ "videollama_caption": datum.get('videollama_caption_0000', ''),
+ }
+
+ # select one caption if multiple exists
+ if isinstance(data['gt_audio_caption'], list) and len(data['gt_audio_caption']) > 0 and self.sample_single_caption:
+ idx = np.random.randint(len(data['gt_audio_caption']))
+ data['gt_audio_caption'] = data['gt_audio_caption'][idx]
+
+
+ for add_on in self.add_ons:
+ data.update(add_on(self.config, data, self.data[index]))
+
+ # augment data
+ if augment and np.random.rand() < self.augment_p:
+ data = self.pair_augmentation(data)
+
+ data['text'] = data['gt_audio_caption']
+
+ self.obtained_samples += 1
+
+ if self.obtained_samples % 20 == 0:
+ gc.collect()
+ return data
+ except Exception as e:
+ if '.json' in self.data[index]:
+ dataset_name = self.datasets_of_datapoints[index]
+ file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
+ else:
+ file_path = ""
+
+ index = random.randint(0, len(self.data) - 1)
+ retries += 1
+ print("[ERROR, videoaudio_dataset] error while loading", file_path, e)
+ continue
+ return self.default_sample()
def text_to_filename(self, text):
return text.replace(" ", "_").replace("'", "_").replace('"', "_")
@@ -198,134 +281,124 @@ def __len__(self):
def replace_extension(self, path, new_ext):
return f"{'/'.join(path.split('.')[:-1])}.{new_ext}"
+
def feature_extraction(self, index):
- if index > len(self.data) - 1:
- print(
- "The index of the dataloader is out of range: %s/%s"
- % (index, len(self.data))
- )
- index = random.randint(0, len(self.data) - 1)
-
# Read wave file and extract feature
- while True:
- try:
- if isinstance(self.data[index], str) and '.json' in self.data[index]:
- dataset_name = self.datasets_of_datapoints[index]
- file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
- datum = load_json(file_path)
- else:
- datum = self.data[index]
+ if isinstance(self.data[index], str) and '.json' in self.data[index]:
+ dataset_name = self.datasets_of_datapoints[index]
+ file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
+ datum = load_json(file_path)
+ else:
+ datum = self.data[index]
- if 'path' in datum and datum['path']:
- datum['path'] = self._relative_path_to_absolute_path([datum['path']], dataset_name)[0]
+ if 'path' in datum and datum['path']:
+ datum['path'] = self._relative_path_to_absolute_path([datum['path']], dataset_name)[0]
- if 'wav' in datum and datum['wav']:
- datum['wav'] = self._relative_path_to_absolute_path([datum['wav']], dataset_name)[0]
+ if 'wav' in datum and datum['wav']:
+ datum['wav'] = self._relative_path_to_absolute_path([datum['wav']], dataset_name)[0]
+
+ random_start = None
+ log_mel_spec, stft, waveform, frames = None, None, None, None
+ audio_file = None
+
+ if self.load_audio and not ('wav' in datum.keys() and os.path.exists(datum['wav'])):
+ # assume that a .wav file exists in the same location as the .json file
+ wav_path = self.replace_extension(file_path, 'wav')
+ flac_path = self.replace_extension(file_path, 'flac')
+ if os.path.exists(wav_path):
+ datum['wav'] = wav_path
+ elif os.path.exists(flac_path):
+ datum['wav'] = flac_path
+ elif 'wav' in datum:
+ del datum['wav']
+
+ # cache wav file: useful when there exists a local memory the is faster to do read operations on it
+ if self.load_audio and 'wav' in datum and self.cache_dir is not None:
+ target_audio_file_path = f"{self.cache_dir}{datum['wav']}"
+ if not os.path.exists(target_audio_file_path):
+ os.makedirs(os.path.dirname(target_audio_file_path), exist_ok=True)
+ shutil.copy2(datum['wav'] , target_audio_file_path)
+
+ # update
+ datum['wav'] = target_audio_file_path
+
+ save_random_start = False
+ random_start = None
+ if self.consistent_start_time: # always sample from the same start time
+ if 'random_start_t' in datum:
+ random_start = datum.get('random_start_t', None)
+ save_random_start = False
+ else:
+ save_random_start = True
+
+ # load audio
+ if self.load_audio:
+ if 'wav' in datum:
+ (
+ log_mel_spec,
+ stft,
+ waveform,
+ random_start,
+ ) = self.read_audio_file(datum["wav"], random_start=random_start)
- random_start = None
- log_mel_spec, stft, waveform, frames = None, None, None, None
- audio_file = None
-
- if self.load_audio and not ('wav' in datum.keys() and os.path.exists(datum['wav'])):
- # assume that a .wav file exists in the same location as the .json file
- wav_path = self.replace_extension(file_path, 'wav')
- flac_path = self.replace_extension(file_path, 'flac')
- if os.path.exists(wav_path):
- datum['wav'] = wav_path
- elif os.path.exists(flac_path):
- datum['wav'] = flac_path
- elif 'wav' in datum:
- del datum['wav']
-
- # cache wav file: useful when there exists a local memory the is faster to do read operations on it
- if self.load_audio and 'wav' in datum and self.cache_dir is not None:
- target_audio_file_path = f"{self.cache_dir}{datum['wav']}"
- if not os.path.exists(target_audio_file_path):
- os.makedirs(os.path.dirname(target_audio_file_path), exist_ok=True)
- shutil.copy2(datum['wav'] , target_audio_file_path)
-
- # update
- datum['wav'] = target_audio_file_path
- save_random_start = False
- random_start = None
- if self.consistent_start_time: # always sample from the same start time
- if 'random_start_t' in datum:
- random_start = datum.get('random_start_t', None)
- save_random_start = False
- else:
- save_random_start = True
+ waveform = torch.FloatTensor(waveform)
- # load audio
- if self.load_audio:
- if 'wav' in datum:
- (
- log_mel_spec,
- stft,
- waveform,
- random_start,
- ) = self.read_audio_file(datum["wav"], random_start=random_start)
- waveform = torch.FloatTensor(waveform)
-
- else:
- (
- frames,
- log_mel_spec,
- stft,
- waveform,
- random_start,
- audio_file
- ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True)
- waveform = torch.FloatTensor(waveform)
-
- # load video
- if self.load_video and 'path' in datum:
- (frames, _, _, _, _, _ ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=self.load_audio and waveform is None)
- elif self.load_video and 'path' in datum:
- (
- frames,
- log_mel_spec,
- stft,
- waveform,
- random_start,
- audio_file
- ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True)
- waveform = torch.FloatTensor(waveform)
- if audio_file is not None:
- # update json to include path to audio. Only effective if keep_audio_file is enabled
- updated_json = load_json(file_path)
- updated_json['wav'] = self._absolute_path_to_relative_path([audio_file], dataset_name)[0]
- datum["wav"] = updated_json['wav']
- updated_json['random_start_t'] = random_start
- write_json(updated_json, file_path)
-
- elif save_random_start and random_start is not None:
- # update json to include the randomly sampled start time for future experiments
- updated_json = load_json(file_path)
- updated_json['random_start_t'] = random_start
- write_json(updated_json, file_path)
-
- mix_datum = None
- if self.load_video:
- assert frames.shape == (3, self.target_frame_cnt, self.frame_width, self.frame_height)
- break
+
+ else:
+ (
+ frames,
+ log_mel_spec,
+ stft,
+ waveform,
+ random_start,
+ audio_file
+ ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True)
+ waveform = torch.FloatTensor(waveform)
+
+ # load video
+ if self.load_video and 'path' in datum:
+ (frames, _, _, _, _, _ ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=self.load_audio and waveform is None)
+
+ elif self.load_video and 'path' in datum:
+ (
+ frames,
+ log_mel_spec,
+ stft,
+ waveform,
+ random_start,
+ audio_file
+ ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True)
+ waveform = torch.FloatTensor(waveform)
+
+ if audio_file is not None:
+ # update json to include path to audio. Only effective if keep_audio_file is enabled
+ updated_json = load_json(file_path)
+ updated_json['wav'] = self._absolute_path_to_relative_path([audio_file], dataset_name)[0]
+ datum["wav"] = updated_json['wav']
+ updated_json['random_start_t'] = random_start
+ # write_json(updated_json, file_path)
+
+ elif save_random_start and random_start is not None:
+ # update json to include the randomly sampled start time for future experiments
+ updated_json = load_json(file_path)
+ updated_json['random_start_t'] = random_start
+ write_json(updated_json, file_path)
+
+ mix_datum = None
+ if self.load_video:
+ assert frames.shape == (3, self.target_frame_cnt, self.frame_width, self.frame_height)
- except Exception as e:
- if '.json' in self.data[index]:
- dataset_name = self.datasets_of_datapoints[index]
- file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
- else:
- file_path = ""
-
- index = (index + 1) % len(self.data)
- print("[ERROR, videoaudio_dataset] error while loading", file_path, e)
- continue
# The filename of the wav file
- fname = datum["path"] if 'path' in datum and self.load_video else datum["wav"]
+ fname = datum["path"] if 'path' in datum and self.load_video else datum.get('wav', '')
+
+ if not fname:
+ fname = datum['fname']
+
return (
index,
@@ -338,7 +411,7 @@ def feature_extraction(self, index):
(datum, mix_datum),
random_start,
)
-
+
def combine_captions(self, caption1, caption2, remove_duplicates=False, background=False):
"""
Useful function to combine two caption when doing mixup augmentation
@@ -602,9 +675,58 @@ def process_wavform(self, waveform, sr):
return waveform
+ def load_audio_with_timeout(self, file_path, timeout):
+ """
+ Load an audio file with a specified timeout using threading.
+
+ :param file_path: Path to the audio file.
+ :param timeout: Maximum time (in seconds) to allow for loading the file.
+ :return: (waveform, sample_rate) if successful, None if timeout occurs.
+ """
+ class AudioLoader(threading.Thread):
+ def __init__(self, file_path):
+ super().__init__()
+ self.file_path = file_path
+ self.result = None
+
+ def run(self):
+ try:
+ waveform, sample_rate = torchaudio.load(self.file_path)
+ self.result = (waveform, sample_rate)
+ except Exception as e:
+ print(f"Failed to load audio: {e}")
+ self.result = None
+
+ # Start the thread
+ audio_loader = AudioLoader(file_path)
+ audio_loader.start()
+
+ # Wait for the thread to complete or timeout
+ audio_loader.join(timeout=timeout)
+ if audio_loader.is_alive():
+ print(f"Timeout while loading {file_path}")
+ return None, None # Timeout case
+
+ return audio_loader.result
+
+
def read_wav_file(self, filename, random_start=None):
+
# waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
- waveform, sr = torchaudio.load(filename)
+ # waveform = torch.from_numpy(waveform)
+ # print("waveform shape", waveform.shape)
+ waveform, sr = self.load_audio_with_timeout(filename, timeout=10)
+ if waveform is None:
+ print("[INFO] timeout when loading the audio")
+ # # # TODO Important, dummy audio
+ waveform = torch.zeros(1, int(self.sampling_rate * self.duration))
+ sr = 16000
+
+ # waveform = torch.zeros(1, int(self.sampling_rate * self.duration))
+ # sr = 16000
+ # waveform, sr = torchaudio.load(filename)
+ # # # TODO Important, dummy audio
+ # waveform = torch.zeros(1, int(self.sampling_rate * self.duration))
waveform, random_start = self.random_segment_wav(
waveform, target_length=int(sr * self.duration), random_start=random_start
@@ -885,13 +1007,13 @@ def custom_collate_fn(batch):
# for test
# for k in batch[0].keys():
- # try:
- # default_collate([{k:item[k]} for item in batch])
- # print("done")
- # except:
- # print("collect error in key", k)
- # print("files", [b['fname'] for b in batch])
- # inp = [{k:item[k]} for item in batch]
+ # try:
+ # default_collate([{k:item[k]} for item in batch])
+ # except Exception as e:
+ # print("collect error in key", k)
+ # print("files", [b['fname'] for b in batch])
+ # print("shape", [item[k].shape for item in batch])
+ # print("error", e)
collated_batch = default_collate(batch)
diff --git a/GenAU/train/genau.py b/GenAU/train/genau.py
index d96d7e3..5515659 100755
--- a/GenAU/train/genau.py
+++ b/GenAU/train/genau.py
@@ -32,7 +32,7 @@
copy_test_subset_data,
)
from src.utilities.model.model_util import instantiate_from_config
-from src.utilities.data.videoaudio_dataset import VideoAudioDataset
+from src.utilities.data.videoaudio_dataset import VideoAudioDataset, custom_collate_fn
from src.tools.download_manager import get_checkpoint_path
logging.basicConfig(level=logging.WARNING)
@@ -69,6 +69,7 @@ def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False):
num_workers=configs['data'].get('num_workers', 32),
pin_memory=True,
shuffle=True,
+ collate_fn=custom_collate_fn
)
print(
@@ -82,6 +83,7 @@ def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False):
val_dataset,
num_workers=configs['data'].get('num_workers', 32),
batch_size=max(1, batch_size // configs['model']['params']['evaluation_params']['n_candidates_per_samples']),
+ collate_fn=custom_collate_fn
)
# Copy test data
@@ -96,7 +98,8 @@ def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False):
config_reload_from_ckpt = configs.get("reload_from_ckpt", None)
limit_val_batches = configs["step"].get("limit_val_batches", None)
limit_train_batches = configs["step"].get("limit_train_batches", None)
- validation_every_n_epochs = configs["step"]["validation_every_n_epochs"]
+ validation_every_n_epochs = configs["step"].get("validation_every_n_epochs", None)
+ val_check_interval = configs["step"].get("val_check_interval", None)
max_steps = configs["step"]["max_steps"]
save_top_k = configs["logging"]["save_top_k"]
save_checkpoint_every_n_steps = configs["logging"]["save_checkpoint_every_n_steps"]
@@ -173,6 +176,7 @@ def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False):
limit_val_batches=limit_val_batches,
limit_train_batches=limit_train_batches,
check_val_every_n_epoch=validation_every_n_epochs,
+ val_check_interval=val_check_interval,
strategy=DDPStrategy(find_unused_parameters=True),
callbacks=[checkpoint_callback],
gradient_clip_val=configs["model"]["params"].get("clip_grad", None)