diff --git a/sky/utils/command_runner.py b/sky/utils/command_runner.py index 92d1f2749d7..09ddef0e386 100644 --- a/sky/utils/command_runner.py +++ b/sky/utils/command_runner.py @@ -471,7 +471,7 @@ def __init__( self._ssh_proxy_command = ssh_proxy_command self.disable_control_master = ( disable_control_master or - control_master_utils.should_disable_control_master()) + control_master_utils.should_disable_control_master(ip)) if docker_user is not None: assert port is None or port == 22, ( f'port must be None or 22 for docker_user, got {port}.') @@ -623,6 +623,10 @@ def run( else: command += [f'> {log_path}'] executable = '/bin/bash' + if ssh_mode == SshMode.INTERACTIVE: + # By default we disable stdin in run_with_log to avoid blocking, but + # for interactive mode, we need to enable it. + kwargs['stdin'] = None return log_lib.run_with_log(' '.join(command), log_path, require_outputs=require_outputs, @@ -814,6 +818,10 @@ def run( else: command += [f'> {log_path}'] executable = '/bin/bash' + if ssh_mode == SshMode.INTERACTIVE: + # By default we disable stdin in run_with_log to avoid blocking, but + # for interactive mode, we need to enable it. + kwargs['stdin'] = None return log_lib.run_with_log(' '.join(command), log_path, require_outputs=require_outputs, diff --git a/sky/utils/control_master_utils.py b/sky/utils/control_master_utils.py index d645014c417..6798ef56be0 100644 --- a/sky/utils/control_master_utils.py +++ b/sky/utils/control_master_utils.py @@ -1,12 +1,19 @@ """Utils to check if the ssh control master should be disabled.""" import functools +import subprocess from sky import sky_logging from sky.utils import subprocess_utils logger = sky_logging.init_logger(__name__) +# The maximum number of concurrent ssh connections to a same node. This is a +# heuristic value, based on the observation that when the number of concurrent +# ssh connections to a node with control master is high, new connections through +# control master will hang. +_MAX_CONCURRENT_SSH_CONNECTIONS = 32 + def is_tmp_9p_filesystem() -> bool: """Check if the /tmp filesystem is 9p. @@ -34,16 +41,53 @@ def is_tmp_9p_filesystem() -> bool: return filesystem_types[1].lower() == '9p' -@functools.lru_cache -def should_disable_control_master() -> bool: +def should_disable_control_master(ip: str) -> bool: """Whether disable ssh control master based on file system. + Args: + ip: The ip address of the node. + Returns: bool: True if the ssh control master should be disabled, False otherwise. """ - if is_tmp_9p_filesystem(): + if is_unsupported_filesystem(): + return True + if is_high_concurrency(ip): return True # there may be additional criteria to disable ssh control master # in the future. They should be checked here return False + + +@functools.lru_cache(maxsize=1) +def is_unsupported_filesystem(): + """Determine if the filesystem is unsupported.""" + return is_tmp_9p_filesystem() + + +def is_high_concurrency(ip: str) -> bool: + """Determine if the node has high concurrent ssh connections. + + Args: + ip: The IP address to check + threshold: Maximum number of allowed concurrent SSH connections + + Returns: + bool: True if number of concurrent SSH connections exceeds threshold + """ + try: + # Use pgrep to efficiently find ssh processes and pipe to grep for IP + cmd = f'pgrep -f ssh | xargs -r ps -p | grep -c {ip}' + proc = subprocess.run(cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False) + if proc.returncode == 0: + count = int(proc.stdout.strip()) + return count >= _MAX_CONCURRENT_SSH_CONNECTIONS + return False + except (subprocess.SubprocessError, ValueError): + return False