Skip to content

Commit

Permalink
Reimplement worker threadpool (#46)
Browse files Browse the repository at this point in the history
* fix prefill CUDA mem leakage

* implement threadpool
  • Loading branch information
loubbrad authored Jul 14, 2024
1 parent 525b53f commit e0f66da
Showing 1 changed file with 47 additions and 35 deletions.
82 changes: 47 additions & 35 deletions amt/inference/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
import torch._dynamo.config
import torch._inductor.config
import numpy as np
import concurrent

from torch.multiprocessing import Queue
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
from functools import wraps
from torch.cuda import is_bf16_supported
Expand Down Expand Up @@ -313,7 +315,7 @@ def gpu_manager(
try:
while True:
try:
batch = gpu_batch_queue.get(timeout=30)
batch = gpu_batch_queue.get(timeout=60)
except Exception as e:
logger.info(f"GPU timed out waiting for batch")
break
Expand Down Expand Up @@ -380,12 +382,13 @@ def gpu_batch_manager(
tasks = []
while True:
try:
task, pid = gpu_task_queue.get(timeout=0.05)
task, pid = gpu_task_queue.get(timeout=0.1)
except Exception as e:
pass
else:
tasks.append((task, pid))
continue
if gpu_batch_queue.empty() is False:
continue

# No tasks in queue -> check gpu batch queue
if gpu_batch_queue.empty() is False:
Expand Down Expand Up @@ -784,9 +787,9 @@ def remove_failures_from_queue_(_queue: Queue, _pid: int):
logger.info(f"{file_queue.qsize()} file(s) remaining in queue")


def watchdog(main_gpu_pid: int, child_pids: list):
def watchdog(main_pids: list, child_pids: list):
while True:
if not os.path.exists(f"/proc/{main_gpu_pid}"):
if not all(os.path.exists(f"/proc/{pid}") for pid in main_pids):
print("Cleaning up children...")
for pid in child_pids:
try:
Expand All @@ -805,43 +808,44 @@ def worker(
result_queue: Queue,
save_dir: str,
input_dir: str | None = None,
tasks_per_worker: int = 1,
tasks_per_worker: int = 5,
):
logger = _setup_logger(name="F")
tokenizer = AmtTokenizer()
threads = []
try:
while not file_queue.empty() or any(t.is_alive() for t in threads):
while len(threads) < tasks_per_worker and not file_queue.empty():
logging.info("Starting worker")
file_path = file_queue.get()
t = threading.Thread(
target=process_file,
args=(
file_path,
file_queue,
gpu_task_queue,
result_queue,
tokenizer,
save_dir,
input_dir,
logger,
),
)
t.start()
threads.append(t)

threads = [t for t in threads if t.is_alive()]

time.sleep(0.1)
def process_file_wrapper():
while True:
try:
file_path = file_queue.get(timeout=5)
except Exception as e:
if file_queue.empty():
logger.info("File queue empty")
break
else:
continue

for t in threads:
t.join()
process_file(
file_path,
file_queue,
gpu_task_queue,
result_queue,
tokenizer,
save_dir,
input_dir,
logger,
)

try:
with ThreadPoolExecutor(max_workers=tasks_per_worker) as executor:
futures = [
executor.submit(process_file_wrapper)
for _ in range(tasks_per_worker)
]
concurrent.futures.wait(futures)
except Exception as e:
logger.error(f"File worker failed with exception: {e}")
finally:
logger.info(f"File worker terminated")
logger.info("File worker terminated")


def batch_transcribe(
Expand Down Expand Up @@ -941,7 +945,7 @@ def batch_transcribe(
for gpu_id in range(len(gpu_ids))
]
for p in gpu_manager_processes:
child_pids.append(gpu_manager_processes.pid)
child_pids.append(p.pid)
p.start()
watchdog_process = multiprocessing.Process(
target=watchdog, args=(os.getpid(), child_pids)
Expand All @@ -963,7 +967,15 @@ def batch_transcribe(
gpu_manager_processes = [_gpu_manager_process]

watchdog_process = multiprocessing.Process(
target=watchdog, args=(os.getpid(), child_pids)
target=watchdog,
args=(
[
os.getpid(),
gpu_batch_manager_process.pid,
_gpu_manager_process.pid,
],
child_pids,
),
)
watchdog_process.start()

Expand Down

0 comments on commit e0f66da

Please sign in to comment.