From 6845a32f2557c98d553b0976c020f00ae921ad6b Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 28 Aug 2024 15:14:14 +0000 Subject: [PATCH] add segment support for inference --- amt/audio.py | 6 +- amt/data.py | 29 +++++++- amt/inference/transcribe.py | 133 +++++++++++++++++++++++------------- amt/run.py | 34 ++++++++- 4 files changed, 147 insertions(+), 55 deletions(-) diff --git a/amt/audio.py b/amt/audio.py index 9ab6170..b4369db 100644 --- a/amt/audio.py +++ b/amt/audio.py @@ -222,8 +222,12 @@ def _get_noise(self, noise_paths: list): for wav, sr in noises ] - for wav in noises: + for wav, path in zip(noises, noise_paths): assert wav.shape[-1] == self.num_samples, "noise wav too short" + assert not ( + torch.all(wav < 0.01).item() is True + and torch.all(wav > -0.01).item() is True + ), f"Loaded wav {path} is approximately silent which can cause NaN." return noises diff --git a/amt/data.py b/amt/data.py index 7440082..5eeaef4 100644 --- a/amt/data.py +++ b/amt/data.py @@ -1,7 +1,6 @@ import mmap import os import io -import math import random import shlex import base64 @@ -12,7 +11,7 @@ import torchaudio from multiprocessing import Pool, Queue, Process -from typing import Callable +from typing import Callable, Tuple from aria.data.midi import MidiDict from amt.tokenizer import AmtTokenizer @@ -69,6 +68,7 @@ def get_wav_segments( audio_path: str, stride_factor: int | None = None, pad_last=False, + segment: Tuple[int, int] | None = None, ): assert os.path.isfile(audio_path), "Audio file not found" config = load_config() @@ -83,6 +83,18 @@ def get_wav_segments( stride_samples = int(chunk_samples // stride_factor) assert chunk_samples % stride_samples == 0, "Invalid stride" + # Handle segmentation if provided + if segment is not None: + assert ( + segment[0] < segment[1] + ), "Invalid segment: start must be less than end" + start_time_s, end_time_s = segment + start_sample = int(start_time_s * sample_rate) + end_sample = int(end_time_s * sample_rate) + stream.seek(start_time_s) + else: + start_sample, end_sample = 0, None + stream.add_basic_audio_stream( frames_per_chunk=stride_samples, stream_index=0, @@ -90,9 +102,16 @@ def get_wav_segments( ) buffer = torch.tensor([], dtype=torch.float32) + total_samples = start_sample for stride_seg in stream.stream(): seg_chunk = stride_seg[0].mean(1) + if end_sample and total_samples + seg_chunk.shape[0] > end_sample: + samples_to_use = end_sample - total_samples + seg_chunk = seg_chunk[:samples_to_use] + + total_samples += seg_chunk.shape[0] + # Pad seg_chunk if required if seg_chunk.shape[0] < stride_samples: seg_chunk = F.pad( @@ -110,7 +129,10 @@ def get_wav_segments( if buffer.shape[0] == chunk_samples: yield buffer - if pad_last == True: + if end_sample and total_samples >= end_sample: + break + + if pad_last and buffer.shape[0] > stride_samples: yield torch.nn.functional.pad( buffer[stride_samples:], (0, chunk_samples - len(buffer[stride_samples:])), @@ -296,6 +318,7 @@ def build_synth_worker_fn( class AmtDataset(torch.utils.data.Dataset): def __init__(self, load_paths: str | list): + super().__init__() self.tokenizer = AmtTokenizer(return_tensors=True) self.config = load_config()["data"] self.mixup_fn = self.tokenizer.export_msg_mixup() diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 0050ef5..3f32a0d 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -373,7 +373,7 @@ def gpu_manager( # pid = -1 when its a pad sequence for result, (_, pid) in zip(results, batch): if pid != -1: - result_queue.put({"result": result, "pid": pid}) + result_queue.put((result, pid)) except Exception as e: logger.error(f"GPU manager failed with exception: {e}") @@ -681,12 +681,12 @@ def transcribe_file( result_queue: Queue, pid: int, tokenizer: AmtTokenizer = AmtTokenizer(), + segment: Tuple[int, int] | None = None, ): logger = logging.getLogger(__name__) logger.info(f"Getting wav segments: {file_path}") - res = [] seq = [tokenizer.bos_tok] concat_seq = [tokenizer.bos_tok] idx = 0 @@ -694,6 +694,7 @@ def transcribe_file( audio_path=file_path, stride_factor=STRIDE_FACTOR, pad_last=True, + segment=segment, ): init_idx = len(seq) # Add to gpu queue and wait for results @@ -706,25 +707,25 @@ def transcribe_file( except Exception as e: pass else: - if gpu_result["pid"] == pid: - seq = gpu_result["result"] + if gpu_result[1] == pid: + seq = gpu_result[0] break else: result_queue.put(gpu_result) if len(silent_intervals) > 0: logger.debug( - f"Seen silent intervals in segment {idx}: {silent_intervals}" + f"Seen silent intervals in audio chunk {idx}: {silent_intervals}" ) seq_adj = _process_silent_intervals( seq, intervals=silent_intervals, tokenizer=tokenizer ) - if len(seq_adj) < len(seq) - 5: + if len(seq_adj) < len(seq) - 15: logger.info( f"Removed tokens ({len(seq)} -> {len(seq_adj)}) " - f"in segment {idx} according to silence in intervals: " + f"in audio chunk {idx} according to silence in intervals: " f"{silent_intervals}", ) seq = seq_adj @@ -736,7 +737,9 @@ def transcribe_file( LEN_MS - CHUNK_LEN_MS, ) except Exception as e: - logger.info(f"Failed to reconcile segment {idx}: {file_path}") + logger.info( + f"Failed to reconcile sequences for audio chunk {idx}: {file_path}" + ) logger.debug(traceback.format_exc()) try: @@ -755,11 +758,13 @@ def transcribe_file( else: if seq[-1] == tokenizer.eos_tok: - logger.info(f"Seen eos_tok at segment {idx}: {file_path}") + logger.info(f"Seen eos_tok in audio chunk {idx}: {file_path}") seq = seq[:-1] if len(next_seq) == 1: - logger.info(f"Skipping segment {idx} (silence): {file_path}") + logger.info( + f"Skipping audio chunk {idx} (silence): {file_path}" + ) seq = [tokenizer.bos_tok] else: concat_seq += _shift_onset( @@ -770,9 +775,7 @@ def transcribe_file( idx += 1 - res.append(concat_seq) - - return res + return concat_seq def get_save_path( @@ -806,6 +809,7 @@ def process_file( save_dir: str, input_dir: str, logger: logging.Logger, + segments: List[Tuple[int, int]] | None = None, ): def _save_seq(_seq: List, _save_path: str): if os.path.exists(_save_path): @@ -846,26 +850,41 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int): return num_removed pid = threading.get_ident() - try: - seqs = transcribe_file(file_path, gpu_task_queue, result_queue, pid=pid) - except Exception as e: - logger.error(f"Failed to process {file_path}: {traceback.format_exc()}") - task_rmv_cnt = remove_failures_from_queue_(gpu_task_queue, pid) - res_rmv_cnt = remove_failures_from_queue_(result_queue, pid) - logger.info(f"Removed {task_rmv_cnt} from task queue") - logger.info(f"Removed {res_rmv_cnt} from result queue") - return + if segments is None: + segments = [None] + + if len(segments) == 0: + logger.info(f"No segments to transcribe, skipping file: {file_path}") + + for idx, segment in enumerate(segments): + try: + seq = transcribe_file( + file_path, + gpu_task_queue, + result_queue, + pid=pid, + segment=segment, + ) + except Exception as e: + logger.error( + f"Failed to process {file_path} segment {idx}: {traceback.format_exc()}" + ) + task_rmv_cnt = remove_failures_from_queue_(gpu_task_queue, pid) + res_rmv_cnt = remove_failures_from_queue_(result_queue, pid) + logger.info(f"Removed {task_rmv_cnt} from task queue") + logger.info(f"Removed {res_rmv_cnt} from result queue") + continue - logger.info(f"Finished file: {file_path}") - for seq in seqs: + logger.info(f"Finished file: {file_path} (segment: {idx})") if len(seq) < 500: - logger.info("Skipping seq - too short") + logger.info(f"Skipping seq - too short (segment {idx})") else: logger.debug( - f"Saving seq of length {len(seq)} from file: {file_path}" + f"Saving seq of length {len(seq)} from file: {file_path} (segment: {idx})" ) - - _save_seq(seq, get_save_path(file_path, input_dir, save_dir)) + idx = f"_{idx}" if segment is not None else "" + save_path = get_save_path(file_path, input_dir, save_dir, idx) + _save_seq(seq, save_path) logger.info(f"{file_queue.qsize()} file(s) remaining in queue") @@ -905,7 +924,7 @@ def worker( def process_file_wrapper(): while True: try: - file_path = file_queue.get(timeout=15) + file_to_process = file_queue.get(timeout=15) except Empty as e: if file_queue.empty(): logger.info("File queue empty") @@ -916,14 +935,15 @@ def process_file_wrapper(): continue process_file( - file_path, - file_queue, - gpu_task_queue, - result_queue, - tokenizer, - save_dir, - input_dir, - logger, + file_path=file_to_process["path"], + file_queue=file_queue, + gpu_task_queue=gpu_task_queue, + result_queue=result_queue, + tokenizer=tokenizer, + save_dir=save_dir, + input_dir=input_dir, + logger=logger, + segments=file_to_process.get("segments", None), ) if file_queue.empty(): @@ -943,7 +963,7 @@ def process_file_wrapper(): def batch_transcribe( - file_paths: List, + files_to_process: List[dict], model: AmtEncoderDecoder, save_dir: str, batch_size: int = 8, @@ -968,21 +988,36 @@ def batch_transcribe( os.remove("transcribe.log") if quantize is True: - logger.info("Quantising decoder weights to int8") + logger.info("Quantizing decoder weights to int8") model.decoder = quantize_int8(model.decoder) file_queue = Queue() - sorted(file_paths, key=lambda x: os.path.getsize(x), reverse=True) - for file_path in file_paths: + sorted( + files_to_process, key=lambda x: os.path.getsize(x["path"]), reverse=True + ) + for file_to_process in files_to_process: + # Only add to file_queue if transcription MIDI file doesn't exist if ( - os.path.isfile(get_save_path(file_path, input_dir, save_dir)) + os.path.isfile( + get_save_path(file_to_process["path"], input_dir, save_dir) + ) is False - ): - file_queue.put(file_path) - elif len(file_paths) == 1: - file_queue.put(file_path) + ) and os.path.isfile( + get_save_path( + file_to_process["path"], input_dir, save_dir, idx="_0" + ) + ) is False: + file_queue.put(file_to_process) + elif len(files_to_process) == 1: + file_queue.put(file_to_process) - logger.info(f"Files to process: {file_queue.qsize()}/{len(file_paths)}") + logger.info( + f"Files to process: {file_queue.qsize()}/{len(files_to_process)}" + ) + + if file_queue.qsize() == 0: + logger.info("No files to process") + return if num_workers is None: num_workers = min( @@ -1113,10 +1148,10 @@ def batch_transcribe( cleanup_processes(child_pids=child_pids) logger.info("Complete") finally: - gpu_batch_manager_process.terminate() - gpu_batch_manager_process.join() watchdog_process.terminate() watchdog_process.join() + gpu_batch_manager_process.terminate() + gpu_batch_manager_process.join() file_queue.close() file_queue.join_thread() gpu_task_queue.close() diff --git a/amt/run.py b/amt/run.py index b57d670..32bbd01 100644 --- a/amt/run.py +++ b/amt/run.py @@ -2,6 +2,7 @@ import argparse import os +import json import glob from csv import DictReader @@ -100,6 +101,12 @@ def _add_transcribe_args(subparser): help="numer of file worker processes", type=int, ) + subparser.add_argument( + "-segments_json_path", + help="", + type=str, + required=False, + ) subparser.add_argument("-bs", help="batch size", type=int, default=16) @@ -363,6 +370,7 @@ def transcribe( num_workers: int | None = None, quantize: bool = False, compile_mode: str | bool = False, + segments_json_path: str | None = None, ): """ Transcribe audio files to midi using the given model and checkpoint. @@ -448,13 +456,34 @@ def transcribe( file_paths = [load_path] batch_size = 1 + if segments_json_path and os.path.isfile(segments_json_path): + with open(segments_json_path, "r") as f: + segments_by_audio_file = json.load(f) + elif segments_json_path: + print(f"Couldn't find json file at {segments_json_path}") + segments_by_audio_file = {} + else: + segments_by_audio_file = {} + + files_to_process = [] + for audio_path in file_paths: + if segments_by_audio_file.get(audio_path, None): + file_info = { + "path": audio_path, + "segments": segments_by_audio_file[audio_path], + } + else: + file_info = {"path": audio_path} + + files_to_process.append(file_info) + if multi_gpu: gpu_ids = [ int(id) for id in os.getenv("CUDA_VISIBLE_DEVICES").split(",") ] print(f"Visible gpu_ids: {gpu_ids}") batch_transcribe( - file_paths=file_paths, + files_to_process=files_to_process, model=model, save_dir=save_dir, batch_size=batch_size, @@ -467,7 +496,7 @@ def transcribe( else: batch_transcribe( - file_paths=file_paths, + files_to_process=files_to_process, model=model, save_dir=save_dir, batch_size=batch_size, @@ -548,6 +577,7 @@ def main(): if args.compile and args.max_autotune else "reduce-overhead" if args.compile else False ), + segments_json_path=args.segments_json_path, ) else: print("Unrecognized command")