diff --git a/AutoCap/README.md b/AutoCap/README.md index a08d729..9454fe9 100644 --- a/AutoCap/README.md +++ b/AutoCap/README.md @@ -1,119 +1,106 @@ - +[![arXiv](ARXIV ICON)](ARXIV LINK) -# GenAU inference, training, and evaluation -- [Introduction](#introduction) -- [Environemnt setup](#environment-initalization) +# AutoCap inference, training and evaluation - [Inference](#inference) - * [Audio to text script](#text-to-audio) - * [Inference a list of promots](#inference-a-list-of-prompts) + * [Audio to text script](#audio-to-text) + * [Gradio demo](#gradio-demo) + * [Caption a list of audio files](#caption-list-of-audio-files) + * [Caption your custom dataset](#caption-a-dataset) - [Training](#training) - * [GenAU](#genau) - * [Finetuning GenAU](#finetuning-genau) - * [1D-VAE (optional)](#1d-vae-optional) - [Evaluation](#evaluation) - [Cite this work](#cite-this-work) - [Acknowledgements](#acknowledgements) -# Introduction -We introduce GenAU, a transformer-based audio latent diffusion model leveraging the FIT architecture. Our model compresses mel-spectrogram data into a 1D representation and utilizes layered attention processes to achieve state-of-the-art audio generation results among open-source models. -
- -
- -
- -
- -# Environment initialization +# Environment initalization For initializing your environment, please refer to the [general README](../README.md). # Inference -## Text to Audio -To quickly generate an audio based on an input text prompt, run +## Audio to Text +To quickly generate a caption for an input audio, run ```shell -python scripts/text_to_audio.py --prompt "Horses growl and clop hooves." --model "genau-full-l" +python scripts/audio_to_text.py --wav_path + +# Example inference +python scripts/audio_to_text.py --wav_path samples/ood_samples/loudwhistle-91003.wav ``` -- This will automatically download and use the model `genau-full-l` with default settings. You may change these parameters or provide your custom model config file and checkpoint path. -- Available models include `genau-full-l` (1.25B parameters) and `genau-full-s` (493M parameters) -- These models are trained to generate ambient sounds and is incapable of generating speech or music. -- Outputs will be saved by default at `samples/model_output` using the provided prompt as the file name. +- This will automatically download `TODO` model and run the inference with the default parameters. You may change these parameters or provide your cutome model config file and checkpoint path. +- For more accurate captioning, provide meta data using `--title`, `description`, and `--video_caption` arguments. - +python app_audio2text.py +``` -## Inference a list of prompts -Optionally, you may prepare a `.txt` file with your target prompts and run +## Caption list of audio files +- Prepare all target audio files in a single folder +- Optionally, provide meta data information in `yaml` file using the following structure +```yaml +file_name.wav: + title: "video title" + description: "video description" + video_caption: "video caption" +``` +Then run the following script ```shell -python scripts/inference_file.py --list_inference --model +python scripts/inference_folder.py --folder_path --meta_data_file -# Example -python scripts/inference_file.py --list_inference samples/prompts_list.txt --model "genau-full-l" +# Example inference +python scripts/inference_folder.py --folder_path samples/ood_samples --meta_data_file samples/ood_samples/meta_data.yaml ``` +## Caption your custom dataset -## Training - -### Dataset -Please refer to the [dataset preparation README](../dataset_preperation/README.md) for instructions on downloading our dataset or preparing your own. - -### GenAU -- Prepare a yaml config file for your experiments. A sample config file is provided at `settings/simple_runs/genau.yaml` -- Specify your project name and provide your Wandb key in the config file. A Wandb key can be obtained from [https://wandb.ai/authorize](https://wandb.ai/authorize) -- Optionally, provide your S3 bucket and folder to save intermediate checkpoints. -- By default, checkpoints will be saved under `run_logs/genau/train` at the same level as the config file. +If you want to caption a large dataset, we provide a script that works with multigpus for faster inference. +- Prepare your custom dataset by following the instruction in the dataset prepeartion README (TODO) and run ```shell -# Training GenAU from scratch -python train/genau.py -c settings/simple_runs/genau.yaml -``` +python scripts/caption_dataset.py \ + --caption_store_key \ + --beam_size 2 \ + --start_idx 0 \ + --end_idx 1000000 \ + --dataset_keys "dataset_1" "dataset_2" ... -For multinode training, run -```shell -python -m torch.distributed.run --nproc_per_node=8 train/genau.py -c settings/simple_runs/genau.yaml ``` -### Finetuning GenAU +- Provide your dataset keys as registered in the dataset preperation (TODO) +- Captions will be generated and stores in each file json file with the specified caption_ store_key +- `start_idx` and `end_idx` arugments can be used to resume or distribute captioning experiments +- Add your `caption_store_key` under `keys_synonyms:gt_audio_caption` in the target yaml config file for it to be selected when the ground truth caption is not available in your audio captioning or audio generation experiments. -- Prepare your custom dataset and obtain the dataset keys following [dataset preparation README](../dataset_preperation/README.md) -- Make a copy and adjust the default config file of `genau-full-l` which you can find under `pretrained_models/genau/genau-full-l.yaml` -- Add ids for your dataset keys under `dataset2id` attribute in the config file. +# Training +### Dataset +Please refer to the dataset README (TODO) for instructions on downloading our dataset or preparing your own dataset. + +### Stage 1 (pretraining) +- Specify your model parameters in a config yaml file. A sample yaml file is given under `settings/pretraining.yaml` +- Specify your project name and provide your wandb key in the config file. A wandb key can be obtained from [https://wandb.ai/authorize](https://wandb.ai/authorize) +- Optionally, provide your S3 bucket and folder to save intermediate checkpoints. +- By default, checkpoints will be save under `run_logs/train` ```shell -# Finetuning GenAU -python train/genau.py --reload_from_ckpt 'genau-full-l' \ - --config \ - --dataset_keys "" "" ... +python train.py -c settings/pretraining.yaml ``` - -### 1D VAE (Optional) -By default, we offer a pre-trained 1D-VAE for GenAU training. If you prefer, you can train your own VAE by following the provided instructions. -- Prepare your own dataset following the instructions in the [dataset preparation README](../dataset_preperation/README.md) -- Prepare your yaml config file in a similar way to the GenAU config file -- A sample config file is provided at `settings/simple_runs/1d_vae.yaml` +### Stage 2 (finetuning) +- Prepare your finetuning config file in a similar way as the pretraining stage. Typically, you only need to provide `pretrain_path` to your pretraining checkpoint, adjust learning rate, and untoggle the freeze option for the `text_decoder`. +- A sample fintuning config is provided under `settings/finetuning.yaml` ```shell -python train/1d_vae.py -c settings/simple_runs/1d_vae.yaml +python train.py -c settings/finetuning.yaml ``` -## Evaluation -- We follow [audioldm](https://github.com/haoheliu/AudioLDM-training-finetuning) to perform our evaulations. -- By default, the models will be evaluated periodically during training as specified in the config file. For each evaluation, a folder with the generated audio will be saved under `run_logs/train' at the same levels as the specified config file. -- The code identifies the test dataset in an already existing folder according to the number of samples. If you would like to test on a new test dataset, register it in `scripts/generate_and_eval` +# Evalution +- By default, the models will be log metrics on the validation set to wandb periodically during training as specified in the config file. +- We exclude the `spice`, `spideer` and `meteor` metrics during training as they tend to hang out the training during multigpu training. You man inlcude them by changing the configruation. +- A file with the predicted captions during evaluation will be saved under `run_logs/train` and metrics can be found in a file named `output.txt` under the logging folder. +- To run the evaluation on the test set, after the training finishes, run: ```shell - -# Evaluate an existing generated folder -python scripts/evaluate.py --log_path - -# Geneate test audios from a pre-trained checkpoint and run evaulation -python scripts/generate_and_eval.py -c -ckpt +python evaluate.py -c -ckpt ``` -The evaluation result will be saved in a JSON file at the same level of the generated audio folder. # Cite this work If you found this useful, please consider citing our work @@ -122,5 +109,6 @@ If you found this useful, please consider citing our work ``` # Acknowledgements -Our audio generation and evaluation codebase relies on [audioldm](https://github.com/haoheliu/AudioLDM-training-finetuning). We sincerely appreciate the authors for sharing their code openly. - +We sincerely thank the authors of the following work for sharing their code publicly: +- [WavCaps: A ChatGPT-Assisted Weakly-Labelled Audio Captioning Dataset for Audio-Language Multimodal Research](https://github.com/XinhaoMei/WavCaps) +- [Audio Captioning Transformer](https://github.com/XinhaoMei/ACT/tree/main/coco_caption) \ No newline at end of file diff --git a/AutoCap/src/models/pl_htsat_q_bart_captioning.py b/AutoCap/src/models/pl_htsat_q_bart_captioning.py index c010656..aa29eed 100644 --- a/AutoCap/src/models/pl_htsat_q_bart_captioning.py +++ b/AutoCap/src/models/pl_htsat_q_bart_captioning.py @@ -57,7 +57,6 @@ def forward(self, token_ids): token_ids = torch.where(special_tokens_mask, self.base_tokenizer.pad_token_id, token_ids) embeddings = self.model.encoder.embed_tokens(token_ids) - embeddings.view(-1, embeddings.shape[-1])[special_tokens_mask.view(-1)] = special_tokens_embeds return embeddings @@ -198,7 +197,7 @@ def __init__(self, config): self.decoder = BartForConditionalGeneration.from_pretrained(decoder_name) else: bart_config = BartConfig.from_pretrained(decoder_name) - self.decoder = BartForConditionalGeneration.from_config(bart_config) + self.decoder = BartForConditionalGeneration(config=bart_config) self.set_decoder_requires_grad(freeze=freeze_decoder, freeze_embed_layer=freeze_embed_layer) @@ -381,7 +380,9 @@ def __init__(self, config): else: text_max_tokens = self.num_text_query_token - self.decoder = self.adjust_max_pos_embeds(self.decoder, audio_max_tokens+text_max_tokens+2+self.use_clap_embeds) # extra token for CLAP + # extra token for CLAP, two for offsent in bart, two for bos and eos, and the rest because why not + self.decoder = self.adjust_max_pos_embeds(self.decoder, audio_max_tokens+text_max_tokens+20+self.use_clap_embeds) + # dropout kayer self.audio_features_dropout = nn.Dropout(p=config['model']['audio_features_dropout_p']) @@ -875,7 +876,7 @@ def get_meta_dict(self, batch, meta_keys=None, drop_keys=[]): def training_step(self, batch, batch_idx): audio = batch['waveform'].squeeze(1) text = batch['gt_audio_caption'] # list of captions - audio = audio.to(self.device, non_blocking=True) + audio = audio.to(self.device) # prepare meta if 'meta_keys' in self.config['model'].keys(): @@ -999,11 +1000,15 @@ def on_validation_epoch_start(self): def on_validation_epoch_end(self): + + # place a barrier to ensure that all ranks reach the gather operation + torch.distributed.barrier() gathered = [None] * torch.distributed.get_world_size() torch.distributed.all_gather_object(gathered, self.val_outputs) - if not self.trainer.is_global_zero: - return - + torch.distributed.barrier() + + metrics_log = {} + # all ranks should excute the blocks to avoid deadlock self.gathered_output_dict = [] for idx in range(len(self.val_loaders_labels)): val_loader_out = {} @@ -1015,7 +1020,7 @@ def on_validation_epoch_end(self): self.gathered_output_dict.append(val_loader_out) - + val_logger = logger.bind(indent=1) for split, dataloader_outputs in zip(self.val_loaders_labels, self.gathered_output_dict): val_logger.info(f"[INFO] evaluating metrics for split: {split}") @@ -1051,12 +1056,12 @@ def get_score(metrics, key): f'Spider score using beam search (beam size:{beam_size}): {spider:7.4f}') metrics_log = {f"{split}/spider_beam_{beam_size}" : spider, - f"{split}/cider_beam_{beam_size}":cider, - f"{split}/spice_beam_{beam_size}":spice, - f"{split}/bleu_1_beam_{beam_size}":bleu_1, + f"{split}/cider_beam_{beam_size}":cider, + f"{split}/spice_beam_{beam_size}":spice, + f"{split}/bleu_1_beam_{beam_size}":bleu_1, f"{split}/bleu_4_beam_{beam_size}":bleu_4, - f"{split}/rouge_l_beam_{beam_size}":rouge_l, - f"{split}/meteor_beam_{beam_size}":meteor } + f"{split}/rouge_l_beam_{beam_size}":rouge_l, + f"{split}/meteor_beam_{beam_size}":meteor } if 'bert_score' in metrics: bert_score = metrics.pop('bert_score') metrics_log[f"{split}/bertscore_beam_{beam_size}"] = bert_score diff --git a/AutoCap/train.py b/AutoCap/train.py index 9fd2069..57b0260 100644 --- a/AutoCap/train.py +++ b/AutoCap/train.py @@ -49,7 +49,7 @@ def main(): config["optim_args"]["lr"] = args.lr if args.num_workers is not None: - config['data_args']['num_workers'] + config['data_args']['num_workers'] = args.num_workers # set up model devices = torch.cuda.device_count() @@ -90,6 +90,7 @@ def main(): return_test=False, cache_dir=None) + # print training settings printer = PrettyPrinter() main_logger.info('Training setting:\n' @@ -112,7 +113,7 @@ def main(): main_logger.info(f'Size of {val_k} validation set: {len(val_loader.dataset)}, size of batches: {len(val_loader)}') - # update the model with data types, combine test and val loaders # TODO: delete the test + print("val_loaders", len(val_loaders), val_loaders.keys()) model.val_loaders_labels = list(val_loaders.keys()) val_loaders = list(val_loaders.values()) @@ -120,7 +121,7 @@ def main(): # ckpt - validation_every_n_epochs = config["step"]["validation_every_n_epochs"] + validation_every_n_epochs = config["step"].get("validation_every_n_epochs", None) save_checkpoint_every_n_epochs = config["logging"]["save_checkpoint_every_n_epochs"] checkpoint_callback = S3ModelCheckpoint( @@ -146,7 +147,7 @@ def main(): limit_val_batches=config['step'].get('limit_val_batches', None), limit_train_batches=config['step'].get('limit_train_batches', None), check_val_every_n_epoch=validation_every_n_epochs, - strategy=DDPStrategy(find_unused_parameters=True), + strategy=DDPStrategy(find_unused_parameters=False), callbacks=[checkpoint_callback], gradient_clip_val=config["model"].get("clip_grad", None), profiler=config['training'].get('profiler', None), diff --git a/GenAU/audioldm_eval/audioldm_eval/datasets/load_mel.py b/GenAU/audioldm_eval/audioldm_eval/datasets/load_mel.py index 8c35eb3..ed89118 100644 --- a/GenAU/audioldm_eval/audioldm_eval/datasets/load_mel.py +++ b/GenAU/audioldm_eval/audioldm_eval/datasets/load_mel.py @@ -130,7 +130,7 @@ def __init__( sr=16000, limit_num=None, ): - self.datalist = [os.path.join(datadir, x) for x in os.listdir(datadir)] + self.datalist = [os.path.join(datadir, x) for x in os.listdir(datadir) if x.endswith('.wav')] self.datalist = sorted(self.datalist) if limit_num is not None: self.datalist = self.datalist[:limit_num] diff --git a/GenAU/src/tools/training_utils.py b/GenAU/src/tools/training_utils.py index d8a7257..8340d25 100755 --- a/GenAU/src/tools/training_utils.py +++ b/GenAU/src/tools/training_utils.py @@ -61,7 +61,7 @@ def build_dataset_json_from_list(list_path): wav = "" data.append( { - "wav": wav, + "fname": wav, "caption": caption, } ) @@ -83,7 +83,7 @@ def read_json(dataset_json_file): def copy_test_subset_data(metadata, testset_copy_target_path): # metadata = read_json(testset_metadata) os.makedirs(testset_copy_target_path, exist_ok=True) - if len(os.listdir(testset_copy_target_path)) == len(metadata): + if len(os.listdir(testset_copy_target_path)) >= len(metadata) - 1: return else: # delete files in folder testset_copy_target_path diff --git a/GenAU/src/utilities/data/videoaudio_dataset.py b/GenAU/src/utilities/data/videoaudio_dataset.py index 1f76c41..6e960de 100644 --- a/GenAU/src/utilities/data/videoaudio_dataset.py +++ b/GenAU/src/utilities/data/videoaudio_dataset.py @@ -17,6 +17,9 @@ from src.tools.io import load_file, write_json, load_json from src.tools.torch_utils import spectral_normalize_torch, random_uniform from src.tools.training_utils import build_dataset_json_from_list +import gc +import librosa +import threading class VideoAudioDataset(Dataset): def __init__( @@ -33,6 +36,7 @@ def __init__( dataset_json=None, sample_single_caption=True, augment_p=0.0, + limit_data_percentage = None, cache_dir=None ): """ @@ -48,13 +52,19 @@ def __init__( self.load_audio = load_audio self.keep_audio_files = keep_audio_files self.sample_single_caption = sample_single_caption + self.limit_data_percentage = config['data'].get('limit_data_percentage', False) self.trim_wav = False self.waveform_only = waveform_only self.augment_p = augment_p self.add_ons = [eval(x) for x in add_ons] - self.cache_dir = cache_dir self.consistent_start_time = config['data'].get('consistent_start_time', False) + + self.cache_dir = config['data'].get('cache_dir', None) + if self.cache_dir is not None: + os.makedirs(self.cache_dir, exist_ok=True) + print("[INFO] Add-ons:", self.add_ons) + self.obtained_samples = 0 # transforms if video_transform is None: @@ -80,6 +90,20 @@ def __init__( % (split, self.config["data"].keys()) ) self.retrieve_paths() + + + if split=='train' and self.limit_data_percentage: + print(f"[INFO] limiting data to only {self.limit_data_percentage} of the total data {len(self.data)}") + num_datapoints = int(len(self.data) * self.limit_data_percentage) + + # fix the seed to make sure we select the same data. + np.random.seed(42) + selected_idx = np.random.randint(0, len(self.data), size=num_datapoints) + + # select + self.video_json_paths = np.asarray(self.video_json_paths)[selected_idx] + self.data = np.asarray(self.data)[selected_idx] + self.datasets_of_datapoints = np.asarray(self.datasets_of_datapoints)[selected_idx] self.build_dsp() @@ -108,71 +132,130 @@ def get_data_from_keys(self, data, key, default_value=None): return data[key] return default_value # Or return a default value if none of the keys are found - def __getitem__(self, index, augment=True): - ( - index, - fname, - video_frames, - waveform, - stft, - log_mel_spec, - _, # the one-hot representation of the audio class - (datum, mix_datum), - random_start, - ) = self.feature_extraction(index) - - if '.json' in self.data[index]: - dataset_name = self.datasets_of_datapoints[index] - absolute_file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0] - else: - dataset_name = absolute_file_path = "" - + + def default_sample(self): data = { - "dataset_name": dataset_name, - "json_path": absolute_file_path, - "fname": fname, # list - "waveform": "" if (not self.load_audio) else waveform.float(), - # tensor, [batchsize, t-steps, f-bins] - "stft": "" if (stft is None) else stft.float(), - # tensor, [batchsize, t-steps, mel-bins] - "log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(), - "duration": self.duration, - "sampling_rate": self.sampling_rate, - "random_start_sample_in_original_audio_file": random_start if random_start is not None else 0, - "labels": ', '.join(datum.get('labels', [])), - - # # video - "frames": video_frames if self.load_video else "", + "dataset_name": "UNK", + "json_path": "UNK", + "fname": "UNK", # list + "waveform": "" if (not self.load_audio) else torch.zeros(1, int(self.sampling_rate * self.duration)), + # "waveform": torch.zeros(1, int(self.sampling_rate * self.duration)), + # tensor, [batchsize, t-steps, f-bins] + "stft": "" if self.waveform_only else torch.zeros(int(self.duration * 100), 512), + # tensor, [batchsize, t-steps, mel-bins] + "log_mel_spec": "" if self.waveform_only else torch.zeros(int(self.duration * 100), 64), + "duration": self.duration, + "sampling_rate": self.sampling_rate, + "random_start_sample_in_original_audio_file": -1, + "labels": "UNK", + + # # video + "frames": "", + + # additional meta data + "title": "UNK", + "url": "UNK", + "description": "UNK", + "original_captions": "UNK", + "automatic_captions": "UNK", + "gt_audio_caption": "UNK" if self.sample_single_caption else ["UNK"] * 5, + "video_caption": "UNK", + "videollama_caption": "UNK", + "text": "UNK" if self.sample_single_caption else ["UNK"] * 5 + } - # additional meta data - "title": self.filter_text(datum.get('title', '')), - "url": self.filter_text(datum.get('url', '')), - "description": self.filter_text(self.get_sample_description(datum)), - "original_captions": self.filter_text(datum.get('original_captions', '')), - "automatic_captions": self.filter_text(datum.get('automatic_captions', '')), - "gt_audio_caption": self.get_sample_caption(datum, index=index), - "panda_caption": datum.get('panda70m_caption_0000', '').replace("", "").strip(), - "videollama_caption": datum.get('videollama_caption_0000', ''), - } - - # select one caption if multiple exists - if isinstance(data['gt_audio_caption'], list) and len(data['gt_audio_caption']) > 0 and self.sample_single_caption: - idx = np.random.randint(len(data['gt_audio_caption'])) - data['gt_audio_caption'] = data['gt_audio_caption'][idx] - - - for add_on in self.add_ons: - data.update(add_on(self.config, data, self.data[index])) + return data - # augment data - if augment and np.random.rand() < self.augment_p: - data = self.pair_augmentation(data) + def __getitem__(self, index, augment=True): - data['text'] = data['gt_audio_caption'] - if not data['fname']: - data['fname'] = data['text'] + retries = 0 + max_retries = 1 - return data + while retries < max_retries: + try: + if '.json' in self.data[index]: + dataset_name = self.datasets_of_datapoints[index] + absolute_file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0] + if not os.path.exists(absolute_file_path): + print(f"file {absolute_file_path} does not exists. Retying..") + index = random.randint(0, len(self.data) - 1) + retries += 1 + continue + else: + dataset_name = absolute_file_path = "" + + ( + index, + fname, + video_frames, + waveform, + stft, + log_mel_spec, + _, # the one-hot representation of the audio class + (datum, mix_datum), + random_start, + ) = self.feature_extraction(index) + + data = { + "dataset_name": dataset_name, + "json_path": absolute_file_path, + "fname": fname, # list + "waveform": "" if (not self.load_audio) else waveform.float(), + # tensor, [batchsize, t-steps, f-bins] + "stft": "" if (stft is None) else stft.float(), + # tensor, [batchsize, t-steps, mel-bins] + "log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(), + "duration": self.duration, + "sampling_rate": self.sampling_rate, + "random_start_sample_in_original_audio_file": -1 if random_start is None else random_start, + "labels": ', '.join(datum.get('labels', [])), + + # # video + "frames": video_frames if self.load_video else "", + + # additional meta data + "title": self.filter_text(datum.get('title', '')), + "url": self.filter_text(datum.get('url', '')), + "description": self.filter_text(self.get_sample_description(datum)), + "original_captions": self.filter_text(datum.get('original_captions', '')), + "automatic_captions": self.filter_text(datum.get('automatic_captions', '')), + "gt_audio_caption": self.get_sample_caption(datum, index=index), + "video_caption": datum.get('panda70m_caption_0000', '').replace("", "").strip(), + "videollama_caption": datum.get('videollama_caption_0000', ''), + } + + # select one caption if multiple exists + if isinstance(data['gt_audio_caption'], list) and len(data['gt_audio_caption']) > 0 and self.sample_single_caption: + idx = np.random.randint(len(data['gt_audio_caption'])) + data['gt_audio_caption'] = data['gt_audio_caption'][idx] + + + for add_on in self.add_ons: + data.update(add_on(self.config, data, self.data[index])) + + # augment data + if augment and np.random.rand() < self.augment_p: + data = self.pair_augmentation(data) + + data['text'] = data['gt_audio_caption'] + + self.obtained_samples += 1 + + if self.obtained_samples % 20 == 0: + gc.collect() + return data + except Exception as e: + if '.json' in self.data[index]: + dataset_name = self.datasets_of_datapoints[index] + file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0] + else: + file_path = "" + + index = random.randint(0, len(self.data) - 1) + retries += 1 + print("[ERROR, videoaudio_dataset] error while loading", file_path, e) + continue + return self.default_sample() def text_to_filename(self, text): return text.replace(" ", "_").replace("'", "_").replace('"', "_") @@ -198,134 +281,124 @@ def __len__(self): def replace_extension(self, path, new_ext): return f"{'/'.join(path.split('.')[:-1])}.{new_ext}" + def feature_extraction(self, index): - if index > len(self.data) - 1: - print( - "The index of the dataloader is out of range: %s/%s" - % (index, len(self.data)) - ) - index = random.randint(0, len(self.data) - 1) - # Read wave file and extract feature - while True: - try: - if isinstance(self.data[index], str) and '.json' in self.data[index]: - dataset_name = self.datasets_of_datapoints[index] - file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0] - datum = load_json(file_path) - else: - datum = self.data[index] + if isinstance(self.data[index], str) and '.json' in self.data[index]: + dataset_name = self.datasets_of_datapoints[index] + file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0] + datum = load_json(file_path) + else: + datum = self.data[index] - if 'path' in datum and datum['path']: - datum['path'] = self._relative_path_to_absolute_path([datum['path']], dataset_name)[0] + if 'path' in datum and datum['path']: + datum['path'] = self._relative_path_to_absolute_path([datum['path']], dataset_name)[0] - if 'wav' in datum and datum['wav']: - datum['wav'] = self._relative_path_to_absolute_path([datum['wav']], dataset_name)[0] + if 'wav' in datum and datum['wav']: + datum['wav'] = self._relative_path_to_absolute_path([datum['wav']], dataset_name)[0] + + random_start = None + log_mel_spec, stft, waveform, frames = None, None, None, None + audio_file = None + + if self.load_audio and not ('wav' in datum.keys() and os.path.exists(datum['wav'])): + # assume that a .wav file exists in the same location as the .json file + wav_path = self.replace_extension(file_path, 'wav') + flac_path = self.replace_extension(file_path, 'flac') + if os.path.exists(wav_path): + datum['wav'] = wav_path + elif os.path.exists(flac_path): + datum['wav'] = flac_path + elif 'wav' in datum: + del datum['wav'] + + # cache wav file: useful when there exists a local memory the is faster to do read operations on it + if self.load_audio and 'wav' in datum and self.cache_dir is not None: + target_audio_file_path = f"{self.cache_dir}{datum['wav']}" + if not os.path.exists(target_audio_file_path): + os.makedirs(os.path.dirname(target_audio_file_path), exist_ok=True) + shutil.copy2(datum['wav'] , target_audio_file_path) + + # update + datum['wav'] = target_audio_file_path + + save_random_start = False + random_start = None + if self.consistent_start_time: # always sample from the same start time + if 'random_start_t' in datum: + random_start = datum.get('random_start_t', None) + save_random_start = False + else: + save_random_start = True + + # load audio + if self.load_audio: + if 'wav' in datum: + ( + log_mel_spec, + stft, + waveform, + random_start, + ) = self.read_audio_file(datum["wav"], random_start=random_start) - random_start = None - log_mel_spec, stft, waveform, frames = None, None, None, None - audio_file = None - - if self.load_audio and not ('wav' in datum.keys() and os.path.exists(datum['wav'])): - # assume that a .wav file exists in the same location as the .json file - wav_path = self.replace_extension(file_path, 'wav') - flac_path = self.replace_extension(file_path, 'flac') - if os.path.exists(wav_path): - datum['wav'] = wav_path - elif os.path.exists(flac_path): - datum['wav'] = flac_path - elif 'wav' in datum: - del datum['wav'] - - # cache wav file: useful when there exists a local memory the is faster to do read operations on it - if self.load_audio and 'wav' in datum and self.cache_dir is not None: - target_audio_file_path = f"{self.cache_dir}{datum['wav']}" - if not os.path.exists(target_audio_file_path): - os.makedirs(os.path.dirname(target_audio_file_path), exist_ok=True) - shutil.copy2(datum['wav'] , target_audio_file_path) - - # update - datum['wav'] = target_audio_file_path - save_random_start = False - random_start = None - if self.consistent_start_time: # always sample from the same start time - if 'random_start_t' in datum: - random_start = datum.get('random_start_t', None) - save_random_start = False - else: - save_random_start = True + waveform = torch.FloatTensor(waveform) - # load audio - if self.load_audio: - if 'wav' in datum: - ( - log_mel_spec, - stft, - waveform, - random_start, - ) = self.read_audio_file(datum["wav"], random_start=random_start) - waveform = torch.FloatTensor(waveform) - - else: - ( - frames, - log_mel_spec, - stft, - waveform, - random_start, - audio_file - ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True) - waveform = torch.FloatTensor(waveform) - - # load video - if self.load_video and 'path' in datum: - (frames, _, _, _, _, _ ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=self.load_audio and waveform is None) - elif self.load_video and 'path' in datum: - ( - frames, - log_mel_spec, - stft, - waveform, - random_start, - audio_file - ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True) - waveform = torch.FloatTensor(waveform) - if audio_file is not None: - # update json to include path to audio. Only effective if keep_audio_file is enabled - updated_json = load_json(file_path) - updated_json['wav'] = self._absolute_path_to_relative_path([audio_file], dataset_name)[0] - datum["wav"] = updated_json['wav'] - updated_json['random_start_t'] = random_start - write_json(updated_json, file_path) - - elif save_random_start and random_start is not None: - # update json to include the randomly sampled start time for future experiments - updated_json = load_json(file_path) - updated_json['random_start_t'] = random_start - write_json(updated_json, file_path) - - mix_datum = None - if self.load_video: - assert frames.shape == (3, self.target_frame_cnt, self.frame_width, self.frame_height) - break + + else: + ( + frames, + log_mel_spec, + stft, + waveform, + random_start, + audio_file + ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True) + waveform = torch.FloatTensor(waveform) + + # load video + if self.load_video and 'path' in datum: + (frames, _, _, _, _, _ ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=self.load_audio and waveform is None) + + elif self.load_video and 'path' in datum: + ( + frames, + log_mel_spec, + stft, + waveform, + random_start, + audio_file + ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True) + waveform = torch.FloatTensor(waveform) + + if audio_file is not None: + # update json to include path to audio. Only effective if keep_audio_file is enabled + updated_json = load_json(file_path) + updated_json['wav'] = self._absolute_path_to_relative_path([audio_file], dataset_name)[0] + datum["wav"] = updated_json['wav'] + updated_json['random_start_t'] = random_start + # write_json(updated_json, file_path) + + elif save_random_start and random_start is not None: + # update json to include the randomly sampled start time for future experiments + updated_json = load_json(file_path) + updated_json['random_start_t'] = random_start + write_json(updated_json, file_path) + + mix_datum = None + if self.load_video: + assert frames.shape == (3, self.target_frame_cnt, self.frame_width, self.frame_height) - except Exception as e: - if '.json' in self.data[index]: - dataset_name = self.datasets_of_datapoints[index] - file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0] - else: - file_path = "" - - index = (index + 1) % len(self.data) - print("[ERROR, videoaudio_dataset] error while loading", file_path, e) - continue # The filename of the wav file - fname = datum["path"] if 'path' in datum and self.load_video else datum["wav"] + fname = datum["path"] if 'path' in datum and self.load_video else datum.get('wav', '') + + if not fname: + fname = datum['fname'] + return ( index, @@ -338,7 +411,7 @@ def feature_extraction(self, index): (datum, mix_datum), random_start, ) - + def combine_captions(self, caption1, caption2, remove_duplicates=False, background=False): """ Useful function to combine two caption when doing mixup augmentation @@ -602,9 +675,58 @@ def process_wavform(self, waveform, sr): return waveform + def load_audio_with_timeout(self, file_path, timeout): + """ + Load an audio file with a specified timeout using threading. + + :param file_path: Path to the audio file. + :param timeout: Maximum time (in seconds) to allow for loading the file. + :return: (waveform, sample_rate) if successful, None if timeout occurs. + """ + class AudioLoader(threading.Thread): + def __init__(self, file_path): + super().__init__() + self.file_path = file_path + self.result = None + + def run(self): + try: + waveform, sample_rate = torchaudio.load(self.file_path) + self.result = (waveform, sample_rate) + except Exception as e: + print(f"Failed to load audio: {e}") + self.result = None + + # Start the thread + audio_loader = AudioLoader(file_path) + audio_loader.start() + + # Wait for the thread to complete or timeout + audio_loader.join(timeout=timeout) + if audio_loader.is_alive(): + print(f"Timeout while loading {file_path}") + return None, None # Timeout case + + return audio_loader.result + + def read_wav_file(self, filename, random_start=None): + # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower - waveform, sr = torchaudio.load(filename) + # waveform = torch.from_numpy(waveform) + # print("waveform shape", waveform.shape) + waveform, sr = self.load_audio_with_timeout(filename, timeout=10) + if waveform is None: + print("[INFO] timeout when loading the audio") + # # # TODO Important, dummy audio + waveform = torch.zeros(1, int(self.sampling_rate * self.duration)) + sr = 16000 + + # waveform = torch.zeros(1, int(self.sampling_rate * self.duration)) + # sr = 16000 + # waveform, sr = torchaudio.load(filename) + # # # TODO Important, dummy audio + # waveform = torch.zeros(1, int(self.sampling_rate * self.duration)) waveform, random_start = self.random_segment_wav( waveform, target_length=int(sr * self.duration), random_start=random_start @@ -885,13 +1007,13 @@ def custom_collate_fn(batch): # for test # for k in batch[0].keys(): - # try: - # default_collate([{k:item[k]} for item in batch]) - # print("done") - # except: - # print("collect error in key", k) - # print("files", [b['fname'] for b in batch]) - # inp = [{k:item[k]} for item in batch] + # try: + # default_collate([{k:item[k]} for item in batch]) + # except Exception as e: + # print("collect error in key", k) + # print("files", [b['fname'] for b in batch]) + # print("shape", [item[k].shape for item in batch]) + # print("error", e) collated_batch = default_collate(batch) diff --git a/GenAU/train/genau.py b/GenAU/train/genau.py index d96d7e3..5515659 100755 --- a/GenAU/train/genau.py +++ b/GenAU/train/genau.py @@ -32,7 +32,7 @@ copy_test_subset_data, ) from src.utilities.model.model_util import instantiate_from_config -from src.utilities.data.videoaudio_dataset import VideoAudioDataset +from src.utilities.data.videoaudio_dataset import VideoAudioDataset, custom_collate_fn from src.tools.download_manager import get_checkpoint_path logging.basicConfig(level=logging.WARNING) @@ -69,6 +69,7 @@ def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False): num_workers=configs['data'].get('num_workers', 32), pin_memory=True, shuffle=True, + collate_fn=custom_collate_fn ) print( @@ -82,6 +83,7 @@ def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False): val_dataset, num_workers=configs['data'].get('num_workers', 32), batch_size=max(1, batch_size // configs['model']['params']['evaluation_params']['n_candidates_per_samples']), + collate_fn=custom_collate_fn ) # Copy test data @@ -96,7 +98,8 @@ def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False): config_reload_from_ckpt = configs.get("reload_from_ckpt", None) limit_val_batches = configs["step"].get("limit_val_batches", None) limit_train_batches = configs["step"].get("limit_train_batches", None) - validation_every_n_epochs = configs["step"]["validation_every_n_epochs"] + validation_every_n_epochs = configs["step"].get("validation_every_n_epochs", None) + val_check_interval = configs["step"].get("val_check_interval", None) max_steps = configs["step"]["max_steps"] save_top_k = configs["logging"]["save_top_k"] save_checkpoint_every_n_steps = configs["logging"]["save_checkpoint_every_n_steps"] @@ -173,6 +176,7 @@ def main(configs, config_yaml_path, exp_group_name, exp_name, debug=False): limit_val_batches=limit_val_batches, limit_train_batches=limit_train_batches, check_val_every_n_epoch=validation_every_n_epochs, + val_check_interval=val_check_interval, strategy=DDPStrategy(find_unused_parameters=True), callbacks=[checkpoint_callback], gradient_clip_val=configs["model"]["params"].get("clip_grad", None)