From f57b0b580835c68a97768a79125ef1ee388812ee Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 10 Jul 2024 17:06:12 +0000 Subject: [PATCH 1/4] fix prefill CUDA mem leakage --- amt/inference/transcribe.py | 103 +++++++++++++++++++++--------------- amt/run.py | 36 ++++++++----- requirements.txt | 2 +- 3 files changed, 84 insertions(+), 57 deletions(-) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 6ea9a26..52b5bd9 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -42,7 +42,6 @@ def _setup_logger(name: str | None = None): logger.propagate = False logger.setLevel(logging.DEBUG) - # Adjust the formatter to include the name before the PID if provided formatter = logging.Formatter( f"[%(asctime)s] {logger_name}%(process)d: [%(levelname)s] %(message)s", ) @@ -169,7 +168,7 @@ def prefill( x_input_pos: torch.Tensor, xa_input_pos: torch.Tensor, ): - # This is the same as decode_token, however we don't compile the prefill + # This is the same as decode_token and is separate for compilation reasons logits = model.decoder.forward( x=x, xa=xa, @@ -216,18 +215,20 @@ def process_segments( ) ): # for idx in range(min_prefix_len, MAX_BLOCK_LEN - 1): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - if idx == min_prefix_len: - logits, next_tok_ids = decode_token( - model, - x=seq[:, :idx], - xa=audio_features, - x_input_pos=torch.arange(0, idx, device=seq.device), - xa_input_pos=torch.arange( - 0, audio_features.shape[1], device=seq.device - ), - ) - else: + if idx == min_prefix_len: + logits, next_tok_ids = prefill( + model, + x=seq[:, :idx], + xa=audio_features, + x_input_pos=torch.arange(0, idx, device=seq.device), + xa_input_pos=torch.arange( + 0, audio_features.shape[1], device=seq.device + ), + ) + else: + with torch.nn.attention.sdpa_kernel( + torch.nn.attention.SDPBackend.MATH + ): logits, next_tok_ids = decode_token( model, x=seq[:, idx - 1 : idx], @@ -241,7 +242,7 @@ def process_segments( ) assert not torch.isnan(logits).any(), "NaN seen in logits" - logits[:, 389] *= 1.05 + logits[:, 389] *= 1.05 # Increase pedal-off msg logits next_tok_ids = torch.argmax(logits, dim=-1) next_tok_ids = recalculate_tok_ids( @@ -278,10 +279,10 @@ def gpu_manager( result_queue: Queue, model: AmtEncoderDecoder, batch_size: int, - compile: bool = False, + compile_mode: str | bool = False, gpu_id: int | None = None, ): - if gpu_id: + if gpu_id is not None: logger = _setup_logger(name=f"GPU-{gpu_id}") else: logger = _setup_logger(name=f"GPU") @@ -289,7 +290,7 @@ def gpu_manager( logger.info("Started GPU manager") if gpu_id is not None: - os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + torch.cuda.set_device(gpu_id) model.decoder.setup_cache( batch_size=batch_size, @@ -298,16 +299,11 @@ def gpu_manager( ) model.cuda() model.eval() - if compile is True: - global decode_token, recalculate_tok_ids - if batch_size == 1: - recalculate_tok_ids = torch.compile( - recalculate_tok_ids, mode="max-autotune-no-cudagraphs" - ) + if compile_mode is not False: + global decode_token decode_token = torch.compile( decode_token, - mode="reduce-overhead", - # mode="max-autotune", + mode=compile_mode, fullgraph=True, ) @@ -695,7 +691,7 @@ def transcribe_file( def get_save_path( file_path: str, - input_dir: str, + input_dir: str | None, save_dir: str, idx: int | str = "", ): @@ -848,7 +844,6 @@ def worker( logger.info(f"File worker terminated") -# Needs to test this for multi-gpu def batch_transcribe( file_paths: list, model: AmtEncoderDecoder, @@ -857,9 +852,15 @@ def batch_transcribe( input_dir: str | None = None, gpu_ids: int | None = None, quantize: bool = False, - compile: bool = False, + compile_mode: str | bool = False, ): assert os.name == "posix", "UNIX/LINUX is the only supported OS" + assert compile_mode in { + "reduce-overhead", + "max-autotune", + False, + }, "Invalid value for compile_mode" + torch.multiprocessing.set_start_method("spawn") num_gpus = len(gpu_ids) if gpu_ids is not None else 1 logger = _setup_logger() @@ -933,39 +934,55 @@ def batch_transcribe( result_queue, model, batch_size, - compile, + compile_mode, gpu_id, ), ) - for gpu_id in gpu_ids + for gpu_id in range(len(gpu_ids)) ] for p in gpu_manager_processes: + child_pids.append(gpu_manager_processes.pid) p.start() watchdog_process = multiprocessing.Process( - target=watchdog, args=(gpu_batch_manager_process.pid, child_pids) + target=watchdog, args=(os.getpid(), child_pids) ) watchdog_process.start() else: - gpu_manager_processes = None + _gpu_manager_process = multiprocessing.Process( + target=gpu_manager, + args=( + gpu_batch_queue, + result_queue, + model, + batch_size, + compile_mode, + ), + ) + child_pids.append(_gpu_manager_process.pid) + _gpu_manager_process.start() + gpu_manager_processes = [_gpu_manager_process] + watchdog_process = multiprocessing.Process( target=watchdog, args=(os.getpid(), child_pids) ) watchdog_process.start() - gpu_manager( - gpu_batch_queue, - result_queue, - model, - batch_size, - compile, - ) + + for p in worker_processes: + p.join() if gpu_manager_processes is not None: for p in gpu_manager_processes: + p.terminate() p.join() - for p in worker_processes: - p.terminate() - p.join() + file_queue.close() + file_queue.join_thread() + gpu_task_queue.close() + gpu_task_queue.join_thread() + gpu_batch_queue.close() + gpu_batch_queue.join_thread() + result_queue.close() + result_queue.join_thread() gpu_batch_manager_process.terminate() gpu_batch_manager_process.join() diff --git a/amt/run.py b/amt/run.py index abd545d..d22c319 100644 --- a/amt/run.py +++ b/amt/run.py @@ -89,6 +89,12 @@ def _add_transcribe_args(subparser): action="store_true", default=False, ) + subparser.add_argument( + "-max_autotune", + help="use mode=max_autotune when compiling", + action="store_true", + default=False, + ) subparser.add_argument("-bs", help="batch size", type=int, default=16) @@ -341,16 +347,16 @@ def build_maestro( def transcribe( - model_name, - checkpoint_path, - save_dir, - load_path=None, - load_dir=None, - maestro=False, - batch_size=16, - multi_gpu=False, - quantize=False, - compile=False, + model_name: str, + checkpoint_path: str, + save_dir: str, + load_path: str | None = None, + load_dir: str | None = None, + maestro: bool = False, + batch_size: int = 8, + multi_gpu: bool = False, + quantize: bool = False, + compile_mode: str | bool = False, ): """ Transcribe audio files to midi using the given model and checkpoint. @@ -449,7 +455,7 @@ def transcribe( input_dir=load_dir, gpu_ids=gpu_ids, quantize=quantize, - compile=compile, + compile_mode=compile_mode, ) else: @@ -460,7 +466,7 @@ def transcribe( batch_size=batch_size, input_dir=load_dir, quantize=quantize, - compile=compile, + compile_mode=compile_mode, ) @@ -528,7 +534,11 @@ def main(): batch_size=args.bs, multi_gpu=args.multi_gpu, quantize=args.q8, - compile=args.compile, + compile_mode=( + "max-autotune" + if args.compile and args.max_autotune + else "reduce-overhead" if args.compile else False + ), ) else: print("Unrecognized command") diff --git a/requirements.txt b/requirements.txt index 96a7801..9dd4f77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ aria @ git+https://github.com/EleutherAI/aria.git -torch >= 2.2 +torch >= 2.3 torchaudio accelerate psutil From c1a8290f60f3c8fc83d37a14ca0f7704bd994905 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 17 Jul 2024 17:06:28 +0000 Subject: [PATCH 2/4] update batch manager --- amt/inference/transcribe.py | 203 ++++++++++++++++++++++++------------ amt/tokenizer.py | 7 +- 2 files changed, 142 insertions(+), 68 deletions(-) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index f18705f..906e873 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -13,8 +13,11 @@ import numpy as np import concurrent -from torch.multiprocessing import Queue +from multiprocessing import Queue, Manager +from queue import Empty +from collections import deque from concurrent.futures import ThreadPoolExecutor +from typing import Tuple, List, Deque from tqdm import tqdm from functools import wraps from torch.cuda import is_bf16_supported @@ -185,7 +188,7 @@ def prefill( @optional_bf16_autocast @torch.no_grad() def process_segments( - tasks: list, + tasks: List, model: AmtEncoderDecoder, audio_transform: AudioTransform, tokenizer: AmtTokenizer, @@ -278,22 +281,24 @@ def process_segments( def gpu_manager( gpu_batch_queue: Queue, + gpu_waiting_list: dict, result_queue: Queue, model: AmtEncoderDecoder, batch_size: int, compile_mode: str | bool = False, gpu_id: int | None = None, ): + if gpu_id is not None: + torch.cuda.set_device(gpu_id) + if gpu_id is not None: logger = _setup_logger(name=f"GPU-{gpu_id}") else: logger = _setup_logger(name=f"GPU") + gpu_id = 0 logger.info("Started GPU manager") - if gpu_id is not None: - torch.cuda.set_device(gpu_id) - model.decoder.setup_cache( batch_size=batch_size, max_seq_len=MAX_BLOCK_LEN, @@ -315,10 +320,12 @@ def gpu_manager( try: while True: try: + gpu_waiting_list[gpu_id] = time.time() batch = gpu_batch_queue.get(timeout=60) - except Exception as e: - logger.info(f"GPU timed out waiting for batch") - break + gpu_waiting_list.pop(gpu_id, None) + except Empty as e: + gpu_waiting_list.pop(gpu_id, None) + raise e else: try: results = process_segments( @@ -342,10 +349,11 @@ def gpu_manager( except Exception as e: logger.error(f"GPU manager failed with exception: {e}") finally: + gpu_waiting_list.pop(gpu_id, None) logger.info(f"GPU manager terminated") -def _find_min_diff_batch(tasks: list, batch_size: int): +def _find_min_diff_batch(tasks: List, batch_size: int): prefix_lens = [ (len(prefix), idx) for idx, ((audio_seg, prefix), _) in enumerate(tasks) ] @@ -374,55 +382,101 @@ def _find_min_diff_batch(tasks: list, batch_size: int): def gpu_batch_manager( gpu_task_queue: Queue, gpu_batch_queue: Queue, + gpu_waiting_list: dict, batch_size: int, + max_wait_time: float = 0.1, + min_batch_size: int = 1, ): logger = _setup_logger(name="B") logger.info("Started batch manager") + + tasks: Deque[Tuple[object, int]] = deque() + gpu_wait_time = 0 + try: - tasks = [] while True: try: - task, pid = gpu_task_queue.get(timeout=0.1) - except Exception as e: + while not gpu_task_queue.empty(): + task, pid = gpu_task_queue.get_nowait() + tasks.append((task, pid)) + except Empty: pass - else: - tasks.append((task, pid)) - if gpu_batch_queue.empty() is False: - continue - # No tasks in queue -> check gpu batch queue - if gpu_batch_queue.empty() is False: - continue - elif len(tasks) == 0: - continue + curr_time = time.time() + num_gpus_waiting = len(gpu_waiting_list) + gpu_wait_time = ( + max( + [ + curr_time - wait_time_abs + for gpu_id, wait_time_abs in gpu_waiting_list.items() + ] + ) + if gpu_waiting_list + else 0.0 + ) - # Get new batch and add to batch queue - if len(tasks) < batch_size: - logger.warning( - f"Not enough tasks ({len(tasks)}) - padding batch" + should_create_batch = ( + len(tasks) >= 4 * batch_size + or ( + num_gpus_waiting > gpu_batch_queue.qsize() + and len(tasks) >= batch_size + ) + or ( + num_gpus_waiting > 0 + and len(tasks) >= min_batch_size + and gpu_wait_time > max_wait_time ) - while len(tasks) < batch_size: - _pad_task, _pid = tasks[0] - tasks.append((_pad_task, -1)) - - assert len(tasks) >= batch_size, "batch error" - new_batch_idxs = _find_min_diff_batch( - tasks, - batch_size=batch_size, ) - gpu_batch_queue.put([tasks[_idx] for _idx in new_batch_idxs]) - tasks = [ - task - for _idx, task in enumerate(tasks) - if _idx not in new_batch_idxs - ] + + if should_create_batch: + logger.debug( + f"Batch created: " + f"num_gpus_waiting={num_gpus_waiting}, " + f"gpu_wait_time={round(gpu_wait_time, 4)}s, " + f"num_tasks_ready={len(tasks)}, " + f"num_batches_ready={gpu_batch_queue.qsize()}" + ) + batch = create_batch(tasks, batch_size, min_batch_size, logger) + gpu_batch_queue.put(batch) + elif gpu_task_queue.empty(): + time.sleep(0.05) + else: + time.sleep(0.01) + except Exception as e: logger.error(f"GPU batch manager failed with exception: {e}") finally: - logger.info(f"GPU batch manager terminated") + logger.info("GPU batch manager terminated") -def _shift_onset(seq: list, shift_ms: int): +def create_batch( + tasks: Deque[Tuple[object, int]], + batch_size: int, + min_batch_size: int, + logger: logging.Logger, +): + assert len(tasks) >= min_batch_size, "Insufficient number of tasks" + + if len(tasks) < batch_size: + logger.info(f"Creating a partial padded batch with {len(tasks)} tasks") + batch_idxs = list(range(len(tasks))) + batch = [tasks.popleft() for _ in batch_idxs] + + while len(batch) < min_batch_size: + pad_task, _ = batch[0] + batch.append((pad_task, -1)) + else: + batch_idxs = _find_min_diff_batch(list(tasks), batch_size) + batch = [tasks[idx] for idx in batch_idxs] + + # Remove the selected tasks from the deque + for idx in sorted(batch_idxs, reverse=True): + del tasks[idx] + + return batch + + +def _shift_onset(seq: List, shift_ms: int): res = [] for tok in seq: if type(tok) is tuple and tok[0] == "onset": @@ -434,7 +488,7 @@ def _shift_onset(seq: list, shift_ms: int): def _truncate_seq( - seq: list, + seq: List, start_ms: int, end_ms: int, tokenizer: AmtTokenizer = AmtTokenizer(), @@ -460,8 +514,10 @@ def _truncate_seq( # TODO: Add detection for pedal messages which occur before notes are played -def process_silent_intervals( - seq: list, intervals: list, tokenizer: AmtTokenizer +def _process_silent_intervals( + seq: List, + intervals: List, + tokenizer: AmtTokenizer, ): def adjust_onset(_onset: int): # Adjusts the onset according to the silence intervals @@ -552,7 +608,7 @@ def adjust_onset(_onset: int): return res -def get_silent_intervals(wav: torch.Tensor): +def _get_silent_intervals(wav: torch.Tensor): FRAME_LEN = 2048 HOP_LEN = 512 MIN_WINDOW_S = 5 @@ -614,12 +670,12 @@ def transcribe_file( # Add to gpu queue and wait for results curr_audio_segment = audio_segments.pop(0) - silent_intervals = get_silent_intervals(curr_audio_segment) + silent_intervals = _get_silent_intervals(curr_audio_segment) input_seq = copy.deepcopy(seq) gpu_task_queue.put(((curr_audio_segment, seq), pid)) while True: try: - gpu_result = result_queue.get(timeout=0.1) + gpu_result = result_queue.get(timeout=0.01) except Exception as e: pass else: @@ -634,7 +690,7 @@ def transcribe_file( f"Seen silent intervals in segment {idx}: {silent_intervals}" ) - seq_adj = process_silent_intervals( + seq_adj = _process_silent_intervals( seq, intervals=silent_intervals, tokenizer=tokenizer ) @@ -724,7 +780,7 @@ def process_file( input_dir: str, logger: logging.Logger, ): - def _save_seq(_seq: list, _save_path: str): + def _save_seq(_seq: List, _save_path: str): if os.path.exists(_save_path): logger.info(f"Already exists {_save_path} - overwriting") @@ -762,7 +818,7 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int): return num_removed - pid = threading.get_ident() + pid = int(str(os.getpid()) + str(threading.get_ident())) try: seqs = transcribe_file(file_path, gpu_task_queue, result_queue, pid=pid) except Exception as e: @@ -787,7 +843,7 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int): logger.info(f"{file_queue.qsize()} file(s) remaining in queue") -def watchdog(main_pids: list, child_pids: list): +def watchdog(main_pids: List, child_pids: List): while True: if not all(os.path.exists(f"/proc/{pid}") for pid in main_pids): print("Cleaning up children...") @@ -816,12 +872,14 @@ def worker( def process_file_wrapper(): while True: try: - file_path = file_queue.get(timeout=5) - except Exception as e: + file_path = file_queue.get(timeout=15) + except Empty as e: if file_queue.empty(): logger.info("File queue empty") break else: + # I'm pretty sure empty is thrown due to timeout too + logger.info("Processes timed out waiting for file queue") continue process_file( @@ -835,6 +893,9 @@ def process_file_wrapper(): logger, ) + if file_queue.empty(): + return + try: with ThreadPoolExecutor(max_workers=tasks_per_worker) as executor: futures = [ @@ -849,10 +910,10 @@ def process_file_wrapper(): def batch_transcribe( - file_paths: list, + file_paths: List, model: AmtEncoderDecoder, save_dir: str, - batch_size: int = 16, + batch_size: int = 8, input_dir: str | None = None, gpu_ids: int | None = None, quantize: bool = False, @@ -865,7 +926,7 @@ def batch_transcribe( False, }, "Invalid value for compile_mode" - torch.multiprocessing.set_start_method("spawn") + torch.multiprocessing.set_start_method("forkserver") num_gpus = len(gpu_ids) if gpu_ids is not None else 1 logger = _setup_logger() @@ -893,13 +954,18 @@ def batch_transcribe( min(batch_size * num_gpus, multiprocessing.cpu_count() - num_gpus), file_queue.qsize(), ) + num_processes_per_worker = min(5, file_queue.qsize() // num_workers) + mp_manager = Manager() + gpu_waiting_list = mp_manager.dict() gpu_task_queue = Queue() gpu_batch_queue = Queue() result_queue = Queue() child_pids = [] - logger.info(f"Creating {num_workers} file worker(s)") + logger.info( + f"Creating {num_workers} file worker(s) with {num_processes_per_worker} sub-processes" + ) worker_processes = [ multiprocessing.Process( target=worker, @@ -909,7 +975,7 @@ def batch_transcribe( result_queue, save_dir, input_dir, - 5, + num_processes_per_worker, ), ) for _ in range(num_workers) @@ -921,20 +987,19 @@ def batch_transcribe( gpu_batch_manager_process = multiprocessing.Process( target=gpu_batch_manager, - args=(gpu_task_queue, gpu_batch_queue, batch_size), + args=(gpu_task_queue, gpu_batch_queue, gpu_waiting_list, batch_size), ) gpu_batch_manager_process.start() child_pids.append(gpu_batch_manager_process.pid) - time.sleep(5) start_time = time.time() - if num_gpus > 1: gpu_manager_processes = [ multiprocessing.Process( target=gpu_manager, args=( gpu_batch_queue, + gpu_waiting_list, result_queue, model, batch_size, @@ -945,10 +1010,19 @@ def batch_transcribe( for gpu_id in range(len(gpu_ids)) ] for p in gpu_manager_processes: - child_pids.append(p.pid) p.start() + child_pids.append(p.pid) + watchdog_process = multiprocessing.Process( - target=watchdog, args=(os.getpid(), child_pids) + target=watchdog, + args=( + [ + os.getpid(), + gpu_batch_manager_process.pid, + ] + + [p.pid for p in gpu_manager_processes], + child_pids, + ), ) watchdog_process.start() else: @@ -956,14 +1030,15 @@ def batch_transcribe( target=gpu_manager, args=( gpu_batch_queue, + gpu_waiting_list, result_queue, model, batch_size, compile_mode, ), ) - child_pids.append(_gpu_manager_process.pid) _gpu_manager_process.start() + child_pids.append(_gpu_manager_process.pid) gpu_manager_processes = [_gpu_manager_process] watchdog_process = multiprocessing.Process( diff --git a/amt/tokenizer.py b/amt/tokenizer.py index dd34d0b..3146e32 100644 --- a/amt/tokenizer.py +++ b/amt/tokenizer.py @@ -343,11 +343,10 @@ def _detokenize_midi_dict( # Process note and add to note msgs note_to_close = notes_to_close.pop(tok_1_data, None) if note_to_close is None: - print( - f"No 'on' token corresponding to 'off' token: {tok_1, tok_2}" - ) if DEBUG: - raise Exception + print( + f"No 'on' token corresponding to 'off' token: {tok_1, tok_2}" + ) continue else: _pitch = tok_1_data From fbd9401e04667ca53985ce3137630deaf3ea06a6 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 17 Jul 2024 19:41:38 +0000 Subject: [PATCH 3/4] fix logic --- amt/inference/transcribe.py | 91 ++++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 36 deletions(-) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index 906e873..b9c9cf8 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -14,6 +14,7 @@ import concurrent from multiprocessing import Queue, Manager +from multiprocessing.synchronize import Lock as LockType from queue import Empty from collections import deque from concurrent.futures import ThreadPoolExecutor @@ -281,7 +282,8 @@ def process_segments( def gpu_manager( gpu_batch_queue: Queue, - gpu_waiting_list: dict, + gpu_waiting_dict: dict, + gpu_waiting_dict_lock: LockType, result_queue: Queue, model: AmtEncoderDecoder, batch_size: int, @@ -320,11 +322,14 @@ def gpu_manager( try: while True: try: - gpu_waiting_list[gpu_id] = time.time() + with gpu_waiting_dict_lock: + gpu_waiting_dict[gpu_id] = time.time() batch = gpu_batch_queue.get(timeout=60) - gpu_waiting_list.pop(gpu_id, None) + with gpu_waiting_dict_lock: + del gpu_waiting_dict[gpu_id] except Empty as e: - gpu_waiting_list.pop(gpu_id, None) + with gpu_waiting_dict_lock: + del gpu_waiting_dict[gpu_id] raise e else: try: @@ -349,7 +354,7 @@ def gpu_manager( except Exception as e: logger.error(f"GPU manager failed with exception: {e}") finally: - gpu_waiting_list.pop(gpu_id, None) + del gpu_waiting_dict[gpu_id] logger.info(f"GPU manager terminated") @@ -379,12 +384,19 @@ def _find_min_diff_batch(tasks: List, batch_size: int): ] +# NOTE: +# - For some reason copying gpu_waiting_dict is not working properly and is +# leading to race conditions. I've implemented a lock to stop it. +# - The size of gpu_batch_queue decreases before the code for deleting the +# corresponding entry in gpu_waiting_dict get processed. Adding a short sleep +# is a workaround def gpu_batch_manager( gpu_task_queue: Queue, gpu_batch_queue: Queue, - gpu_waiting_list: dict, + gpu_waiting_dict: dict, + gpu_waiting_dict_lock: LockType, batch_size: int, - max_wait_time: float = 0.1, + max_wait_time: float = 0.25, min_batch_size: int = 1, ): logger = _setup_logger(name="B") @@ -402,27 +414,29 @@ def gpu_batch_manager( except Empty: pass - curr_time = time.time() - num_gpus_waiting = len(gpu_waiting_list) - gpu_wait_time = ( - max( - [ - curr_time - wait_time_abs - for gpu_id, wait_time_abs in gpu_waiting_list.items() - ] + with gpu_waiting_dict_lock: + curr_time = time.time() + num_tasks_in_batch_queue = gpu_batch_queue.qsize() + num_gpus_waiting = len(gpu_waiting_dict) + gpu_wait_time = ( + max( + [ + curr_time - wait_time_abs + for gpu_id, wait_time_abs in gpu_waiting_dict.items() + ] + ) + if gpu_waiting_dict + else 0.0 ) - if gpu_waiting_list - else 0.0 - ) should_create_batch = ( len(tasks) >= 4 * batch_size or ( - num_gpus_waiting > gpu_batch_queue.qsize() + num_gpus_waiting > num_tasks_in_batch_queue and len(tasks) >= batch_size ) or ( - num_gpus_waiting > 0 + num_gpus_waiting > num_tasks_in_batch_queue and len(tasks) >= min_batch_size and gpu_wait_time > max_wait_time ) @@ -430,18 +444,15 @@ def gpu_batch_manager( if should_create_batch: logger.debug( - f"Batch created: " + f"Creating batch: " f"num_gpus_waiting={num_gpus_waiting}, " f"gpu_wait_time={round(gpu_wait_time, 4)}s, " f"num_tasks_ready={len(tasks)}, " - f"num_batches_ready={gpu_batch_queue.qsize()}" + f"num_batches_ready={num_tasks_in_batch_queue}" ) batch = create_batch(tasks, batch_size, min_batch_size, logger) gpu_batch_queue.put(batch) - elif gpu_task_queue.empty(): - time.sleep(0.05) - else: - time.sleep(0.01) + time.sleep(0.025) except Exception as e: logger.error(f"GPU batch manager failed with exception: {e}") @@ -462,16 +473,14 @@ def create_batch( batch_idxs = list(range(len(tasks))) batch = [tasks.popleft() for _ in batch_idxs] - while len(batch) < min_batch_size: + while len(batch) < batch_size: pad_task, _ = batch[0] batch.append((pad_task, -1)) else: batch_idxs = _find_min_diff_batch(list(tasks), batch_size) batch = [tasks[idx] for idx in batch_idxs] - - # Remove the selected tasks from the deque - for idx in sorted(batch_idxs, reverse=True): - del tasks[idx] + for idx in sorted(batch_idxs, reverse=True): + del tasks[idx] return batch @@ -853,6 +862,7 @@ def watchdog(main_pids: List, child_pids: List): except ProcessLookupError: pass + print("Finished cleaning up children") return time.sleep(1) @@ -957,9 +967,10 @@ def batch_transcribe( num_processes_per_worker = min(5, file_queue.qsize() // num_workers) mp_manager = Manager() - gpu_waiting_list = mp_manager.dict() - gpu_task_queue = Queue() + gpu_waiting_dict = mp_manager.dict() + gpu_waiting_dict_lock = mp_manager.Lock() gpu_batch_queue = Queue() + gpu_task_queue = Queue() result_queue = Queue() child_pids = [] @@ -987,7 +998,13 @@ def batch_transcribe( gpu_batch_manager_process = multiprocessing.Process( target=gpu_batch_manager, - args=(gpu_task_queue, gpu_batch_queue, gpu_waiting_list, batch_size), + args=( + gpu_task_queue, + gpu_batch_queue, + gpu_waiting_dict, + gpu_waiting_dict_lock, + batch_size, + ), ) gpu_batch_manager_process.start() child_pids.append(gpu_batch_manager_process.pid) @@ -999,7 +1016,8 @@ def batch_transcribe( target=gpu_manager, args=( gpu_batch_queue, - gpu_waiting_list, + gpu_waiting_dict, + gpu_waiting_dict_lock, result_queue, model, batch_size, @@ -1030,7 +1048,8 @@ def batch_transcribe( target=gpu_manager, args=( gpu_batch_queue, - gpu_waiting_list, + gpu_waiting_dict, + gpu_waiting_dict_lock, result_queue, model, batch_size, From 5b2307cef0cb39bbbaf6686f2d15ae176286d92d Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 17 Jul 2024 20:33:01 +0000 Subject: [PATCH 4/4] switch back to spawn and add cli arg --- amt/inference/transcribe.py | 14 ++++++++------ amt/run.py | 9 +++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/amt/inference/transcribe.py b/amt/inference/transcribe.py index b9c9cf8..7913998 100644 --- a/amt/inference/transcribe.py +++ b/amt/inference/transcribe.py @@ -827,7 +827,7 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int): return num_removed - pid = int(str(os.getpid()) + str(threading.get_ident())) + pid = threading.get_ident() try: seqs = transcribe_file(file_path, gpu_task_queue, result_queue, pid=pid) except Exception as e: @@ -926,6 +926,7 @@ def batch_transcribe( batch_size: int = 8, input_dir: str | None = None, gpu_ids: int | None = None, + num_workers: int | None = None, quantize: bool = False, compile_mode: str | bool = False, ): @@ -936,7 +937,7 @@ def batch_transcribe( False, }, "Invalid value for compile_mode" - torch.multiprocessing.set_start_method("forkserver") + torch.multiprocessing.set_start_method("spawn") num_gpus = len(gpu_ids) if gpu_ids is not None else 1 logger = _setup_logger() @@ -960,10 +961,11 @@ def batch_transcribe( logger.info(f"Files to process: {file_queue.qsize()}/{len(file_paths)}") - num_workers = min( - min(batch_size * num_gpus, multiprocessing.cpu_count() - num_gpus), - file_queue.qsize(), - ) + if num_workers is None: + num_workers = min( + min(batch_size * num_gpus, multiprocessing.cpu_count() - num_gpus), + file_queue.qsize(), + ) num_processes_per_worker = min(5, file_queue.qsize() // num_workers) mp_manager = Manager() diff --git a/amt/run.py b/amt/run.py index d22c319..b57d670 100644 --- a/amt/run.py +++ b/amt/run.py @@ -95,6 +95,11 @@ def _add_transcribe_args(subparser): action="store_true", default=False, ) + subparser.add_argument( + "-num_workers", + help="numer of file worker processes", + type=int, + ) subparser.add_argument("-bs", help="batch size", type=int, default=16) @@ -355,6 +360,7 @@ def transcribe( maestro: bool = False, batch_size: int = 8, multi_gpu: bool = False, + num_workers: int | None = None, quantize: bool = False, compile_mode: str | bool = False, ): @@ -454,6 +460,7 @@ def transcribe( batch_size=batch_size, input_dir=load_dir, gpu_ids=gpu_ids, + num_workers=num_workers, quantize=quantize, compile_mode=compile_mode, ) @@ -465,6 +472,7 @@ def transcribe( save_dir=save_dir, batch_size=batch_size, input_dir=load_dir, + num_workers=num_workers, quantize=quantize, compile_mode=compile_mode, ) @@ -534,6 +542,7 @@ def main(): batch_size=args.bs, multi_gpu=args.multi_gpu, quantize=args.q8, + num_workers=args.num_workers, compile_mode=( "max-autotune" if args.compile and args.max_autotune