From 25dddcbfd9e02132bc72283b3ce7818190fa90cc Mon Sep 17 00:00:00 2001 From: CTran Date: Thu, 24 Oct 2024 19:21:47 -0700 Subject: [PATCH] fix model server stop process (#217) * fix model server stop process * replace * replace * add test * add multiple pids test * add check install for linux * reformat --- model_server/app/cli.py | 134 +++++++++++------- .../app/tests/test_cli_stop_server.py | 55 +++++++ 2 files changed, 141 insertions(+), 48 deletions(-) create mode 100644 model_server/app/tests/test_cli_stop_server.py diff --git a/model_server/app/cli.py b/model_server/app/cli.py index 6a3f81ea..dd6a5679 100644 --- a/model_server/app/cli.py +++ b/model_server/app/cli.py @@ -15,11 +15,8 @@ log = logging.getLogger("model_server.cli") log.setLevel(logging.INFO) -# Path to the file where the server process ID will be stored -PID_FILE = os.path.join(tempfile.gettempdir(), "model_server.pid") - -def run_server(): +def run_server(port=51000): """Start, stop, or restart the Uvicorn server based on command-line arguments.""" if len(sys.argv) > 1: action = sys.argv[1] @@ -27,22 +24,18 @@ def run_server(): action = "start" if action == "start": - start_server() + start_server(port) elif action == "stop": - stop_server() + stop_server(port) elif action == "restart": - restart_server() + restart_server(port) else: log.info(f"Unknown action: {action}") sys.exit(1) -def start_server(): - """Start the Uvicorn server and save the process ID.""" - if os.path.exists(PID_FILE): - log.info("Server is already running. Use 'model_server restart' to restart it.") - sys.exit(1) - +def start_server(port=51000): + """Start the Uvicorn server""" log.info( "Starting model server - loading some awesomeness, this may take some time :)" ) @@ -55,7 +48,7 @@ def start_server(): "--host", "0.0.0.0", "--port", - "51000", + f"{port}", ], start_new_session=True, bufsize=1, @@ -64,10 +57,7 @@ def start_server(): stderr=subprocess.PIPE, # Suppress standard error. There is a logger that model_server prints to ) - if wait_for_health_check("http://0.0.0.0:51000/healthz"): - # Write the process ID to the PID file - with open(PID_FILE, "w") as f: - f.write(str(process.pid)) + if wait_for_health_check(f"http://0.0.0.0:{port}/healthz"): log.info(f"Model server started with PID {process.pid}") else: # Add model_server boot-up logs @@ -89,40 +79,88 @@ def wait_for_health_check(url, timeout=180): return False -def stop_server(): - """Stop the running Uvicorn server.""" - log.info("Stopping model server") - if not os.path.exists(PID_FILE): - log.info("Process id file not found, seems like model server was not running") - return +def check_and_install_lsof(): + """Check if lsof is installed, and if not, install it using apt-get.""" + try: + # Check if lsof is installed by running "lsof -v" + subprocess.run( + ["lsof", "-v"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + print("lsof is already installed.") + except subprocess.CalledProcessError: + print("lsof not found, installing...") + try: + # Update package list and install lsof + subprocess.run(["sudo", "apt-get", "update"], check=True) + subprocess.run(["sudo", "apt-get", "install", "-y", "lsof"], check=True) + print("lsof installed successfully.") + except subprocess.CalledProcessError as install_error: + print(f"Failed to install lsof: {install_error}") - # Read the process ID from the PID file - with open(PID_FILE, "r") as f: - pid = int(f.read()) +def kill_process(port=51000, wait=True, timeout=10): + """Stop the running Uvicorn server.""" + log.info("Stopping model server") try: - # Get process by PID - process = psutil.Process(pid) - - # Gracefully terminate the process - process.terminate() # Sends SIGTERM by default - process.wait(timeout=10) # Wait for up to 10 seconds for the process to exit - - log.info(f"Model server with PID {pid} stopped.") - os.remove(PID_FILE) - - except psutil.NoSuchProcess: - log.info(f"Model server with PID {pid} not found. Cleaning up PID file.") - os.remove(PID_FILE) - except psutil.TimeoutExpired: - log.info( - f"Model server with PID {pid} did not terminate in time. Forcing shutdown." + # Run the function to check and install lsof if necessary + # Step 1: Run lsof command to get the process using the port + lsof_command = f"lsof -n | grep {port} | grep -i LISTEN" + result = subprocess.run( + lsof_command, shell=True, capture_output=True, text=True ) - process.kill() # Forcefully kill the process - os.remove(PID_FILE) + + if result.returncode != 0: + print(f"No process found listening on port {port}.") + return + + # Step 2: Parse the process IDs from the output + process_ids = [line.split()[1] for line in result.stdout.splitlines()] + + if not process_ids: + print(f"No process found listening on port {port}.") + return + + # Step 3: Kill each process using its PID + for pid in process_ids: + print(f"Killing model server process with PID {pid}") + subprocess.run(f"kill {pid}", shell=True) + + if wait: + # Step 4: Wait for the process to be killed by checking if it's still running + start_time = time.time() + + while True: + check_process = subprocess.run( + f"ps -p {pid}", shell=True, capture_output=True, text=True + ) + if check_process.returncode != 0: + print(f"Process {pid} has been killed.") + break + + elapsed_time = time.time() - start_time + if elapsed_time > timeout: + print( + f"Process {pid} did not terminate within {timeout} seconds." + ) + print(f"Attempting to force kill process {pid}...") + subprocess.run(f"kill -9 {pid}", shell=True) # SIGKILL + break + + print( + f"Waiting for process {pid} to be killed... ({elapsed_time:.2f} seconds)" + ) + time.sleep(0.5) + + except Exception as e: + print(f"Error occurred: {e}") + + +def stop_server(port=51000, wait=True, timeout=10): + check_and_install_lsof() + kill_process(port, wait, timeout) -def restart_server(): +def restart_server(port=51000): """Restart the Uvicorn server.""" - stop_server() - start_server() + stop_server(port) + start_server(port) diff --git a/model_server/app/tests/test_cli_stop_server.py b/model_server/app/tests/test_cli_stop_server.py new file mode 100644 index 00000000..4f3955a7 --- /dev/null +++ b/model_server/app/tests/test_cli_stop_server.py @@ -0,0 +1,55 @@ +import unittest +from unittest.mock import patch, MagicMock +import subprocess +import time +from app.cli import kill_process + + +class TestStopServer(unittest.TestCase): + @patch("subprocess.run") + def test_stop_server_no_process(self, mock_run): + # Mock subprocess.run to simulate no process listening on the port + mock_run.return_value.returncode = 1 + with patch("builtins.print") as mock_print: + kill_process(port=51000) + mock_print.assert_called_with("No process found listening on port 51000.") + + @patch("subprocess.run") + def test_stop_server_process_killed(self, mock_run): + # Simulate lsof returning a process id + mock_run.side_effect = [ + MagicMock(returncode=0, stdout="uvicorn 1234 user LISTEN\n"), + MagicMock(returncode=0), # for killing the process + MagicMock(returncode=1), # for checking the process after it is killed + ] + with patch("builtins.print") as mock_print: + kill_process(port=51000, wait=True, timeout=5) + mock_print.assert_any_call("Killing model server process with PID 1234") + mock_print.assert_any_call("Process 1234 has been killed.") + + @patch("subprocess.run") + def test_stop_server_multiple_pids(self, mock_run): + # Simulate lsof returning multiple process ids (e.g., 1234 and 5678) + mock_run.side_effect = [ + MagicMock( + returncode=0, + stdout="uvicorn 1234 user LISTEN\nuvicorn 5678 user LISTEN\n", + ), # lsof output + MagicMock(returncode=0), # first kill command for PID 1234 + MagicMock(returncode=1), # PID 1234 is successfully terminated + MagicMock(returncode=0), # second kill command for PID 5678 + MagicMock(returncode=1), # PID 5678 is successfully terminated + ] + + with patch("builtins.print") as mock_print: + kill_process(port=51000, wait=True, timeout=5) + + # Assert that the function tried to kill both PIDs + mock_print.assert_any_call("Killing model server process with PID 1234") + mock_print.assert_any_call("Process 1234 has been killed.") + mock_print.assert_any_call("Killing model server process with PID 5678") + mock_print.assert_any_call("Process 5678 has been killed.") + + +if __name__ == "__main__": + unittest.main()