Skip to content

Commit

Permalink
feat(abstractions): auto termination of backends
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyiQ committed Dec 4, 2024
1 parent d5cca3d commit 20a38ca
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 98 deletions.
6 changes: 3 additions & 3 deletions examples/abstractions/finetuning_datamanip.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def dialogue_manipulation():
# ============== Generating a dialogue, using a model to play the role of both user and assistant ==============
global llama8b_instruct
dialogue_data = Data(
"dialogue_data",
"dialogue_data1",
data_content=[
{
"input": "Is Eiffel Tower in Paris?",
Expand All @@ -120,11 +120,11 @@ def dialogue_manipulation():
]
)
dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data", backend="sglang"
dialogue_data, "dialogue_data2", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_user()
dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data", backend="sglang"
dialogue_data, "dialogue_data3", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_assistant()
print(list(dialogue_data.all_passages()))
Expand Down
276 changes: 182 additions & 94 deletions src/abstractions/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

import torch
import time
import gc
import json
import pwd
from typing import List, Tuple, Literal, Union, Dict, Callable, Optional
from nvitop import GpuProcess, Device
import multiprocessing
import subprocess
from nvitop import GpuProcess, Device
Expand Down Expand Up @@ -207,7 +210,6 @@ def get_model_size(model_repoid_or_path: str) -> float:
)
return model_size


def start_inference_backend(
model_repoid_or_path: str,
backend_type: Literal["sglang", "vllm"] = "sglang",
Expand All @@ -216,28 +218,35 @@ def start_inference_backend(
port: int = PORT_NUM,
num_gpus: int = None,
template_type: Literal["auto", "alpaca", "mistral"] = "auto",
) -> Tuple[subprocess.Popen, Callable]:
) -> Tuple[subprocess.Popen, Callable, Callable]:
"""Start an inference backend for a given model.
Returns a tuple containing the backend process and the function to process a batch of samples.
When purpose is "logprobs", the returned function will return the log probability of the prompt text itself, without generating any text. The probability will be stored in the "logprob" field of the output dictionary, with all other fields staying the same.
When purpose is "responses", the returned function will generate a response to the prompt text. The response will be stored in the "predict" field of the output dictionary, with all other fields staying the same.
:param model_repoid_or_path: The model repo ID or path (e.g., "meta-llama/Llama-3.1-8B-Instruct").
:type model_repoid_or_path: str
:param backend_type: The type of backend to start, defaults to "sglang"
:type backend_type: Literal["sglang", "vllm"], optional
:param purpose: The purpose of the backend, defaults to "logprobs"
:type purpose: Literal["responses, "logprobs"], optional
:param silent: Whether to run the backend silently, defaults to True
:type silent: bool, optional
:param port: The port number to use for the backend, defaults to PORT_NUM
:type port: int, optional
:param num_gpus: The number of GPUs to use for the backend, defaults to None (use all available GPUs)
:type num_gpus: int, optional
:param template_type: The type of template to use for the backend, defaults to "auto", which uses the appropriate template (not limited to alpaca/mistral) based on the model config file
:type template_type: Literal["auto", "alpaca", "mistral"], optional
:return: A tuple containing the backend process and the function to process a batch of samples (type signature: List[dict] -> List[dict], with optional metadata arguments)
:rtype: Tuple[subprocess.Popen, Callable]
:return: A tuple containing the backend process, the function to process a batch of samples (type signature: List[dict] -> List[dict], with optional metadata arguments), and the function to destroy the backend after use.
:rtype: Tuple[subprocess.Popen, Callable, Callable]
"""
if eval(os.environ.get("LOUD_BACKEND", "0")):
silent = False
Expand Down Expand Up @@ -307,8 +316,35 @@ def vllm_process_batch(
dic["predict"] = generated_text

return sample_dicts

def vllm_free_gpu_memory():
"""Remove the vllm model and free vllm cache. This should wipe out all GPU memory used by self."""
if destroy_model_parallel is not None:
try:
destroy_model_parallel()
except Exception as e:
print(f"destroy_model_parallel fails: {type(e)} {e}")

nonlocal vllm_model
try:
del vllm_model.llm_engine.model_executor.driver_worker
except Exception as e:
print(f"del model_executor.driver_worker fails: {type(e)} {e}")

del vllm_model.llm_engine
del vllm_model
gc.collect()
torch.cuda.empty_cache()
try:
torch.distributed.destroy_process_group()
except:
print("No process group to destroy.")

restart_ray_cluster()
gc.collect()
print("Successfully deleted the vllm model and freed the GPU memory.")

return vllm_model, vllm_process_batch
return vllm_model, vllm_process_batch, vllm_free_gpu_memory

elif backend_type == "sglang":

Expand All @@ -317,96 +353,126 @@ def vllm_process_batch(
warnings.warn(
f"SGLang backend only supports auto template type. Ignoring template_type={template_type}. This is not an issue if you simply intend to perform inference on HistLlama models, but may be an issue if the model is neither in the HistLlama family nor in SGLang's supported models list, in which case you may use NO_SGLANG=1 to disable sglang backend."
)

with open(os.devnull, "w") as devnull:
frac_static = 0.8 if purpose == "responses" else 0.7
prefill_size = 8192 if purpose == "responses" else 1024

model_size = get_model_size(model_repoid_or_path)
assert model_size is not None

if model_size <= 10 and not os.environ.get("FORCE_TP"):
args = [
"python",
"-m",
"sglang.launch_server",
"--port",
f"{port}",
f"--dp",
f"{num_gpus}",
"--model",
model_repoid_or_path,
"--mem-fraction-static",
f"{frac_static}",
"--chunked-prefill-size",
f"{prefill_size}",
"--trust-remote-code",
]

else:
min_gpus_per_instance = (
2 if model_size <= 30 else 4 if model_size <= 80 else 8
)

if os.environ.get("FORCE_TP"):
min_gpus_per_instance = int(os.environ.get("FORCE_TP"))

assert num_gpus % min_gpus_per_instance == 0
args = [
"python",
"-m",
"sglang.launch_server",
"--port",
f"{port}",
f"--tp",
f"{min_gpus_per_instance}",
f"--dp",
f"{num_gpus//min_gpus_per_instance}",
"--model",
model_repoid_or_path,
"--mem-fraction-static",
f"{frac_static}",
"--chunked-prefill-size",
f"{prefill_size}",
"--trust-remote-code",
]

# if 'int4' not in model_repoid_or_path.lower():
# args += ['--quantization', 'fp8']

if "phi" in model_repoid_or_path.lower():
args += ["--disable-flashinfer"]

if "smol" in model_repoid_or_path.lower():
args += ["--chat-template=chatml"]

print(f"Starting backend for {model_repoid_or_path} - {args}", flush=True)

if silent:
new_env = os.environ.copy()
new_env["PYTHONWARNINGS"] = "ignore"
backend = subprocess.Popen(
args, stdout=devnull, stderr=devnull, env=new_env
)
else:
backend = subprocess.Popen(args)

# Wait for backend to start
for _ in range(40):
time.sleep(30)
try:
print("Trying to connect to backend...", flush=True)

backend_key = f"{model_repoid_or_path}-{backend_type}-{purpose}-{num_gpus}"
connected = False

if os.path.exists(f"{root}/output/backend_history.json"):
with open(f"{root}/output/backend_history.json", "r") as f:
backend_history = json.load(f)
else:
backend_history = {}

print(f"Current backend history: {backend_history}", flush=True)
print(f"Looking for prior backend with key {backend_key}...", flush=True)

if backend_key in backend_history:
backend_port = backend_history[backend_key]
print(f"Found prior backend with key {backend_key} at port {backend_port}.", flush=True)

try:
sgl.set_default_backend(sgl.RuntimeEndpoint(f"http://localhost:{port}"))
connected = True
print("Connected to backend.", flush=True)
break
except:
print(
"Failed to connect to backend (this is to be expected if backend is still starting). Retrying after 30s...",
flush=True,
)
pass
else:
raise Exception("Failed to connect to backend after 20 minutes.")
del backend_history[backend_key]
print("Failed to connect to backend. Will start a new one.", flush=True)

if not connected:
with open(os.devnull, "w") as devnull:
frac_static = 0.8 if purpose == "responses" else 0.7
prefill_size = 8192 if purpose == "responses" else 1024

model_size = get_model_size(model_repoid_or_path)
assert model_size is not None

if model_size <= 10 and not os.environ.get("FORCE_TP"):
args = [
"python",
"-m",
"sglang.launch_server",
"--port",
f"{port}",
f"--dp",
f"{num_gpus}",
"--model",
model_repoid_or_path,
"--mem-fraction-static",
f"{frac_static}",
"--chunked-prefill-size",
f"{prefill_size}",
"--trust-remote-code",
]

else:
min_gpus_per_instance = (
2 if model_size <= 30 else 4 if model_size <= 80 else 8
)

if os.environ.get("FORCE_TP"):
min_gpus_per_instance = int(os.environ.get("FORCE_TP"))

assert num_gpus % min_gpus_per_instance == 0
args = [
"python",
"-m",
"sglang.launch_server",
"--port",
f"{port}",
f"--tp",
f"{min_gpus_per_instance}",
f"--dp",
f"{num_gpus//min_gpus_per_instance}",
"--model",
model_repoid_or_path,
"--mem-fraction-static",
f"{frac_static}",
"--chunked-prefill-size",
f"{prefill_size}",
"--trust-remote-code",
]

# if 'int4' not in model_repoid_or_path.lower():
# args += ['--quantization', 'fp8']

if "phi" in model_repoid_or_path.lower():
args += ["--disable-flashinfer"]

if "smol" in model_repoid_or_path.lower():
args += ["--chat-template=chatml"]

print(f"Starting backend for {model_repoid_or_path} - {args}", flush=True)

if silent:
new_env = os.environ.copy()
new_env["PYTHONWARNINGS"] = "ignore"
backend = subprocess.Popen(
args, stdout=devnull, stderr=devnull, env=new_env
)
else:
backend = subprocess.Popen(args)

print(f"Registered backend with key {backend_key} at port {port}.", flush=True)
backend_history[backend_key] = port
with open(f"{root}/output/backend_history.json", "w") as f:
json.dump(backend_history, f)

# Wait for backend to start
for _ in range(40):
time.sleep(30)
try:
print(f"Trying to connect to backend (at port {port})...", flush=True)
sgl.set_default_backend(sgl.RuntimeEndpoint(f"http://localhost:{port}"))
print("Connected to backend.", flush=True)
break
except:
print(
"Failed to connect to backend (this is to be expected if backend is still starting). Retrying after 30s...",
flush=True,
)
pass
else:
raise Exception("Failed to connect to backend after 20 minutes.")

@sgl.function
def get_response(
Expand Down Expand Up @@ -553,8 +619,30 @@ def sglang_process_batch(
raise Exception(f"More actual failures ({failure_count}) than cases not completed ({count}), which is unexpected.")

return sample_dicts

return backend, sglang_process_batch

def sglang_free_gpu_memory():
"""Wipe out all GPU memory used by the user."""
try:
backend.kill()
except:
print("backend.kill() failed.")

MY_USERNAME = pwd.getpwuid(os.getuid()).pw_name
print(f"Killing all processes on GPU for user {MY_USERNAME}.")

devices = Device.cuda.all()
signal.signal(signal.SIGCHLD, signal.SIG_IGN)
for device in devices:
processes = device.processes()
processes = GpuProcess.take_snapshots(processes.values(), failsafe=True)
for process in processes:
if process.username.lower() == MY_USERNAME.lower():
print(f'Killing process {process.pid}: {process.cmdline}')
os.kill(process.pid, signal.SIGTERM)
os.kill(process.pid, signal.SIGINT)
os.kill(process.pid, signal.SIGKILL)

return backend, sglang_process_batch, sglang_free_gpu_memory

raise ValueError(f"Backend type {backend_type} not recognized.")

Expand Down
4 changes: 3 additions & 1 deletion src/abstractions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def inference_standalone(
purpose: Literal["responses", "logprobs"],
conn: multiprocessing.connection.Connection,
):
backend, process_batch = start_inference_backend(
backend, process_batch, mopup_memory = start_inference_backend(
model_path,
backend_type,
num_gpus=num_gpus,
Expand All @@ -69,6 +69,8 @@ def inference_standalone(
map_key_fields=True,
)
print("Job finished.")
mopup_memory()
print("Memory mopup done.")
conn.send(result_data.data_path)


Expand Down
1 change: 1 addition & 0 deletions src/config/requirements-pip.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ yajl
numpy
pandas
gensim
nvitop
seaborn
argparse
guidance
Expand Down

0 comments on commit 20a38ca

Please sign in to comment.