Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Avoid high concurrency issue with control master #4455

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sky/utils/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ def __init__(
self._ssh_proxy_command = ssh_proxy_command
self.disable_control_master = (
disable_control_master or
control_master_utils.should_disable_control_master())
control_master_utils.should_disable_control_master(ip))
if docker_user is not None:
assert port is None or port == 22, (
f'port must be None or 22 for docker_user, got {port}.')
Expand Down Expand Up @@ -623,6 +623,10 @@ def run(
else:
command += [f'> {log_path}']
executable = '/bin/bash'
if ssh_mode == SshMode.INTERACTIVE:
# By default we disable stdin in run_with_log to avoid blocking, but
# for interactive mode, we need to enable it.
kwargs['stdin'] = None
return log_lib.run_with_log(' '.join(command),
log_path,
require_outputs=require_outputs,
Expand Down Expand Up @@ -814,6 +818,10 @@ def run(
else:
command += [f'> {log_path}']
executable = '/bin/bash'
if ssh_mode == SshMode.INTERACTIVE:
# By default we disable stdin in run_with_log to avoid blocking, but
# for interactive mode, we need to enable it.
kwargs['stdin'] = None
return log_lib.run_with_log(' '.join(command),
log_path,
require_outputs=require_outputs,
Expand Down
50 changes: 47 additions & 3 deletions sky/utils/control_master_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
"""Utils to check if the ssh control master should be disabled."""

import functools
import subprocess

from sky import sky_logging
from sky.utils import subprocess_utils

logger = sky_logging.init_logger(__name__)

# The maximum number of concurrent ssh connections to a same node. This is a
# heuristic value, based on the observation that when the number of concurrent
# ssh connections to a node with control master is high, new connections through
# control master will hang.
_MAX_CONCURRENT_SSH_CONNECTIONS = 32
Comment on lines +11 to +15
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you test if this is related to number of cpu?



def is_tmp_9p_filesystem() -> bool:
"""Check if the /tmp filesystem is 9p.
Expand Down Expand Up @@ -34,16 +41,53 @@ def is_tmp_9p_filesystem() -> bool:
return filesystem_types[1].lower() == '9p'


@functools.lru_cache
def should_disable_control_master() -> bool:
def should_disable_control_master(ip: str) -> bool:
"""Whether disable ssh control master based on file system.

Args:
ip: The ip address of the node.

Returns:
bool: True if the ssh control master should be disabled,
False otherwise.
"""
if is_tmp_9p_filesystem():
if is_unsupported_filesystem():
return True
if is_high_concurrency(ip):
return True
# there may be additional criteria to disable ssh control master
# in the future. They should be checked here
return False


@functools.lru_cache(maxsize=1)
def is_unsupported_filesystem():
"""Determine if the filesystem is unsupported."""
return is_tmp_9p_filesystem()


def is_high_concurrency(ip: str) -> bool:
"""Determine if the node has high concurrent ssh connections.

Args:
ip: The IP address to check
threshold: Maximum number of allowed concurrent SSH connections
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stale? seems now the threshold is a constant (but not an arg)


Returns:
bool: True if number of concurrent SSH connections exceeds threshold
"""
try:
# Use pgrep to efficiently find ssh processes and pipe to grep for IP
cmd = f'pgrep -f ssh | xargs -r ps -p | grep -c {ip}'
Comment on lines +80 to +81
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that only tracks the SSH command from local machine? How about the connections from other laptops

proc = subprocess.run(cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=False)
if proc.returncode == 0:
count = int(proc.stdout.strip())
return count >= _MAX_CONCURRENT_SSH_CONNECTIONS
return False
except (subprocess.SubprocessError, ValueError):
return False
Loading