Skip to content

Commit

Permalink
Replace results queue with dict (#60)
Browse files Browse the repository at this point in the history
* replace results queue with dict

* add cleanup
  • Loading branch information
loubbrad authored Dec 2, 2024
1 parent 34bd2a2 commit e4b13b4
Showing 1 changed file with 39 additions and 35 deletions.
74 changes: 39 additions & 35 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@
CHUNK_LEN_MS = LEN_MS // STRIDE_FACTOR


# TODO: Implement continuous batching in a torch.compile friendly way


def _setup_logger(name: str | None = None):
logger_name = f"[{name}] " if name else ""
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -309,7 +306,7 @@ def gpu_manager(
gpu_batch_queue: Queue,
gpu_waiting_dict: dict,
gpu_waiting_dict_lock: LockType,
result_queue: Queue,
results_dict: dict,
model: AmtEncoderDecoder,
batch_size: int,
compile_mode: str | bool = False,
Expand Down Expand Up @@ -374,7 +371,7 @@ def gpu_manager(
# pid = -1 when its a pad sequence
for result, (_, pid) in zip(results, batch):
if pid != -1:
result_queue.put((result, pid))
results_dict[pid] = result

except Exception as e:
logger.error(f"GPU manager failed with exception: {e}")
Expand Down Expand Up @@ -411,10 +408,10 @@ def _find_min_diff_batch(tasks: List, batch_size: int):

# NOTE:
# - For some reason copying gpu_waiting_dict is not working properly and is
# leading to race conditions. I've implemented a lock to stop it.
# - The size of gpu_batch_queue decreases before the code for deleting the
# corresponding entry in gpu_waiting_dict gets processed. Adding a short
# sleep is a workaround
# leading to race conditions. I've implemented a lock to stop it: The size of
# gpu_batch_queue decreases before the code for deleting the corresponding
# entry in gpu_waiting_dict gets processed. Adding a short sleep is a
# workaround.
def gpu_batch_manager(
gpu_task_queue: Queue,
gpu_batch_queue: Queue,
Expand All @@ -434,7 +431,7 @@ def gpu_batch_manager(
while True:
try:
while not gpu_task_queue.empty():
task, pid = gpu_task_queue.get_nowait()
task, pid = gpu_task_queue.get(timeout=0.05)
tasks.append((task, pid))
except Empty:
pass
Expand Down Expand Up @@ -667,7 +664,7 @@ def _get_silent_intervals(wav: torch.Tensor):
# Filter intervals by minimum length
valid = lengths > MIN_WINDOW_STEPS
silent_intervals = [
(start * MS_PER_HOP, (end - 1) * MS_PER_HOP)
(int(start * MS_PER_HOP), int((end - 1) * MS_PER_HOP))
for start, end, vl in zip(starts, ends, valid)
if vl
]
Expand All @@ -678,7 +675,7 @@ def _get_silent_intervals(wav: torch.Tensor):
def transcribe_file(
file_path,
gpu_task_queue: Queue,
result_queue: Queue,
results_dict: dict,
pid: int,
tokenizer: AmtTokenizer = AmtTokenizer(),
segment: Tuple[int, int] | None = None,
Expand All @@ -703,15 +700,12 @@ def transcribe_file(
gpu_task_queue.put(((curr_audio_segment, seq), pid))
while True:
try:
gpu_result = result_queue.get(timeout=0.01)
seq = results_dict.pop(pid)
except Exception as e:
time.sleep(0.1)
pass
else:
if gpu_result[1] == pid:
seq = gpu_result[0]
break
else:
result_queue.put(gpu_result)
break

if len(silent_intervals) > 0:
logger.debug(
Expand Down Expand Up @@ -804,7 +798,7 @@ def process_file(
file_path: str,
file_queue: Queue,
gpu_task_queue: Queue,
result_queue: Queue,
results_dict: dict,
tokenizer: AmtTokenizer,
save_dir: str,
input_dir: str,
Expand Down Expand Up @@ -865,9 +859,9 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):

try:
seq = transcribe_file(
file_path,
gpu_task_queue,
result_queue,
file_path=file_path,
gpu_task_queue=gpu_task_queue,
results_dict=results_dict,
pid=pid,
segment=segment,
)
Expand All @@ -876,9 +870,8 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
f"Failed to process {file_path} segment {idx}: {traceback.format_exc()}"
)
task_rmv_cnt = remove_failures_from_queue_(gpu_task_queue, pid)
res_rmv_cnt = remove_failures_from_queue_(result_queue, pid)
results_dict.pop(pid)
logger.info(f"Removed {task_rmv_cnt} from task queue")
logger.info(f"Removed {res_rmv_cnt} from result queue")
continue

logger.info(
Expand Down Expand Up @@ -921,7 +914,7 @@ def watchdog(main_pids: List, child_pids: List):
def worker(
file_queue: Queue,
gpu_task_queue: Queue,
result_queue: Queue,
results_dict: dict,
save_dir: str,
input_dir: str | None = None,
tasks_per_worker: int = 5,
Expand All @@ -938,15 +931,14 @@ def process_file_wrapper():
logger.info("File queue empty")
break
else:
# I'm pretty sure empty is thrown due to timeout too
logger.info("Processes timed out waiting for file queue")
continue

process_file(
file_path=file_to_process["path"],
file_queue=file_queue,
gpu_task_queue=gpu_task_queue,
result_queue=result_queue,
results_dict=results_dict,
tokenizer=tokenizer,
save_dir=save_dir,
input_dir=input_dir,
Expand All @@ -955,7 +947,8 @@ def process_file_wrapper():
)

if file_queue.empty():
return
logger.info("File queue empty after processing")
break

try:
with ThreadPoolExecutor(max_workers=tasks_per_worker) as executor:
Expand Down Expand Up @@ -1029,6 +1022,9 @@ def batch_transcribe(

# If only processing one file, add even if save file exists
if len(files_to_process) == 1:
# TODO: This workaround should be reimplemented properly
while not file_queue.empty():
file_queue.get()
file_queue.put(files_to_process[0])

logger.info(
Expand All @@ -1045,15 +1041,16 @@ def batch_transcribe(
file_queue.qsize(),
)
num_processes_per_worker = min(
5 * (batch_size // num_workers), file_queue.qsize() // num_workers
round((4 * batch_size) / num_workers),
round(file_queue.qsize() / num_workers),
)

mp_manager = Manager()
gpu_waiting_dict = mp_manager.dict()
gpu_waiting_dict_lock = mp_manager.Lock()
gpu_batch_queue = Queue()
gpu_task_queue = Queue()
result_queue = Queue()
results_dict = mp_manager.dict()

child_pids = []
logger.info(
Expand All @@ -1065,7 +1062,7 @@ def batch_transcribe(
args=(
file_queue,
gpu_task_queue,
result_queue,
results_dict,
save_dir,
input_dir,
num_processes_per_worker,
Expand Down Expand Up @@ -1100,7 +1097,7 @@ def batch_transcribe(
gpu_batch_queue,
gpu_waiting_dict,
gpu_waiting_dict_lock,
result_queue,
results_dict,
model,
batch_size,
compile_mode,
Expand Down Expand Up @@ -1132,7 +1129,7 @@ def batch_transcribe(
gpu_batch_queue,
gpu_waiting_dict,
gpu_waiting_dict_lock,
result_queue,
results_dict,
model,
batch_size,
compile_mode,
Expand Down Expand Up @@ -1174,14 +1171,21 @@ def batch_transcribe(
watchdog_process.join()
gpu_batch_manager_process.terminate()
gpu_batch_manager_process.join()

file_queue.close()
file_queue.join_thread()
gpu_task_queue.close()
gpu_task_queue.join_thread()
gpu_batch_queue.close()
gpu_batch_queue.join_thread()
result_queue.close()
result_queue.join_thread()

for p in worker_processes:
if p.is_alive():
p.terminate()
p.join()

mp_manager.shutdown()
multiprocessing.resource_tracker.unregister_after_fork = True

time_taken_s = int(time.time() - start_time)
logger.info(
Expand Down

0 comments on commit e4b13b4

Please sign in to comment.