Skip to content

Commit

Permalink
fix model server stop process (#217)
Browse files Browse the repository at this point in the history
* fix model server stop process

* replace

* replace

* add test

* add multiple pids test

* add check install for linux

* reformat
  • Loading branch information
cotran2 authored Oct 25, 2024
1 parent ff6e9bd commit 25dddcb
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 48 deletions.
134 changes: 86 additions & 48 deletions model_server/app/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,27 @@
log = logging.getLogger("model_server.cli")
log.setLevel(logging.INFO)

# Path to the file where the server process ID will be stored
PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid")


def run_server():
def run_server(port=51000):
"""Start, stop, or restart the Uvicorn server based on command-line arguments."""
if len(sys.argv) > 1:
action = sys.argv[1]
else:
action = "start"

if action == "start":
start_server()
start_server(port)
elif action == "stop":
stop_server()
stop_server(port)
elif action == "restart":
restart_server()
restart_server(port)
else:
log.info(f"Unknown action: {action}")
sys.exit(1)


def start_server():
"""Start the Uvicorn server and save the process ID."""
if os.path.exists(PID_FILE):
log.info("Server is already running. Use 'model_server restart' to restart it.")
sys.exit(1)

def start_server(port=51000):
"""Start the Uvicorn server"""
log.info(
"Starting model server - loading some awesomeness, this may take some time :)"
)
Expand All @@ -55,7 +48,7 @@ def start_server():
"--host",
"0.0.0.0",
"--port",
"51000",
f"{port}",
],
start_new_session=True,
bufsize=1,
Expand All @@ -64,10 +57,7 @@ def start_server():
stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to
)

if wait_for_health_check("http://0.0.0.0:51000/healthz"):
# Write the process ID to the PID file
with open(PID_FILE, "w") as f:
f.write(str(process.pid))
if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"):
log.info(f"Model server started with PID {process.pid}")
else:
# Add model_server boot-up logs
Expand All @@ -89,40 +79,88 @@ def wait_for_health_check(url, timeout=180):
return False


def stop_server():
"""Stop the running Uvicorn server."""
log.info("Stopping model server")
if not os.path.exists(PID_FILE):
log.info("Process id file not found, seems like model server was not running")
return
def check_and_install_lsof():
"""Check if lsof is installed, and if not, install it using apt-get."""
try:
# Check if lsof is installed by running "lsof -v"
subprocess.run(
["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
print("lsof is already installed.")
except subprocess.CalledProcessError:
print("lsof not found, installing...")
try:
# Update package list and install lsof
subprocess.run(["sudo", "apt-get", "update"], check=True)
subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True)
print("lsof installed successfully.")
except subprocess.CalledProcessError as install_error:
print(f"Failed to install lsof: {install_error}")

# Read the process ID from the PID file
with open(PID_FILE, "r") as f:
pid = int(f.read())

def kill_process(port=51000, wait=True, timeout=10):
"""Stop the running Uvicorn server."""
log.info("Stopping model server")
try:
# Get process by PID
process = psutil.Process(pid)

# Gracefully terminate the process
process.terminate() # Sends SIGTERM by default
process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit

log.info(f"Model server with PID {pid} stopped.")
os.remove(PID_FILE)

except psutil.NoSuchProcess:
log.info(f"Model server with PID {pid} not found. Cleaning up PID file.")
os.remove(PID_FILE)
except psutil.TimeoutExpired:
log.info(
f"Model server with PID {pid} did not terminate in time. Forcing shutdown."
# Run the function to check and install lsof if necessary
# Step 1: Run lsof command to get the process using the port
lsof_command = f"lsof -n | grep {port} | grep -i LISTEN"
result = subprocess.run(
lsof_command, shell=True, capture_output=True, text=True
)
process.kill() # Forcefully kill the process
os.remove(PID_FILE)

if result.returncode != 0:
print(f"No process found listening on port {port}.")
return

# Step 2: Parse the process IDs from the output
process_ids = [line.split()[1] for line in result.stdout.splitlines()]

if not process_ids:
print(f"No process found listening on port {port}.")
return

# Step 3: Kill each process using its PID
for pid in process_ids:
print(f"Killing model server process with PID {pid}")
subprocess.run(f"kill {pid}", shell=True)

if wait:
# Step 4: Wait for the process to be killed by checking if it's still running
start_time = time.time()

while True:
check_process = subprocess.run(
f"ps -p {pid}", shell=True, capture_output=True, text=True
)
if check_process.returncode != 0:
print(f"Process {pid} has been killed.")
break

elapsed_time = time.time() - start_time
if elapsed_time > timeout:
print(
f"Process {pid} did not terminate within {timeout} seconds."
)
print(f"Attempting to force kill process {pid}...")
subprocess.run(f"kill -9 {pid}", shell=True) # SIGKILL
break

print(
f"Waiting for process {pid} to be killed... ({elapsed_time:.2f} seconds)"
)
time.sleep(0.5)

except Exception as e:
print(f"Error occurred: {e}")


def stop_server(port=51000, wait=True, timeout=10):
check_and_install_lsof()
kill_process(port, wait, timeout)


def restart_server():
def restart_server(port=51000):
"""Restart the Uvicorn server."""
stop_server()
start_server()
stop_server(port)
start_server(port)
55 changes: 55 additions & 0 deletions model_server/app/tests/test_cli_stop_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest
from unittest.mock import patch, MagicMock
import subprocess
import time
from app.cli import kill_process


class TestStopServer(unittest.TestCase):
@patch("subprocess.run")
def test_stop_server_no_process(self, mock_run):
# Mock subprocess.run to simulate no process listening on the port
mock_run.return_value.returncode = 1
with patch("builtins.print") as mock_print:
kill_process(port=51000)
mock_print.assert_called_with("No process found listening on port 51000.")

@patch("subprocess.run")
def test_stop_server_process_killed(self, mock_run):
# Simulate lsof returning a process id
mock_run.side_effect = [
MagicMock(returncode=0, stdout="uvicorn 1234 user LISTEN\n"),
MagicMock(returncode=0), # for killing the process
MagicMock(returncode=1), # for checking the process after it is killed
]
with patch("builtins.print") as mock_print:
kill_process(port=51000, wait=True, timeout=5)
mock_print.assert_any_call("Killing model server process with PID 1234")
mock_print.assert_any_call("Process 1234 has been killed.")

@patch("subprocess.run")
def test_stop_server_multiple_pids(self, mock_run):
# Simulate lsof returning multiple process ids (e.g., 1234 and 5678)
mock_run.side_effect = [
MagicMock(
returncode=0,
stdout="uvicorn 1234 user LISTEN\nuvicorn 5678 user LISTEN\n",
), # lsof output
MagicMock(returncode=0), # first kill command for PID 1234
MagicMock(returncode=1), # PID 1234 is successfully terminated
MagicMock(returncode=0), # second kill command for PID 5678
MagicMock(returncode=1), # PID 5678 is successfully terminated
]

with patch("builtins.print") as mock_print:
kill_process(port=51000, wait=True, timeout=5)

# Assert that the function tried to kill both PIDs
mock_print.assert_any_call("Killing model server process with PID 1234")
mock_print.assert_any_call("Process 1234 has been killed.")
mock_print.assert_any_call("Killing model server process with PID 5678")
mock_print.assert_any_call("Process 5678 has been killed.")


if __name__ == "__main__":
unittest.main()

0 comments on commit 25dddcb

Please sign in to comment.