Skip to content

Commit

Permalink
Startup: Check if the port is available and fallback
Browse files Browse the repository at this point in the history
Similar to Gradio, fall back to port + 1 if the config port isn't
bindable. If both ports aren't available, let the user know and exit.
An infinite loop of finding a port isn't advisable.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
kingbri1 committed Mar 12, 2024
1 parent 7c6fd7a commit 894be4a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
12 changes: 12 additions & 0 deletions common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Common utility functions"""

import socket
import traceback
from loguru import logger
from pydantic import BaseModel
Expand Down Expand Up @@ -67,3 +68,14 @@ def prune_dict(input_dict):
"""Trim out instances of None from a dictionary"""

return {k: v for k, v in input_dict.items() if v is not None}


def is_port_in_use(port: int) -> bool:
"""
Checks if a port is in use
From https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python
"""

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0
26 changes: 23 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
get_generator_error,
handle_request_error,
load_progress,
is_port_in_use,
unwrap,
)
from OAI.types.completion import CompletionRequest
Expand Down Expand Up @@ -732,6 +733,28 @@ def entrypoint(args: Optional[dict] = None):

network_config = get_network_config()

host = unwrap(network_config.get("host"), "127.0.0.1")
port = unwrap(network_config.get("port"), 5000)

# Check if the port is available and attempt to bind a fallback
if is_port_in_use(port):
fallback_port = port + 1

if is_port_in_use(fallback_port):
logger.error(
f"Ports {port} and {fallback_port} are in use by different services.\n"
"Please free up those ports or specify a different one.\n"
"Exiting."
)

return
else:
logger.warning(
f"Port {port} is currently in use. Switching to {fallback_port}."
)

port = fallback_port

# Initialize auth keys
load_auth_keys(unwrap(network_config.get("disable_auth"), False))

Expand Down Expand Up @@ -788,9 +811,6 @@ def entrypoint(args: Optional[dict] = None):
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)

host = unwrap(network_config.get("host"), "127.0.0.1")
port = unwrap(network_config.get("port"), 5000)

# TODO: Replace this with abortables, async via producer consumer, or something else
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)

Expand Down

0 comments on commit 894be4a

Please sign in to comment.