diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1bebbf8..006cb5c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,3 +21,13 @@ jobs: python-version: ${{ matrix.python-version }} - run: pip install -r requirements-travis.txt - run: make check + + static-checks: + runs-on: ubuntu-latest + steps: + - name: Check out repository code + uses: actions/checkout@v4 + - name: run static checks + uses: avocado-framework/avocado-ci-tools@main + with: + avocado-static-checks: true diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..1ceb9f9 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "avocado-static-checks"] + path = static-checks + url = ../avocado-static-checks + branch = main diff --git a/aexpect/__init__.py b/aexpect/__init__.py index 6501f94..3e32262 100644 --- a/aexpect/__init__.py +++ b/aexpect/__init__.py @@ -14,23 +14,24 @@ entry-points. """ -from .exceptions import ExpectError -from .exceptions import ExpectProcessTerminatedError -from .exceptions import ExpectTimeoutError -from .exceptions import ShellCmdError -from .exceptions import ShellError -from .exceptions import ShellProcessTerminatedError -from .exceptions import ShellStatusError -from .exceptions import ShellTimeoutError - -from .client import Spawn -from .client import Tail -from .client import Expect -from .client import ShellSession -from .client import kill_tail_threads -from .client import run_tail -from .client import run_bg -from .client import run_fg - -from . import remote -from . import rss_client +from . import remote, rss_client +from .client import ( + Expect, + ShellSession, + Spawn, + Tail, + kill_tail_threads, + run_bg, + run_fg, + run_tail, +) +from .exceptions import ( + ExpectError, + ExpectProcessTerminatedError, + ExpectTimeoutError, + ShellCmdError, + ShellError, + ShellProcessTerminatedError, + ShellStatusError, + ShellTimeoutError, +) diff --git a/aexpect/client.py b/aexpect/client.py index d686f59..117c345 100644 --- a/aexpect/client.py +++ b/aexpect/client.py @@ -17,38 +17,39 @@ # disable too-many-* as we need them pylint: disable=R0902,R0913,R0914,C0302 -import time -import signal +import locale +import logging import os import re -import threading -import shutil import select +import shutil +import signal import subprocess -import locale -import logging +import threading +import time -from aexpect.exceptions import ExpectError -from aexpect.exceptions import ExpectProcessTerminatedError -from aexpect.exceptions import ExpectTimeoutError -from aexpect.exceptions import ShellCmdError -from aexpect.exceptions import ShellError -from aexpect.exceptions import ShellProcessTerminatedError -from aexpect.exceptions import ShellStatusError -from aexpect.exceptions import ShellTimeoutError - -from aexpect.shared import BASE_DIR -from aexpect.shared import get_filenames -from aexpect.shared import get_reader_filename -from aexpect.shared import get_lock_fd -from aexpect.shared import is_file_locked -from aexpect.shared import unlock_fd -from aexpect.shared import wait_for_lock - -from aexpect.utils import astring -from aexpect.utils import data_factory -from aexpect.utils import process as utils_process +from aexpect.exceptions import ( + ExpectError, + ExpectProcessTerminatedError, + ExpectTimeoutError, + ShellCmdError, + ShellError, + ShellProcessTerminatedError, + ShellStatusError, + ShellTimeoutError, +) +from aexpect.shared import ( + BASE_DIR, + get_filenames, + get_lock_fd, + get_reader_filename, + is_file_locked, + unlock_fd, + wait_for_lock, +) +from aexpect.utils import astring, data_factory from aexpect.utils import path as utils_path +from aexpect.utils import process as utils_process from aexpect.utils import wait as utils_wait _THREAD_KILL_REQUESTED = threading.Event() @@ -105,8 +106,16 @@ class Spawn: resumes _tail() if needed. """ - def __init__(self, command=None, a_id=None, auto_close=False, echo=False, - linesep="\n", pass_fds=(), encoding=None): + def __init__( + self, + command=None, + a_id=None, + auto_close=False, + echo=False, + linesep="\n", + pass_fds=(), + encoding=None, + ): """ Initialize the class and run command as a child process. @@ -140,14 +149,16 @@ def __init__(self, command=None, a_id=None, auto_close=False, echo=False, # Define filenames for communication with server utils_path.init_dir(base_dir) - (self.shell_pid_filename, - self.status_filename, - self.output_filename, - self.inpipe_filename, - self.ctrlpipe_filename, - self.lock_server_running_filename, - self.lock_client_starting_filename, - self.server_log_filename) = get_filenames(base_dir) + ( + self.shell_pid_filename, + self.status_filename, + self.output_filename, + self.inpipe_filename, + self.ctrlpipe_filename, + self.lock_server_running_filename, + self.lock_client_starting_filename, + self.server_log_filename, + ) = get_filenames(base_dir) assert os.path.isdir(base_dir) @@ -167,8 +178,8 @@ def __init__(self, command=None, a_id=None, auto_close=False, echo=False, # Define the reader filenames self.reader_filenames = dict( - (reader, get_reader_filename(base_dir, reader)) - for reader in self.readers) + (reader, get_reader_filename(base_dir, reader)) for reader in self.readers + ) # Let the server know a client intends to open some pipes; # if the executed command terminates quickly, the server will wait for @@ -177,13 +188,16 @@ def __init__(self, command=None, a_id=None, auto_close=False, echo=False, # Start the server (which runs the command) if command: - helper_cmd = utils_path.find_command('aexpect_helper') - self._aexpect_helper = subprocess.Popen([helper_cmd], # pylint: disable=R1732 - shell=True, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - pass_fds=pass_fds) + helper_cmd = utils_path.find_command("aexpect_helper") + # pylint: disable=R1732 + self._aexpect_helper = subprocess.Popen( + [helper_cmd], + shell=True, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + pass_fds=pass_fds, + ) sub = self._aexpect_helper # Send parameters to the server sub.stdin.write(f"{self.a_id}\n".encode(self.encoding)) @@ -193,8 +207,9 @@ def __init__(self, command=None, a_id=None, auto_close=False, echo=False, sub.stdin.write(f"{command}\n".encode(self.encoding)) sub.stdin.flush() # Wait for the server to complete its initialization - while (f"Server {self.a_id} ready" not in - sub.stdout.readline().decode(self.encoding, "ignore")): + while f"Server {self.a_id} ready" not in sub.stdout.readline().decode( + self.encoding, "ignore" + ): pass # Open the reading pipes @@ -205,8 +220,7 @@ def __init__(self, command=None, a_id=None, auto_close=False, echo=False, except OSError: LOG.warning("Failed to open reader '%s'", filename) else: - LOG.warning("Not opening readers, lock_server_running " - "not locked") + LOG.warning("Not opening readers, lock_server_running not locked") # Allow the server to continue unlock_fd(lock_client_starting) @@ -313,8 +327,7 @@ def get_pid(self): command. """ try: - with open(self.shell_pid_filename, 'r', - encoding='utf-8') as pid_file: + with open(self.shell_pid_filename, "r", encoding="utf-8") as pid_file: try: return int(pid_file.read()) except ValueError: @@ -329,8 +342,7 @@ def get_status(self): """ wait_for_lock(self.lock_server_running_filename) try: - with open(self.status_filename, 'r', - encoding='utf-8') as status_file: + with open(self.status_filename, "r", encoding="utf-8") as status_file: try: return int(status_file.read()) except ValueError: @@ -343,9 +355,8 @@ def get_output(self): Return the STDOUT and STDERR output of the process so far. """ try: - with open(self.output_filename, 'rb') as output_file: - return output_file.read().decode(self.encoding, - 'backslashreplace') + with open(self.output_filename, "rb") as output_file: + return output_file.read().decode(self.encoding, "backslashreplace") except IOError: return None @@ -393,8 +404,8 @@ def close(self, sig=signal.SIGKILL): self._close_reader_fds() self.reader_fds = {} # Remove all used files - if 'AEXPECT_DEBUG' not in os.environ: - shutil.rmtree(os.path.join(BASE_DIR, f'aexpect_{self.a_id}')) + if "AEXPECT_DEBUG" not in os.environ: + shutil.rmtree(os.path.join(BASE_DIR, f"aexpect_{self.a_id}")) self._close_aexpect_helper() self.closed = True @@ -440,13 +451,20 @@ def sendcontrol(self, char): if 97 <= val <= 122: val = val - 97 + 1 # ctrl+a = '\0x01' return self.send(chr(val)) - mapping = {'@': 0, '`': 0, - '[': 27, '{': 27, - '\\': 28, '|': 28, - ']': 29, '}': 29, - '^': 30, '~': 30, - '_': 31, - '?': 127} + mapping = { + "@": 0, + "`": 0, + "[": 27, + "{": 27, + "\\": 28, + "|": 28, + "]": 29, + "}": 29, + "^": 30, + "~": 30, + "_": 31, + "?": 127, + } return self.send(chr(mapping[char])) def send_ctrl(self, control_str=""): @@ -487,10 +505,22 @@ class Tail(Spawn): When this class is unpickled, it automatically resumes reporting output. """ - def __init__(self, command=None, a_id=None, auto_close=False, echo=False, - linesep="\n", termination_func=None, termination_params=(), - output_func=None, output_params=(), output_prefix="", - thread_name=None, pass_fds=(), encoding=None): + def __init__( + self, + command=None, + a_id=None, + auto_close=False, + echo=False, + linesep="\n", + termination_func=None, + termination_params=(), + output_func=None, + output_params=(), + output_prefix="", + thread_name=None, + pass_fds=(), + encoding=None, + ): """ Initialize the class and run command as a child process. @@ -527,8 +557,7 @@ def __init__(self, command=None, a_id=None, auto_close=False, echo=False, self._add_close_hook(Tail._join_thread) # Init the superclass - super().__init__(command, a_id, auto_close, echo, linesep, - pass_fds, encoding) + super().__init__(command, a_id, auto_close, echo, linesep, pass_fds, encoding) if thread_name is None: self.thread_name = f"tail_thread_{self.a_id}_{str(command)[:10]}" else: @@ -551,12 +580,14 @@ def __reduce__(self): return self.__class__, (self.__getinitargs__()) def __getinitargs__(self): - return Spawn.__getinitargs__(self) + (self.termination_func, - self.termination_params, - self.output_func, - self.output_params, - self.output_prefix, - self.thread_name) + return Spawn.__getinitargs__(self) + ( + self.termination_func, + self.termination_params, + self.output_func, + self.output_params, + self.output_prefix, + self.thread_name, + ) def set_termination_func(self, termination_func): """ @@ -607,8 +638,9 @@ def set_output_prefix(self, output_prefix): """ self.output_prefix = output_prefix - def _tail(self): # speed optimization pylint: disable=too-many-branches,too-many-statements - + def _tail( + self, + ): # speed optimization pylint: disable=too-many-branches,too-many-statements def _print_line(text): # Pre-pend prefix and remove trailing whitespace text = self.output_prefix + text.rstrip() @@ -619,8 +651,7 @@ def _print_line(text): else: self.output_func(text) except TypeError: - LOG.warning("Failed to print_line '%s' '%s'", - self.output_params, text) + LOG.warning("Failed to print_line '%s' '%s'", self.output_params, text) try: tail_pipe = self._get_fd("tail") @@ -656,7 +687,7 @@ def _print_line(text): _print_line(line) # Leave only the last line last_newline_index = bfr.rfind("\n") - bfr = bfr[last_newline_index + 1:] + bfr = bfr[last_newline_index + 1 :] else: # No output is available right now; flush the bfr if bfr: @@ -676,14 +707,12 @@ def _print_line(text): params = self.termination_params + (status,) self.termination_func(*params) except TypeError: - LOG.warning("Termination function execution failure '%s'", - params) + LOG.warning("Termination function execution failure '%s'", params) finally: self.tail_thread = None def _start_thread(self): - self.tail_thread = threading.Thread(target=self._tail, - name=self.thread_name) + self.tail_thread = threading.Thread(target=self._tail, name=self.thread_name) self.tail_thread.start() def _join_thread(self): @@ -704,10 +733,22 @@ class Expect(Tail): It also provides all of Tail's functionality. """ - def __init__(self, command=None, a_id=None, auto_close=True, echo=False, - linesep="\n", termination_func=None, termination_params=(), - output_func=None, output_params=(), output_prefix="", - thread_name=None, pass_fds=(), encoding=None): + def __init__( + self, + command=None, + a_id=None, + auto_close=True, + echo=False, + linesep="\n", + termination_func=None, + termination_params=(), + output_func=None, + output_params=(), + output_prefix="", + thread_name=None, + pass_fds=(), + encoding=None, + ): """ Initialize the class and run command as a child process. @@ -742,10 +783,21 @@ def __init__(self, command=None, a_id=None, auto_close=True, echo=False, self._add_reader("expect") # Init the superclass - super().__init__(command, a_id, auto_close, echo, linesep, - termination_func, termination_params, - output_func, output_params, output_prefix, thread_name, - pass_fds, encoding) + super().__init__( + command, + a_id, + auto_close, + echo, + linesep, + termination_func, + termination_params, + output_func, + output_params, + output_prefix, + thread_name, + pass_fds, + encoding, + ) def __reduce__(self): return self.__class__, (self.__getinitargs__()) @@ -843,9 +895,15 @@ def match_patterns_multiline(cont, patterns): return i return None - def read_until_output_matches(self, patterns, filter_func=lambda x: x, - timeout=60.0, internal_timeout=None, - print_func=None, match_func=None): + def read_until_output_matches( + self, + patterns, + filter_func=lambda x: x, + timeout=60.0, + internal_timeout=None, + print_func=None, + match_func=None, + ): """ Read from child using read_nonblocking until a pattern matches. @@ -887,8 +945,9 @@ def read_until_output_matches(self, patterns, filter_func=lambda x: x, if not poll_status: raise ExpectTimeoutError(patterns, output) # Read data from child - read, data = self._read_nonblocking(internal_timeout, - end_time - time.time()) + read, data = self._read_nonblocking( + internal_timeout, end_time - time.time() + ) if not read: break if not data: @@ -905,13 +964,13 @@ def read_until_output_matches(self, patterns, filter_func=lambda x: x, # Check if the child has terminated if utils_wait.wait_for(lambda: not self.is_alive(), 5, 0, 0.1): - raise ExpectProcessTerminatedError(patterns, self.get_status(), - output) + raise ExpectProcessTerminatedError(patterns, self.get_status(), output) # This shouldn't happen raise ExpectError(patterns, output) - def read_until_last_word_matches(self, patterns, timeout=60.0, - internal_timeout=None, print_func=None): + def read_until_last_word_matches( + self, patterns, timeout=60.0, internal_timeout=None, print_func=None + ): """ Read using read_nonblocking until the last word of the output matches one of the patterns (using match_patterns), or until timeout expires. @@ -934,12 +993,13 @@ def _get_last_word(cont): return cont.split()[-1] return "" - return self.read_until_output_matches(patterns, _get_last_word, - timeout, internal_timeout, - print_func) + return self.read_until_output_matches( + patterns, _get_last_word, timeout, internal_timeout, print_func + ) - def read_until_last_line_matches(self, patterns, timeout=60.0, - internal_timeout=None, print_func=None): + def read_until_last_line_matches( + self, patterns, timeout=60.0, internal_timeout=None, print_func=None + ): """ Read until the last non-empty line matches a pattern. @@ -967,13 +1027,13 @@ def _get_last_nonempty_line(cont): return nonempty_lines[-1] return "" - return self.read_until_output_matches(patterns, - _get_last_nonempty_line, - timeout, internal_timeout, - print_func) + return self.read_until_output_matches( + patterns, _get_last_nonempty_line, timeout, internal_timeout, print_func + ) - def read_until_any_line_matches(self, patterns, timeout=60.0, - internal_timeout=None, print_func=None): + def read_until_any_line_matches( + self, patterns, timeout=60.0, internal_timeout=None, print_func=None + ): """ Read using read_nonblocking until any line matches a pattern. @@ -995,11 +1055,14 @@ def read_until_any_line_matches(self, patterns, timeout=60.0, terminates while waiting for output :raise ExpectError: Raised if an unknown error occurs """ - return self.read_until_output_matches(patterns, - lambda x: x.splitlines(), - timeout, internal_timeout, - print_func, - self.match_patterns_multiline) + return self.read_until_output_matches( + patterns, + lambda x: x.splitlines(), + timeout, + internal_timeout, + print_func, + self.match_patterns_multiline, + ) class ShellSession(Expect): @@ -1016,11 +1079,24 @@ class ShellSession(Expect): # Return code pattern of shell interpreter __RE_STATUS = re.compile("^-?[0-9]+$") - def __init__(self, command=None, a_id=None, auto_close=True, echo=False, - linesep="\n", termination_func=None, termination_params=(), - output_func=None, output_params=(), output_prefix="", - thread_name=None, prompt=r"[\#\$]\s*$", - status_test_command="echo $?", pass_fds=(), encoding=None): + def __init__( + self, + command=None, + a_id=None, + auto_close=True, + echo=False, + linesep="\n", + termination_func=None, + termination_params=(), + output_func=None, + output_params=(), + output_prefix="", + thread_name=None, + prompt=r"[\#\$]\s*$", + status_test_command="echo $?", + pass_fds=(), + encoding=None, + ): """ Initialize the class and run command as a child process. @@ -1056,10 +1132,21 @@ def __init__(self, command=None, a_id=None, auto_close=True, echo=False, locale.getpreferredencoding()) """ # Init the superclass - super().__init__(command, a_id, auto_close, echo, linesep, - termination_func, termination_params, - output_func, output_params, output_prefix, thread_name, - pass_fds, encoding) + super().__init__( + command, + a_id, + auto_close, + echo, + linesep, + termination_func, + termination_params, + output_func, + output_params, + output_prefix, + thread_name, + pass_fds, + encoding, + ) # Remember some attributes self.prompt = prompt @@ -1069,8 +1156,7 @@ def __reduce__(self): return self.__class__, (self.__getinitargs__()) def __getinitargs__(self): - return Expect.__getinitargs__(self) + (self.prompt, - self.status_test_command) + return Expect.__getinitargs__(self) + (self.prompt, self.status_test_command) @classmethod def remove_command_echo(cls, cont, cmd): @@ -1130,8 +1216,7 @@ def is_responsive(self, timeout=5.0): # No output -- report unresponsive return False - def read_up_to_prompt(self, timeout=60.0, internal_timeout=None, - print_func=None): + def read_up_to_prompt(self, timeout=60.0, internal_timeout=None, print_func=None): """ Read until the last non-empty line matches the prompt. @@ -1151,12 +1236,13 @@ def read_up_to_prompt(self, timeout=60.0, internal_timeout=None, terminates while waiting for output :raise ExpectError: Raised if an unknown error occurs """ - return self.read_until_last_line_matches([self.prompt], timeout, - internal_timeout, - print_func)[1] + return self.read_until_last_line_matches( + [self.prompt], timeout, internal_timeout, print_func + )[1] - def cmd_output(self, cmd, timeout=60, internal_timeout=None, - print_func=None, safe=False): + def cmd_output( + self, cmd, timeout=60, internal_timeout=None, print_func=None, safe=False + ): """ Send a command and return its output. @@ -1197,8 +1283,7 @@ def cmd_output(self, cmd, timeout=60, internal_timeout=None, raise ShellError(cmd, output) from error # Remove the echoed command and the final shell prompt - return self.remove_last_nonempty_line(self.remove_command_echo(out, - cmd)) + return self.remove_last_nonempty_line(self.remove_command_echo(out, cmd)) def cmd_output_safe(self, cmd, timeout=60): """ @@ -1236,8 +1321,7 @@ def cmd_output_safe(self, cmd, timeout=60): self.sendline() except ExpectProcessTerminatedError as error: output = self.remove_command_echo(f"{out}{error.output}", cmd) - raise ShellProcessTerminatedError(cmd, error.status, - output) from error + raise ShellProcessTerminatedError(cmd, error.status, output) from error except ExpectError as error: output = self.remove_command_echo(f"{out}{error.output}", cmd) raise ShellError(cmd, output) from error @@ -1246,11 +1330,11 @@ def cmd_output_safe(self, cmd, timeout=60): raise ShellTimeoutError(cmd, out) # Remove the echoed command and the final shell prompt - return self.remove_last_nonempty_line(self.remove_command_echo(out, - cmd)) + return self.remove_last_nonempty_line(self.remove_command_echo(out, cmd)) - def cmd_status_output(self, cmd, timeout=60, internal_timeout=None, - print_func=None, safe=False): + def cmd_status_output( + self, cmd, timeout=60, internal_timeout=None, print_func=None, safe=False + ): """ Send a command and return its exit status and output. @@ -1277,20 +1361,23 @@ def cmd_status_output(self, cmd, timeout=60, internal_timeout=None, out = self.cmd_output(cmd, timeout, internal_timeout, print_func, safe) try: # Send the 'echo $?' (or equivalent) command to get the exit status - status = self.cmd_output(self.status_test_command, 30, - internal_timeout, print_func, safe) + status = self.cmd_output( + self.status_test_command, 30, internal_timeout, print_func, safe + ) except ShellError as error: raise ShellStatusError(cmd, out) from error # Get the first line consisting of digits only - digit_lines = [_ for _ in status.splitlines() - if self.__RE_STATUS.match(_.strip())] + digit_lines = [ + _ for _ in status.splitlines() if self.__RE_STATUS.match(_.strip()) + ] if digit_lines: return int(digit_lines[0].strip()), out raise ShellStatusError(cmd, out) - def cmd_status(self, cmd, timeout=60, internal_timeout=None, - print_func=None, safe=False): + def cmd_status( + self, cmd, timeout=60, internal_timeout=None, print_func=None, safe=False + ): """ Send a command and return its exit status. @@ -1313,11 +1400,19 @@ def cmd_status(self, cmd, timeout=60, internal_timeout=None, :raise ShellStatusError: Raised if the exit status cannot be obtained :raise ShellError: Raised if an unknown error occurs """ - return self.cmd_status_output(cmd, timeout, internal_timeout, - print_func, safe)[0] + return self.cmd_status_output(cmd, timeout, internal_timeout, print_func, safe)[ + 0 + ] - def cmd(self, cmd, timeout=60, internal_timeout=None, print_func=None, - ok_status=None, ignore_all_errors=False): + def cmd( + self, + cmd, + timeout=60, + internal_timeout=None, + print_func=None, + ok_status=None, + ignore_all_errors=False, + ): """ Send a command and return its output. If the command's exit status is nonzero, raise an exception. @@ -1344,11 +1439,13 @@ def cmd(self, cmd, timeout=60, internal_timeout=None, print_func=None, :raise ShellCmdError: Raised if the exit status is nonzero """ if ok_status is None: - ok_status = [0, ] + ok_status = [ + 0, + ] try: - status, output = self.cmd_status_output(cmd, timeout, - internal_timeout, - print_func) + status, output = self.cmd_status_output( + cmd, timeout, internal_timeout, print_func + ) if status not in ok_status: raise ShellCmdError(cmd, status, output) return output @@ -1357,23 +1454,25 @@ def cmd(self, cmd, timeout=60, internal_timeout=None, print_func=None, return None raise - def get_command_output(self, cmd, timeout=60, internal_timeout=None, - print_func=None): + def get_command_output( + self, cmd, timeout=60, internal_timeout=None, print_func=None + ): """ Alias for cmd_output() for backward compatibility. """ return self.cmd_output(cmd, timeout, internal_timeout, print_func) - def get_command_status_output(self, cmd, timeout=60, internal_timeout=None, - print_func=None): + def get_command_status_output( + self, cmd, timeout=60, internal_timeout=None, print_func=None + ): """ Alias for cmd_status_output() for backward compatibility. """ - return self.cmd_status_output(cmd, timeout, internal_timeout, - print_func) + return self.cmd_status_output(cmd, timeout, internal_timeout, print_func) - def get_command_status(self, cmd, timeout=60, internal_timeout=None, - print_func=None): + def get_command_status( + self, cmd, timeout=60, internal_timeout=None, print_func=None + ): """ Alias for cmd_status() for backward compatibility. """ @@ -1389,14 +1488,29 @@ class RemoteSession(ShellSession): connection attributes like client, host, port, username, and password. """ - def __init__(self, command=None, a_id=None, auto_close=True, echo=False, - linesep="\n", termination_func=None, termination_params=(), - output_func=None, output_params=(), output_prefix="", - thread_name=None, prompt=r"[\#\$]\s*$", - status_test_command="echo $?", - client="ssh", host="localhost", port=22, - username="root", password="test1234", - pass_fds=(), encoding=None): + def __init__( + self, + command=None, + a_id=None, + auto_close=True, + echo=False, + linesep="\n", + termination_func=None, + termination_params=(), + output_func=None, + output_params=(), + output_prefix="", + thread_name=None, + prompt=r"[\#\$]\s*$", + status_test_command="echo $?", + client="ssh", + host="localhost", + port=22, + username="root", + password="test1234", + pass_fds=(), + encoding=None, + ): """ Initialize the class and run command as a child process. @@ -1437,11 +1551,23 @@ def __init__(self, command=None, a_id=None, auto_close=True, echo=False, locale.getpreferredencoding()) """ # Init the superclass - super().__init__(command, a_id, auto_close, echo, linesep, - termination_func, termination_params, - output_func, output_params, output_prefix, thread_name, - prompt, status_test_command, - pass_fds, encoding) + super().__init__( + command, + a_id, + auto_close, + echo, + linesep, + termination_func, + termination_params, + output_func, + output_params, + output_prefix, + thread_name, + prompt, + status_test_command, + pass_fds, + encoding, + ) # Remember some attributes self.client = client @@ -1451,9 +1577,16 @@ def __init__(self, command=None, a_id=None, auto_close=True, echo=False, self.password = password -def run_tail(command, termination_func=None, output_func=None, - output_prefix="", timeout=1.0, auto_close=True, pass_fds=(), - encoding=None): +def run_tail( + command, + termination_func=None, + output_func=None, + output_prefix="", + timeout=1.0, + auto_close=True, + pass_fds=(), + encoding=None, +): """ Run a subprocess in the background and collect its output and exit status. @@ -1480,13 +1613,15 @@ def run_tail(command, termination_func=None, output_func=None, :return: A Expect object. """ - bg_process = Tail(command=command, - termination_func=termination_func, - output_func=output_func, - output_prefix=output_prefix, - auto_close=auto_close, - pass_fds=pass_fds, - encoding=encoding) + bg_process = Tail( + command=command, + termination_func=termination_func, + output_func=output_func, + output_prefix=output_prefix, + auto_close=auto_close, + pass_fds=pass_fds, + encoding=encoding, + ) end_time = time.time() + timeout while time.time() < end_time and bg_process.is_alive(): @@ -1495,8 +1630,16 @@ def run_tail(command, termination_func=None, output_func=None, return bg_process -def run_bg(command, termination_func=None, output_func=None, output_prefix="", - timeout=1.0, auto_close=True, pass_fds=(), encoding=None): +def run_bg( + command, + termination_func=None, + output_func=None, + output_prefix="", + timeout=1.0, + auto_close=True, + pass_fds=(), + encoding=None, +): """ Run a subprocess in the background and collect its output and exit status. @@ -1523,13 +1666,15 @@ def run_bg(command, termination_func=None, output_func=None, output_prefix="", :return: A Expect object. """ - bg_process = Expect(command=command, - termination_func=termination_func, - output_func=output_func, - output_prefix=output_prefix, - auto_close=auto_close, - pass_fds=pass_fds, - encoding=encoding) + bg_process = Expect( + command=command, + termination_func=termination_func, + output_func=output_func, + output_prefix=output_prefix, + auto_close=auto_close, + pass_fds=pass_fds, + encoding=encoding, + ) end_time = time.time() + timeout while time.time() < end_time and bg_process.is_alive(): @@ -1538,8 +1683,9 @@ def run_bg(command, termination_func=None, output_func=None, output_prefix="", return bg_process -def run_fg(command, output_func=None, output_prefix="", timeout=1.0, - pass_fds=(), encoding=None): +def run_fg( + command, output_func=None, output_prefix="", timeout=1.0, pass_fds=(), encoding=None +): """ Run a subprocess in the foreground and collect its output and exit status. @@ -1564,8 +1710,15 @@ def run_fg(command, output_func=None, output_prefix="", timeout=1.0, STDOUT/STDERR output. If timeout expires before the process terminates, the returned status is None. """ - bg_process = run_bg(command, None, output_func, output_prefix, timeout, - pass_fds=pass_fds, encoding=encoding) + bg_process = run_bg( + command, + None, + output_func, + output_prefix, + timeout, + pass_fds=pass_fds, + encoding=encoding, + ) output = bg_process.get_output() if bg_process.is_alive(): status = None diff --git a/aexpect/exceptions.py b/aexpect/exceptions.py index 8d3eace..a775abe 100644 --- a/aexpect/exceptions.py +++ b/aexpect/exceptions.py @@ -27,8 +27,10 @@ def _pattern_str(self): return f"patterns {self.patterns!r}" def __str__(self): - return ("Unknown error occurred while looking for " - f"{self._pattern_str()} (output: {self.output!r})") + return ( + "Unknown error occurred while looking for " + f"{self._pattern_str()} (output: {self.output!r})" + ) class ExpectTimeoutError(ExpectError): @@ -36,8 +38,10 @@ class ExpectTimeoutError(ExpectError): """Timeout when looking for output""" def __str__(self): - return ("Timeout expired while looking for " - f"{self._pattern_str()} (output: {self.output!r})") + return ( + "Timeout expired while looking for " + f"{self._pattern_str()} (output: {self.output!r})" + ) class ExpectProcessTerminatedError(ExpectError): @@ -49,9 +53,11 @@ def __init__(self, patterns, status, output): self.status = status def __str__(self): - return ("Process terminated while looking for " - f"{self._pattern_str()} (status: {self.status!r}, " - f"output: {self.output!r})") + return ( + "Process terminated while looking for " + f"{self._pattern_str()} (status: {self.status!r}, " + f"output: {self.output!r})" + ) class ShellError(Exception): @@ -64,8 +70,10 @@ def __init__(self, cmd, output): self.output = output def __str__(self): - return (f"Could not execute shell command {self.cmd!r} " - "(output: {self.output!r})") + return ( + f"Could not execute shell command {self.cmd!r} " + "(output: {self.output!r})" + ) class ShellTimeoutError(ShellError): @@ -73,8 +81,10 @@ class ShellTimeoutError(ShellError): """Timeout when waiting for command to complete""" def __str__(self): - return ("Timeout expired while waiting for shell command to " - f"complete: {self.cmd!r} (output: {self.output!r})") + return ( + "Timeout expired while waiting for shell command to " + f"complete: {self.cmd!r} (output: {self.output!r})" + ) class ShellProcessTerminatedError(ShellError): @@ -89,9 +99,11 @@ def __init__(self, cmd, status, output): self.status = status def __str__(self): - return ("Shell process terminated while waiting for command to " - f"complete: {self.cmd!r} (status: {self.status}, " - f"output: {self.output!r})") + return ( + "Shell process terminated while waiting for command to " + f"complete: {self.cmd!r} (status: {self.status}, " + f"output: {self.output!r})" + ) class ShellCmdError(ShellError): @@ -106,8 +118,10 @@ def __init__(self, cmd, status, output): self.status = status def __str__(self): - return (f"Shell command failed: {self.cmd!r} (status: {self.status}" - f", output: {self.output!r})") + return ( + f"Shell command failed: {self.cmd!r} (status: {self.status}" + f", output: {self.output!r})" + ) class ShellStatusError(ShellError): @@ -117,5 +131,7 @@ class ShellStatusError(ShellError): """ def __str__(self): - return (f"Could not get exit status of command: {self.cmd!r} " - f"(output: {self.output!r})") + return ( + f"Could not get exit status of command: {self.cmd!r} " + f"(output: {self.output!r})" + ) diff --git a/aexpect/ops_linux.py b/aexpect/ops_linux.py index 24925ab..2b560a2 100644 --- a/aexpect/ops_linux.py +++ b/aexpect/ops_linux.py @@ -54,6 +54,7 @@ from shlex import quote from aexpect.exceptions import ShellCmdError + # Need this import for sphinx and other documentation to produce links later on # from .client import ShellSession @@ -231,7 +232,7 @@ def ls(session, path, quote_path=True, flags="-1UNq"): # pylint: disable=C0103 Just like :py:func:`os.listdir`, does not include file names starting with dot (`'.'`) """ - cmd = f'ls {flags} {quote(path)}' if quote_path else f'ls {flags} {path}' + cmd = f"ls {flags} {quote(path)}" if quote_path else f"ls {flags} {path}" status, output = session.cmd_status_output(cmd) status, output = _process_status_output(cmd, status, output) return output.splitlines() @@ -253,7 +254,7 @@ def glob(session, glob_pattern): Alternative implementations are either shell specific or use `echo` and thus cannot handle spaces in paths well. """ - cmd = f'find {glob_pattern} -maxdepth 0' + cmd = f"find {glob_pattern} -maxdepth 0" status, output = session.cmd_status_output(cmd) if status == 1: return [] @@ -276,8 +277,11 @@ def move(session, source, target, quote_path=True, flags=""): Calls `mv source target` through a session. See `man mv` for what source and target can be and what behavior to expect. """ - cmd = f'mv {flags} {quote(source)} {quote(target)}' if quote_path \ - else f'mv {flags} {source} {target}' + cmd = ( + f"mv {flags} {quote(source)} {quote(target)}" + if quote_path + else f"mv {flags} {source} {target}" + ) status, output = session.cmd_status_output(cmd) _process_status_output(cmd, status, output) @@ -297,8 +301,11 @@ def copy(session, source, target, quote_path=True, flags=""): Calls `cp source target` through a session. See `man cp` for what source and target can be and what behavior to expect. """ - cmd = f'cp {flags} {quote(source)} {quote(target)}' if quote_path \ - else f'cp {flags} {source} {target}' + cmd = ( + f"cp {flags} {quote(source)} {quote(target)}" + if quote_path + else f"cp {flags} {source} {target}" + ) status, output = session.cmd_status_output(cmd) _process_status_output(cmd, status, output) @@ -316,7 +323,7 @@ def remove(session, path, quote_path=True, flags="-fr"): Calls `rm -rf`. """ - cmd = f'rm {flags} {quote(path)}' if quote_path else f'rm {flags} {path}' + cmd = f"rm {flags} {quote(path)}" if quote_path else f"rm {flags} {path}" status, output = session.cmd_status_output(cmd) _process_status_output(cmd, status, output) @@ -335,7 +342,7 @@ def make_tempdir(session, template=None): Calls `mktemp -d`, refer to `man` for more info. """ - cmd = f'mktemp -d {template}' if template is not None else 'mktemp -d' + cmd = f"mktemp -d {template}" if template is not None else "mktemp -d" status, output = session.cmd_status_output(cmd) _, output = _process_status_output(cmd, status, output) return output @@ -355,7 +362,7 @@ def make_tempfile(session, template=None): Calls `mktemp`, refer to `man` for more info. """ - cmd = f'mktemp {template}' if template is not None else 'mktemp' + cmd = f"mktemp {template}" if template is not None else "mktemp" status, output = session.cmd_status_output(cmd) _, output = _process_status_output(cmd, status, output) return output @@ -377,7 +384,7 @@ def cat(session, filename, quote_path=True, flags=""): Should only be used for very small files without tabs or other fancy contents. Otherwise, it is better to download the file or use some other method. """ - cmd = f'cat {flags} {quote(filename)}' if quote_path else f'cat {flags} {filename}' + cmd = f"cat {flags} {quote(filename)}" if quote_path else f"cat {flags} {filename}" status, output = session.cmd_status_output(cmd) _, output = _process_status_output(cmd, status, output) return output @@ -463,15 +470,16 @@ def hash_file(session, filename, size="", method="md5"): if output: output = output.strip() if status != 0 or not output: - raise RuntimeError(f'Could not hash {filename} using {cmd}: {output}') + raise RuntimeError(f"Could not hash {filename} using {cmd}: {output}") # parse output hash_str = output.split(maxsplit=1)[0].lower() # check that all chars are hex - if hash_str.strip('0123456789abcdef'): - raise RuntimeError('Resulting hash string has unexpected characters: ' - + hash_str) + if hash_str.strip("0123456789abcdef"): + raise RuntimeError( + "Resulting hash string has unexpected characters: " + hash_str + ) return hash_str @@ -486,6 +494,6 @@ def extract_tarball(session, tarball, target_dir, flags="-ap"): :param str flags: extra flags passed to ``tar`` on the command line :raises: :py:class:`RuntimeError` if tar command returned non-null """ - cmd = f'tar -C {quote(target_dir)} {flags} -xf {quote(tarball)}' + cmd = f"tar -C {quote(target_dir)} {flags} -xf {quote(tarball)}" status, output = session.cmd_status_output(cmd) _process_status_output(cmd, status, output) diff --git a/aexpect/ops_windows.py b/aexpect/ops_windows.py index dc0798b..6038374 100644 --- a/aexpect/ops_windows.py +++ b/aexpect/ops_windows.py @@ -33,12 +33,12 @@ Network configuration and downloads. """ +import logging import re import uuid -import logging -from textwrap import dedent from base64 import b64encode from enum import Enum, auto +from textwrap import dedent # avocado imports from aexpect.exceptions import ShellError, ShellProcessTerminatedError @@ -67,6 +67,7 @@ def ps_cmd(session, script, timeout=60): This function was slightly based on a similar function from `pywinrm`. """ + def nicely_log_str(header, string): LOG.info(header) LOG.info("-" * len(header)) @@ -78,12 +79,14 @@ def nicely_log_str(header, string): # must use utf16 little endian on Windows encoded_cmd = b64encode(script.encode("utf_16_le")).decode("ascii") try: - output = session.cmd(f"powershell -NoLogo -NonInteractive -OutputFormat text " - f"-ExecutionPolicy Bypass -EncodedCommand {encoded_cmd}", - timeout=timeout) + output = session.cmd( + f"powershell -NoLogo -NonInteractive -OutputFormat text " + f"-ExecutionPolicy Bypass -EncodedCommand {encoded_cmd}", + timeout=timeout, + ) # there was an error but the exit code was still zero - if output.startswith("#< CLIXML") and "" in output: + if output.startswith("#< CLIXML") and '' in output: raise ShellError(script, output) # TODO: PowerShell sometimes ignores the -OutputFormat and outputs @@ -113,7 +116,7 @@ def ps_file(session, filename, timeout=60): :returns: the output of the command :rtype: str """ - cmd = f"powershell -ExecutionPolicy RemoteSigned -File \"{filename}\"" + cmd = f'powershell -ExecutionPolicy RemoteSigned -File "{filename}"' LOG.info("Executing PowerShell file with `%s`", cmd) return session.cmd(cmd, timeout=timeout) @@ -132,7 +135,7 @@ def _clean_error_message(message): output = [] # the strings are between ... tags - for msgstr in message.split(""): + for msgstr in message.split(''): if not msgstr.endswith(""): continue msgstr = re.sub(r"$", "", msgstr).replace("_x000D__x000A_", "") @@ -203,10 +206,16 @@ def query_registry(session, key_name, value_name): :returns: value of the corresponding sub key and value name or None if not found :rtype: str or None """ - key_types = ["REG_SZ", "REG_MULTI_SZ", "REG_EXPAND_SZ", - "REG_DWORD", "REG_BINARY", "REG_NONE"] + key_types = [ + "REG_SZ", + "REG_MULTI_SZ", + "REG_EXPAND_SZ", + "REG_DWORD", + "REG_BINARY", + "REG_NONE", + ] - out = session.cmd_output(f"reg query \"{key_name}\" /v {value_name}").strip() + out = session.cmd_output(f'reg query "{key_name}" /v {value_name}').strip() LOG.info("Reg query for key %s and value %s: %s", key_name, value_name, out) for k in key_types: parts = out.split(k) @@ -216,7 +225,9 @@ def query_registry(session, key_name, value_name): return None -def add_reg_key(session, path, name, value, key_type, force=True): # pylint: disable=R0913 +def add_reg_key( + session, path, name, value, key_type, force=True +): # pylint: disable=R0913 """ Wrapper around the reg add command. @@ -231,7 +242,7 @@ def add_reg_key(session, path, name, value, key_type, force=True): # pylint: d if not isinstance(key_type, RegistryKeyType): raise TypeError(f"{key_type} must be instance of RegistryKeyType") - cmd = f"reg add \"{path}\" /v {name} /d {value} /t {key_type.value}" + cmd = f'reg add "{path}" /v {name} /d {value} /t {key_type.value}' if force: cmd += " /f" @@ -269,7 +280,7 @@ def path_exists(session, path): :returns: whether the given path exists :rtype: str """ - output = session.cmd_output(f"if exist \"{path}\" echo yes") + output = session.cmd_output(f'if exist "{path}" echo yes') return output.strip() == "yes" @@ -283,7 +294,7 @@ def hash_file(session, filename): :returns: hash value :rtype: str """ - output = session.cmd(f"certutil -hashfile \"{filename}\" MD5") + output = session.cmd(f'certutil -hashfile "{filename}" MD5') LOG.debug("certutil output for %s is %s", filename, output) hash_value = output.splitlines()[1] LOG.debug("Hash value is %s", hash_value) @@ -302,9 +313,9 @@ def mkdir(session, path, exist_ok=True): exists and exist_ok is False """ if exist_ok: - cmd = f"if not exist \"{path}\" mkdir \"{path}\"" + cmd = f'if not exist "{path}" mkdir "{path}"' else: - cmd = f"mkdir \"{path}\"" + cmd = f'mkdir "{path}"' session.cmd(cmd) @@ -318,7 +329,7 @@ def touch(session, path): :returns: path of the file created :rtype: str """ - session.cmd(f"type nul > \"{path}\"") + session.cmd(f'type nul > "{path}"') return path @@ -381,7 +392,9 @@ def tempfile(session, local=True): ############################################################################### -def run_as(session, user, password, command, timeout=60, background=False): # pylint: disable=R0913 +def run_as( + session, user, password, command, timeout=60, background=False +): # pylint: disable=R0913 """ Run a command as different user. @@ -406,12 +419,12 @@ def _run_as(cmd_to_run): session.sendline(password) if background is True: - _run_as(f"runas /user:{user} \"{command}\"") + _run_as(f'runas /user:{user} "{command}"') return None outfile = tempfile(session, local=False) # runas produces no output, so we add a redirection to the inner command - cmd = f"runas /user:{user} \"cmd.exe /c {command} > {outfile}\"" + cmd = f'runas /user:{user} "cmd.exe /c {command} > {outfile}"' _run_as(cmd) LOG.debug("Waiting %s seconds for the command to finish", timeout) # no need to capture anything, we redirected the output @@ -432,7 +445,7 @@ def kill_program(session, program, kill_children=True): :param str program: name of the program to kill :param bool kill_children: whether to kill child processes """ - cmd = f"taskkill /f /im \"{program}\"" + cmd = f'taskkill /f /im "{program}"' if kill_children: cmd += " /t" session.cmd(cmd) @@ -452,8 +465,9 @@ def wait_process_end(session, process, timeout=30): LOG.info("Waiting for process %s to end using PowerShell", process) # give some extra timeout to wait for PowerShell to return - ps_cmd(session, f"Wait-Process -Name {process} -Timeout {timeout}", - timeout=timeout+5) + ps_cmd( + session, f"Wait-Process -Name {process} -Timeout {timeout}", timeout=timeout + 5 + ) def kill_session(session): @@ -504,7 +518,9 @@ def find_unused_port(session, start_port=10024): :rtype: int """ # pylint: disable=C0301 - output = ps_cmd(session, f""" + output = ps_cmd( + session, + f""" $port = {start_port} while($true) {{ try {{ @@ -521,7 +537,8 @@ def find_unused_port(session, start_port=10024): }} Write-Host "::unused=$port" - """) + """, + ) # pylint: enable=C0301 port = re.search(r"::unused=(\d+)", output).group(1) return int(port) @@ -540,7 +557,8 @@ def curl(session, url, proxy_address=None, proxy_port=None, insecure=False): :rtype: (int, str) """ # extra indentation is not allowed within PS heredocs - skip_ssl_template = dedent("""\ + skip_ssl_template = dedent( + """\ Add-Type -Language CSharp @" namespace System.Net { @@ -555,22 +573,28 @@ def curl(session, url, proxy_address=None, proxy_port=None, insecure=False): } } "@; - [System.Net.Util]::Init()""") + [System.Net.Util]::Init()""" + ) # TODO: WebRequest (wrapper on top of System.Net.WebClient) is painfully slow # for some reason compared to the 3rd-party curl.exe tool, but let's use it for now. - webrequest_cmd = f"Invoke-WebRequest -Uri \"{url}\"" + webrequest_cmd = f'Invoke-WebRequest -Uri "{url}"' if proxy_address is not None: # the session user must have access to the proxy (otherwise use -ProxyCredential) - webrequest_cmd += f" -Proxy http://{proxy_address}:{proxy_port} -ProxyUseDefaultCredentials" + webrequest_cmd += ( + f" -Proxy http://{proxy_address}:{proxy_port} -ProxyUseDefaultCredentials" + ) - output = ps_cmd(session, f""" + output = ps_cmd( + session, + f""" {skip_ssl_template if insecure else ''} $ProgressPreference = 'SilentlyContinue' $result = {webrequest_cmd} Write-Output $result.StatusCode Write-Output $result.Content - """) + """, + ) LOG.debug("Powershell request produced output:\n'%s'", output) try: diff --git a/aexpect/remote.py b/aexpect/remote.py index a49fa32..978d3af 100644 --- a/aexpect/remote.py +++ b/aexpect/remote.py @@ -43,16 +43,15 @@ # ..todo:: we could reduce the disabled issues after more significant refactoring from __future__ import division + import logging -import time import re import shlex +import time -from aexpect.client import Expect -from aexpect.client import RemoteSession -from aexpect.exceptions import ExpectTimeoutError -from aexpect.exceptions import ExpectProcessTerminatedError from aexpect import rss_client +from aexpect.client import Expect, RemoteSession +from aexpect.exceptions import ExpectProcessTerminatedError, ExpectTimeoutError LOG = logging.getLogger(__name__) @@ -69,7 +68,7 @@ class RemoteError(Exception): class LoginError(RemoteError): """Base class for any remote error related to logins and session creation.""" - def __init__(self, msg, output=''): + def __init__(self, msg, output=""): super().__init__() self.msg = msg self.output = output @@ -85,27 +84,26 @@ class LoginAuthenticationError(LoginError): class LoginTimeoutError(LoginError): """Remote error related to login timeout expiration.""" - def __init__(self, output=''): + def __init__(self, output=""): super().__init__("Login timeout expired", output) class LoginProcessTerminatedError(LoginError): """Remote error related to login process termination.""" - def __init__(self, status, output=''): + def __init__(self, status, output=""): super().__init__("Client process terminated", output) self.status = status def __str__(self): - return (f"{self.msg} (status: {self.status}, " - f"output: {self.output!r})") + return f"{self.msg} (status: {self.status}, " f"output: {self.output!r})" class LoginBadClientError(LoginError): """Remote error related to unknown remote shell client.""" def __init__(self, client): - super().__init__('Unknown remote shell client') + super().__init__("Unknown remote shell client") self.client = client def __str__(self): @@ -132,8 +130,10 @@ def __init__(self, client): self.client = client def __str__(self): - return (f"Unknown file copy client: '{self.client}', " - "valid values are scp and rss") + return ( + f"Unknown file copy client: '{self.client}', " + "valid values are scp and rss" + ) class SCPError(TransferError): @@ -166,8 +166,10 @@ def __init__(self, status, output): self.status = status def __str__(self): - return (f"SCP transfer failed (status: {self.status}, " - f"output: {self.output!r})") + return ( + f"SCP transfer failed (status: {self.status}, " + f"output: {self.output!r})" + ) class NetcatError(TransferError): @@ -189,8 +191,10 @@ def __init__(self, status, output): self.status = status def __str__(self): - return (f"Netcat transfer failed (status: {self.status}, " - f"output: {self.output!r})") + return ( + f"Netcat transfer failed (status: {self.status}, " + f"output: {self.output!r})" + ) class NetcatTransferIntegrityError(NetcatError): @@ -216,12 +220,13 @@ def quote_path(path): :return: Shell escaped version """ if isinstance(path, list): - return ' '.join(map(shlex.quote, path)) + return " ".join(map(shlex.quote, path)) return shlex.quote(path) -def handle_prompts(session, username, password, prompt=PROMPT_LINUX, - timeout=10, debug=False): +def handle_prompts( + session, username, password, prompt=PROMPT_LINUX, timeout=10, debug=False +): """ Connect to a remote host (guest) using SSH or Telnet or else. @@ -252,18 +257,28 @@ def handle_prompts(session, username, password, prompt=PROMPT_LINUX, while True: try: match, text = session.read_until_last_line_matches( - [r"[Aa]re you sure", r"[Pp]assword:\s*", - # Prompt of rescue mode for Red Hat. - r"\(or (press|type) Control-D to continue\):\s*$", - r"[Gg]ive.*[Ll]ogin:\s*$", # Prompt of rescue mode for SUSE. - r"(?", - r"[Ll]ost connection"], - timeout=timeout, internal_timeout=0.5) + [ + r"[Aa]re you sure", + r"[Pp]assword:\s*", + # Prompt of rescue mode for Red Hat. + r"\(or (press|type) Control-D to continue\):\s*$", + r"[Gg]ive.*[Ll]ogin:\s*$", # Prompt of rescue mode for SUSE. + r"(?", + r"[Ll]ost connection", + ], + timeout=timeout, + internal_timeout=0.5, + ) output += text if match == 0: # "Are you sure you want to continue connecting" if debug: @@ -272,18 +287,15 @@ def handle_prompts(session, username, password, prompt=PROMPT_LINUX, elif match in [1, 2, 3, 10]: # "password:" if password_prompt_count == 0: if debug: - LOG.debug("Got password prompt, sending '%s'", - password) + LOG.debug("Got password prompt, sending '%s'", password) session.sendline(password) password_prompt_count += 1 else: - raise LoginAuthenticationError("Got password prompt twice", - text) + raise LoginAuthenticationError("Got password prompt twice", text) elif match in [4, 9]: # "login:" if login_prompt_count == 0 and password_prompt_count == 0: if debug: - LOG.debug("Got username prompt; sending '%s'", - username) + LOG.debug("Got username prompt; sending '%s'", username) session.sendline(username) login_prompt_count += 1 else: @@ -316,15 +328,13 @@ def handle_prompts(session, username, password, prompt=PROMPT_LINUX, session.sendline() elif match == 14: # VMware vCenter command prompt # Some old vsphere version (e.x. 6.0.0) needs to enable first. - cmd = 'shell.set --enabled True' + cmd = "shell.set --enabled True" LOG.debug( - "Got VMware VCenter prompt, " - "send '%s' to enable shell first", cmd) + "Got VMware VCenter prompt, send '%s' to enable shell first", cmd + ) session.sendline(cmd) - LOG.debug( - "Got VMware VCenter prompt, " - "send 'shell' to launch bash") - session.sendline('shell') + LOG.debug("Got VMware VCenter prompt, send 'shell' to launch bash") + session.sendline("shell") except ExpectTimeoutError as error: # sometimes, linux kernel print some message to console # the message maybe impact match login pattern, so send @@ -341,13 +351,26 @@ def handle_prompts(session, username, password, prompt=PROMPT_LINUX, return output -def remote_login(client, host, port, username, password, prompt, linesep="\n", - log_filename=None, log_function=None, timeout=10, - interface=None, identity_file=None, - status_test_command="echo $?", verbose=False, bind_ip=None, - preferred_authentication='password', - user_known_hosts_file='/dev/null', - extra_cmdline=''): +def remote_login( + client, + host, + port, + username, + password, + prompt, + linesep="\n", + log_filename=None, + log_function=None, + timeout=10, + interface=None, + identity_file=None, + status_test_command="echo $?", + verbose=False, + bind_ip=None, + preferred_authentication="password", + user_known_hosts_file="/dev/null", + extra_cmdline="", +): """ Log into a remote host (guest) using SSH/Telnet/Netcat. @@ -391,12 +414,13 @@ def remote_login(client, host, port, username, password, prompt, linesep="\n", extra_params.append(extra_cmdline) if host and host.lower().startswith("fe80"): if not interface: - raise RemoteError("When using ipv6 linklocal an interface must " - "be assigned") + raise RemoteError("When using ipv6 linklocal an interface must be assigned") host = f"{host}%{interface}" if client == "ssh": - cmd = (f"ssh {' '.join(extra_params)} -o UserKnownHostsFile={user_known_hosts_file} " - f"-o StrictHostKeyChecking=no -p {port}") + cmd = ( + f"ssh {' '.join(extra_params)} -o UserKnownHostsFile={user_known_hosts_file} " + f"-o StrictHostKeyChecking=no -p {port}" + ) if bind_ip: cmd += f" -b {bind_ip}" if identity_file: @@ -416,11 +440,20 @@ def remote_login(client, host, port, username, password, prompt, linesep="\n", if verbose: LOG.debug("Login command: '%s'", cmd) - session = RemoteSession(cmd, linesep=linesep, output_func=log_function, - output_params=output_params, output_prefix=host, - prompt=prompt, status_test_command=status_test_command, - client=client, host=host, port=port, - username=username, password=password) + session = RemoteSession( + cmd, + linesep=linesep, + output_func=log_function, + output_params=output_params, + output_prefix=host, + prompt=prompt, + status_test_command=status_test_command, + client=client, + host=host, + port=port, + username=username, + password=password, + ) try: handle_prompts(session, username, password, prompt, timeout) except Exception: @@ -429,11 +462,22 @@ def remote_login(client, host, port, username, password, prompt, linesep="\n", return session -def wait_for_login(client, host, port, username, password, prompt, - linesep="\n", log_filename=None, log_function=None, - timeout=240, internal_timeout=10, interface=None, - preferred_authentication='password', - user_known_hosts_file='/dev/null'): +def wait_for_login( + client, + host, + port, + username, + password, + prompt, + linesep="\n", + log_filename=None, + log_function=None, + timeout=240, + internal_timeout=10, + interface=None, + preferred_authentication="password", + user_known_hosts_file="/dev/null", +): """ Make multiple attempts to log into a guest until one succeeds or timeouts. @@ -460,31 +504,56 @@ def wait_for_login(client, host, port, username, password, prompt, :raise: Whatever remote_login() raises :return: A RemoteSession object. """ - LOG.debug("Attempting to log into %s:%s using %s (timeout %ds)", - host, port, client, timeout) + LOG.debug( + "Attempting to log into %s:%s using %s (timeout %ds)", + host, + port, + client, + timeout, + ) end_time = time.time() + timeout verbose = False while time.time() < end_time: try: - return remote_login(client, host, port, username, password, prompt, - linesep, log_filename, log_function, - internal_timeout, interface, verbose=verbose, - preferred_authentication=preferred_authentication, - user_known_hosts_file=user_known_hosts_file) + return remote_login( + client, + host, + port, + username, + password, + prompt, + linesep, + log_filename, + log_function, + internal_timeout, + interface, + verbose=verbose, + preferred_authentication=preferred_authentication, + user_known_hosts_file=user_known_hosts_file, + ) except LoginError as error: LOG.debug(error) verbose = True time.sleep(2) # Timeout expired; try one more time but don't catch exceptions - return remote_login(client, host, port, username, password, prompt, - linesep, log_filename, log_function, - internal_timeout, interface, - preferred_authentication=preferred_authentication, - user_known_hosts_file=user_known_hosts_file) - - -def _remote_scp( - session, password_list, transfer_timeout=600, login_timeout=300): + return remote_login( + client, + host, + port, + username, + password, + prompt, + linesep, + log_filename, + log_function, + internal_timeout, + interface, + preferred_authentication=preferred_authentication, + user_known_hosts_file=user_known_hosts_file, + ) + + +def _remote_scp(session, password_list, transfer_timeout=600, login_timeout=300): """ Transfer files using SCP, given a command line. @@ -516,29 +585,34 @@ def _remote_scp( try: match, text = session.read_until_last_line_matches( [r"[Aa]re you sure", r"[Pp]assword:\s*$", r"lost connection"], - timeout=timeout, internal_timeout=0.5) + timeout=timeout, + internal_timeout=0.5, + ) if match == 0: # "Are you sure you want to continue connecting" LOG.debug("Got 'Are you sure...', sending 'yes'") session.sendline("yes") elif match == 1: # "password:" if password_prompt_count == 0: - LOG.debug("Got password prompt, sending '%s'", - password_list[password_prompt_count]) + LOG.debug( + "Got password prompt, sending '%s'", + password_list[password_prompt_count], + ) session.sendline(password_list[password_prompt_count]) password_prompt_count += 1 timeout = transfer_timeout if scp_type == 1: authentication_done = True elif password_prompt_count == 1 and scp_type == 2: - LOG.debug("Got password prompt, sending '%s'", - password_list[password_prompt_count]) + LOG.debug( + "Got password prompt, sending '%s'", + password_list[password_prompt_count], + ) session.sendline(password_list[password_prompt_count]) password_prompt_count += 1 timeout = transfer_timeout authentication_done = True else: - raise SCPAuthenticationError("Got password prompt twice", - text) + raise SCPAuthenticationError("Got password prompt twice", text) elif match == 2: # "lost connection" raise SCPError("SCP client said 'lost connection'", text) except ExpectTimeoutError as error: @@ -552,8 +626,14 @@ def _remote_scp( raise SCPTransferFailedError(error.status, error.output) from error -def remote_scp(command, password_list, log_filename=None, log_function=None, - transfer_timeout=600, login_timeout=300): +def remote_scp( + command, + password_list, + log_filename=None, + log_function=None, + transfer_timeout=600, + login_timeout=300, +): """ Transfer files using SCP, given a command line. @@ -569,22 +649,33 @@ def remote_scp(command, password_list, log_filename=None, log_function=None, or the password prompt) :raise: Whatever _remote_scp() raises """ - LOG.debug("Trying to SCP with command '%s', timeout %ss", - command, transfer_timeout) + LOG.debug("Trying to SCP with command '%s', timeout %ss", command, transfer_timeout) if log_filename: output_func = log_function output_params = (log_filename,) else: output_func = None output_params = () - with Expect(command, output_func=output_func, - output_params=output_params) as session: + with Expect( + command, output_func=output_func, output_params=output_params + ) as session: _remote_scp(session, password_list, transfer_timeout, login_timeout) -def scp_to_remote(host, port, username, password, local_path, remote_path, - limit="", log_filename=None, log_function=None, - timeout=600, interface=None, directory=True): +def scp_to_remote( + host, + port, + username, + password, + local_path, + remote_path, + limit="", + log_filename=None, + log_function=None, + timeout=600, + interface=None, + directory=True, +): """ Copy files to a remote host (guest) through scp. @@ -609,26 +700,40 @@ def scp_to_remote(host, port, username, password, local_path, remote_path, if host and host.lower().startswith("fe80"): if not interface: - raise SCPError("When using ipv6 linklocal address must assign", - "the interface the neighbour attache") + raise SCPError( + "When using ipv6 linklocal address must assign", + "the interface the neighbour attache", + ) host = f"{host}%{interface}" command = "scp" if directory: command = f"{command} -r" - command += (r" -v -o UserKnownHostsFile=/dev/null " - r"-o StrictHostKeyChecking=no " - fr"-o PreferredAuthentications=password {limit} " - fr"-P {port} {quote_path(local_path)} {username}@\[{host}\]:" - fr"{shlex.quote(remote_path)}") + command += ( + r" -v -o UserKnownHostsFile=/dev/null " + r"-o StrictHostKeyChecking=no " + rf"-o PreferredAuthentications=password {limit} " + rf"-P {port} {quote_path(local_path)} {username}@\[{host}\]:" + rf"{shlex.quote(remote_path)}" + ) password_list = [password] - return remote_scp(command, password_list, - log_filename, log_function, timeout) - - -def scp_from_remote(host, port, username, password, remote_path, local_path, - limit="", log_filename=None, log_function=None, - timeout=600, interface=None, directory=True): + return remote_scp(command, password_list, log_filename, log_function, timeout) + + +def scp_from_remote( + host, + port, + username, + password, + remote_path, + local_path, + limit="", + log_filename=None, + log_function=None, + timeout=600, + interface=None, + directory=True, +): """ Copy files from a remote host (guest). @@ -652,27 +757,44 @@ def scp_from_remote(host, port, username, password, remote_path, local_path, limit = f"-l {limit}" if host and host.lower().startswith("fe80"): if not interface: - raise SCPError("When using ipv6 linklocal address must assign, ", - "the interface the neighbour attache") + raise SCPError( + "When using ipv6 linklocal address must assign, ", + "the interface the neighbour attache", + ) host = f"{host}%{interface}" command = "scp" if directory: command = f"{command} -r" - command += (r" -v -o UserKnownHostsFile=/dev/null " - r"-o StrictHostKeyChecking=no " - fr"-o PreferredAuthentications=password {limit} " - fr"-P {port} {username}@\[{host}\]:{quote_path(remote_path)} " - fr"{shlex.quote(local_path)}") + command += ( + r" -v -o UserKnownHostsFile=/dev/null " + r"-o StrictHostKeyChecking=no " + rf"-o PreferredAuthentications=password {limit} " + rf"-P {port} {username}@\[{host}\]:{quote_path(remote_path)} " + rf"{shlex.quote(local_path)}" + ) password_list = [password] - remote_scp(command, password_list, - log_filename, log_function, timeout) - - -def scp_between_remotes(src, dst, port, s_passwd, d_passwd, s_name, d_name, - s_path, d_path, limit="", - log_filename=None, log_function=None, timeout=600, - src_inter=None, dst_inter=None, directory=True): + remote_scp(command, password_list, log_filename, log_function, timeout) + + +def scp_between_remotes( + src, + dst, + port, + s_passwd, + d_passwd, + s_name, + d_name, + s_path, + d_path, + limit="", + log_filename=None, + log_function=None, + timeout=600, + src_inter=None, + dst_inter=None, + directory=True, +): """ Copy files from a remote host (guest) to another remote host (guest). @@ -700,35 +822,54 @@ def scp_between_remotes(src, dst, port, s_passwd, d_passwd, s_name, d_name, limit = f"-l {limit}" if src and src.lower().startswith("fe80"): if not src_inter: - raise SCPError("When using ipv6 linklocal address must assign ", - "the interface the neighbour attache") + raise SCPError( + "When using ipv6 linklocal address must assign ", + "the interface the neighbour attache", + ) src = f"{src}%{src_inter}" if dst and dst.lower().startswith("fe80"): if not dst_inter: - raise SCPError("When using ipv6 linklocal address must assign ", - "the interface the neighbour attache") + raise SCPError( + "When using ipv6 linklocal address must assign ", + "the interface the neighbour attache", + ) dst = f"{dst}%{dst_inter}" command = "scp" if directory: command = f"{command} -r" - command += (r" -v -o UserKnownHostsFile=/dev/null " - r"-o StrictHostKeyChecking=no " - fr"-o PreferredAuthentications=password {limit} -P {port}" - fr" {s_name}@\[{src}\]:{quote_path(s_path)} {d_name}@\[{dst}\]" - fr":{shlex.quote(d_path)}") + command += ( + r" -v -o UserKnownHostsFile=/dev/null " + r"-o StrictHostKeyChecking=no " + rf"-o PreferredAuthentications=password {limit} -P {port}" + rf" {s_name}@\[{src}\]:{quote_path(s_path)} {d_name}@\[{dst}\]" + rf":{shlex.quote(d_path)}" + ) password_list = [s_passwd, d_passwd] - return remote_scp(command, password_list, - log_filename, log_function, timeout) + return remote_scp(command, password_list, log_filename, log_function, timeout) # noinspection PyBroadException -def nc_copy_between_remotes(src, dst, s_port, s_passwd, d_passwd, - s_name, d_name, s_path, d_path, - c_type="ssh", c_prompt="\n", - d_port="8888", d_protocol="tcp", timeout=2, - check_sum=True, s_session=None, - d_session=None, file_transfer_timeout=600): +def nc_copy_between_remotes( + src, + dst, + s_port, + s_passwd, + d_passwd, + s_name, + d_name, + s_path, + d_path, + c_type="ssh", + c_prompt="\n", + d_port="8888", + d_protocol="tcp", + timeout=2, + check_sum=True, + s_session=None, + d_session=None, + file_transfer_timeout=600, +): """ Copy files from guest to guest using netcat. @@ -758,19 +899,9 @@ def nc_copy_between_remotes(src, dst, s_port, s_passwd, d_passwd, """ check_string = "NCFT" if not s_session: - s_session = remote_login(c_type, - src, - s_port, - s_name, - s_passwd, - c_prompt) + s_session = remote_login(c_type, src, s_port, s_name, s_passwd, c_prompt) if not d_session: - d_session = remote_login(c_type, - dst, - s_port, - d_name, - d_passwd, - c_prompt) + d_session = remote_login(c_type, dst, s_port, d_name, d_passwd, c_prompt) try: s_session.cmd(f"iptables -I INPUT -p {d_protocol} -j ACCEPT") @@ -786,12 +917,12 @@ def nc_copy_between_remotes(src, dst, s_port, s_passwd, d_passwd, d_session.sendline(receive_cmd) send_cmd = f"{cmd} {dst} {d_port} < {s_path}" status, output = s_session.cmd_status_output( - send_cmd, timeout=file_transfer_timeout) + send_cmd, timeout=file_transfer_timeout + ) if status: err = f"Fail to transfer file between {src} -> {dst}." if check_string not in output: - err += ("src did not receive check " - f"string {check_string} sent by dst.") + err += "src did not receive check " f"string {check_string} sent by dst." err += f"send nc command {send_cmd}, output {output}" err += f"Receive nc command {receive_cmd}." raise NetcatTransferFailedError(status, err) @@ -802,17 +933,30 @@ def nc_copy_between_remotes(src, dst, s_port, s_passwd, d_passwd, src_md5 = output.split()[0] dst_md5 = d_session.cmd(f"md5sum {d_path}").split()[0] if src_md5.strip() != dst_md5.strip(): - err_msg = ("Files md5sum mismatch, " - f"file {s_path} md5sum is '{src_md5}', " - f"but the file {d_path} md5sum is {dst_md5}") + err_msg = ( + "Files md5sum mismatch, " + f"file {s_path} md5sum is '{src_md5}', " + f"but the file {d_path} md5sum is {dst_md5}" + ) raise NetcatTransferIntegrityError(err_msg) return True -def udp_copy_between_remotes(src, dst, s_port, s_passwd, d_passwd, - s_name, d_name, s_path, d_path, - c_type="ssh", c_prompt="\n", - d_port="9000", timeout=600): +def udp_copy_between_remotes( + src, + dst, + s_port, + s_passwd, + d_passwd, + s_name, + d_name, + s_path, + d_path, + c_type="ssh", + c_prompt="\n", + d_port="9000", + timeout=600, +): """ Copy files from guest to guest using udp. @@ -839,23 +983,21 @@ def get_abs_path(session, filename, extension): cmd_tmp += "extension='%s'\" get drive^,path" cmd = cmd_tmp % (filename, extension) info = session.cmd_output(cmd, timeout=360).strip() - drive_path = re.search(r'(\w):\s+(\S+)', info, re.M) + drive_path = re.search(r"(\w):\s+(\S+)", info, re.M) if not drive_path: - raise UDPError(f"Not found file {filename}.{extension} " - "in your guest") + raise UDPError(f"Not found file {filename}.{extension} " "in your guest") return ":".join(drive_path.groups()) def get_file_md5(session, file_path): """Get files md5sums.""" if c_type == "ssh": md5_cmd = f"md5sum {file_path}" - md5_reg = fr"(\w+)\s+{file_path}.*" + md5_reg = rf"(\w+)\s+{file_path}.*" else: drive_path = get_abs_path(session, "md5sums", "exe") filename = file_path.split("\\")[-1] - md5_reg = fr"{filename}\s+(\w+)" - md5_cmd = (f'{drive_path}md5sums.exe {file_path} | ' - f'find "{filename}"') + md5_reg = rf"{filename}\s+(\w+)" + md5_cmd = f"{drive_path}md5sums.exe {file_path} | " f'find "{filename}"' output = session.cmd_output(md5_cmd) file_md5 = re.findall(md5_reg, output) if not output: @@ -893,9 +1035,13 @@ def start_client(session): else: drive_path = get_abs_path(session, "recvfile", "exe") client_cmd_tmp = "%srecvfile.exe %s %s %s %s" - client_cmd = client_cmd_tmp % (drive_path, src, d_port, - s_path.split("\\")[-1], - d_path.split("\\")[-1]) + client_cmd = client_cmd_tmp % ( + drive_path, + src, + d_port, + s_path.split("\\")[-1], + d_path.split("\\")[-1], + ) session.cmd_output_safe(client_cmd, timeout) def stop_server(session): @@ -914,9 +1060,11 @@ def stop_server(session): start_client(d_session) dst_md5 = get_file_md5(d_session, d_path) if src_md5 != dst_md5: - err_msg = ("Files md5sum mismatch, " - f"file {s_path} md5sum is '{src_md5}', " - f"but the file {d_path} md5sum is {dst_md5}") + err_msg = ( + "Files md5sum mismatch, " + f"file {s_path} md5sum is '{src_md5}', " + f"but the file {d_path} md5sum is {dst_md5}" + ) raise UDPError(err_msg) finally: stop_server(s_session) @@ -924,8 +1072,14 @@ def stop_server(session): d_session.close() -def login_from_session(session, log_filename=None, log_function=None, - timeout=240, internal_timeout=10, interface=None): +def login_from_session( + session, + log_filename=None, + log_function=None, + timeout=240, + internal_timeout=10, + interface=None, +): """ Log in remotely and return a session for the connection with the same configuration as a previous session. @@ -942,16 +1096,33 @@ def login_from_session(session, log_filename=None, log_function=None, The rest of the arguments are identical to wait_for_login(). """ - return wait_for_login(session.client, session.host, session.port, - session.username, session.password, - session.prompt, session.linesep, - log_filename, log_function, - timeout, internal_timeout, interface) - - -def scp_to_session(session, local_path, remote_path, - limit="", log_filename=None, log_function=None, - timeout=600, interface=None, directory=True): + return wait_for_login( + session.client, + session.host, + session.port, + session.username, + session.password, + session.prompt, + session.linesep, + log_filename, + log_function, + timeout, + internal_timeout, + interface, + ) + + +def scp_to_session( + session, + local_path, + remote_path, + limit="", + log_filename=None, + log_function=None, + timeout=600, + interface=None, + directory=True, +): """ Secure copy a filepath (w/o wildcard) to a remote location with the same configuration as a previous session. @@ -969,16 +1140,33 @@ def scp_to_session(session, local_path, remote_path, The rest of the arguments are identical to scp_to_remote(). """ - scp_to_remote(session.host, session.port, - session.username, session.password, - local_path, remote_path, - limit, log_filename, log_function, - timeout, interface, directory) - - -def scp_from_session(session, remote_path, local_path, - limit="", log_filename=None, log_function=None, - timeout=600, interface=None, directory=True): + scp_to_remote( + session.host, + session.port, + session.username, + session.password, + local_path, + remote_path, + limit, + log_filename, + log_function, + timeout, + interface, + directory, + ) + + +def scp_from_session( + session, + remote_path, + local_path, + limit="", + log_filename=None, + log_function=None, + timeout=600, + interface=None, + directory=True, +): """ Secure copy a filepath (w/o wildcard) from a remote location with the same configuration as a previous session. @@ -996,11 +1184,20 @@ def scp_from_session(session, remote_path, local_path, The rest of the arguments are identical to scp_from_remote(). """ - scp_from_remote(session.host, session.port, - session.username, session.password, - remote_path, local_path, - limit, log_filename, log_function, - timeout, interface, directory) + scp_from_remote( + session.host, + session.port, + session.username, + session.password, + remote_path, + local_path, + limit, + log_filename, + log_function, + timeout, + interface, + directory, + ) def throughput_transfer(func): @@ -1031,10 +1228,23 @@ def transfer(*args, **kwargs): # noinspection PyUnusedLocal @throughput_transfer -def copy_files_to(address, client, username, password, port, local_path, - remote_path, limit="", log_filename=None, log_function=None, - verbose=False, timeout=600, interface=None, filesize=None, # pylint: disable=unused-argument - directory=True): +def copy_files_to( + address, + client, + username, + password, + port, + local_path, + remote_path, + limit="", + log_filename=None, + log_function=None, + verbose=False, + timeout=600, + interface=None, + filesize=None, # pylint: disable=unused-argument + directory=True, +): """ Copy files to a remote host (guest) using the selected client. @@ -1058,9 +1268,20 @@ def copy_files_to(address, client, username, password, port, local_path, :raise: Whatever remote_scp() raises """ if client == "scp": - scp_to_remote(address, port, username, password, local_path, - remote_path, limit, log_filename, log_function, timeout, - interface=interface, directory=directory) + scp_to_remote( + address, + port, + username, + password, + local_path, + remote_path, + limit, + log_filename, + log_function, + timeout, + interface=interface, + directory=directory, + ) elif client == "rss": log_func = None if verbose: @@ -1076,10 +1297,23 @@ def copy_files_to(address, client, username, password, port, local_path, # noinspection PyUnusedLocal @throughput_transfer -def copy_files_from(address, client, username, password, port, remote_path, - local_path, limit="", log_filename=None, log_function=None, - verbose=False, timeout=600, interface=None, filesize=None, # pylint: disable=unused-argument - directory=True): +def copy_files_from( + address, + client, + username, + password, + port, + remote_path, + local_path, + limit="", + log_filename=None, + log_function=None, + verbose=False, + timeout=600, + interface=None, + filesize=None, # pylint: disable=unused-argument + directory=True, +): """ Copy files from a remote host (guest) using the selected client. @@ -1103,9 +1337,20 @@ def copy_files_from(address, client, username, password, port, remote_path, :raise: Whatever ``remote_scp()`` raises """ if client == "scp": - scp_from_remote(address, port, username, password, remote_path, - local_path, limit, log_filename, log_function, timeout, - interface=interface, directory=directory) + scp_from_remote( + address, + port, + username, + password, + remote_path, + local_path, + limit, + log_filename, + log_function, + timeout, + interface=interface, + directory=directory, + ) elif client == "rss": log_func = None if verbose: diff --git a/aexpect/remote_door.py b/aexpect/remote_door.py index 36b4a13..78b785a 100644 --- a/aexpect/remote_door.py +++ b/aexpect/remote_door.py @@ -57,13 +57,13 @@ # disable too-many-* as we need them pylint: disable=R0912,R0913,R0914,R0915,C0302 # ..todo:: we could reduce the disabled issues after more significant refactoring +import importlib +import inspect +import logging import os import re -import logging -import inspect -import importlib -import threading import tempfile +import threading import time # NOTE: enable this before importing the Pyro backend in order to debug issues @@ -73,8 +73,10 @@ # noinspection PyPackageRequirements,PyUnresolvedReferences import Pyro4 except ImportError: - logging.warning("Remote object backend (Pyro4) not found, some functionality" - " of the remote door will not be available") + logging.warning( + "Remote object backend (Pyro4) not found, some functionality" + " of the remote door will not be available" + ) # NOTE: disable aexpect importing on the remote side if not available as the # remote door can run code remotely without the requirement for the aexpect @@ -83,8 +85,10 @@ # noinspection PyUnresolvedReferences from aexpect import remote except ImportError: - logging.warning("Failed to import 'aexpect.remote', some functionality " - "might not be available") + logging.warning( + "Failed to import 'aexpect.remote', some functionality " + "might not be available" + ) LOG = logging.getLogger(__name__) @@ -109,15 +113,15 @@ def _string_call(function, *args, **kwargs): - def arg_to_str(arg): if isinstance(arg, str): return f"r'{arg}'" return f"{arg}" args = tuple(arg_to_str(arg) for arg in args) - kwargs = tuple(f"{key}={arg_to_str(value)}" - for key, value in sorted(kwargs.items())) + kwargs = tuple( + f"{key}={arg_to_str(value)}" for key, value in sorted(kwargs.items()) + ) arguments = ", ".join(args + kwargs) return f"result = {function}({arguments})\n" @@ -173,8 +177,11 @@ def run_remote_util(session, utility, function, *args, detach=False, **kwargs): control_body = f"import {utility}\n" control_body += _string_call(utility + "." + function, *args, **kwargs) control_path = _string_generated_control(session.client, control_body) - LOG.debug("Accessing %s remote utility using the wrapper control %s", - utility, control_path) + LOG.debug( + "Accessing %s remote utility using the wrapper control %s", + utility, + control_path, + ) full_output = run_subcontrol(session, control_path, detach=detach) return "None" if detach else re.search("RESULT = (.*)", full_output).group(1) @@ -206,8 +213,9 @@ def wrapper(session, *args, **kwargs): control_body = "\n" + "".join(fn_source).replace("_session, ", "") control_body += "\n" + _string_call(function.__name__, *args, **kwargs) control_path = _string_generated_control(session.client, control_body) - LOG.debug("Running remotely a function using the wrapper control %s", - control_path) + LOG.debug( + "Running remotely a function using the wrapper control %s", control_path + ) full_output = run_subcontrol(session, control_path) return re.search("RESULT = (.*)", full_output).group(1) @@ -233,11 +241,18 @@ def _copy_control(session, control_path, is_utility=False): # ..todo:: use `remote_dir` here remote_control_path = "%TEMP%\\" + os.path.basename(control_path) else: - raise NotImplementedError("run_subcontrol not implemented for client " - f"{session.client}") - remote.copy_files_to(session.host, transfer_client, - session.username, session.password, transfer_port, - control_path, remote_control_path) + raise NotImplementedError( + "run_subcontrol not implemented for client " f"{session.client}" + ) + remote.copy_files_to( + session.host, + transfer_client, + session.username, + session.password, + transfer_port, + control_path, + remote_control_path, + ) return remote_control_path @@ -261,11 +276,13 @@ def run_subcontrol(session, control_path, timeout=600, detach=False): # run on remote Windows hosts # ..todo:: combine with REMOTE_PYTHON_BINARY elif session.client == "nc": - python_binary = session.cmd("where python", timeout=timeout, - print_func=LOG.info).strip() + python_binary = session.cmd( + "where python", timeout=timeout, print_func=LOG.info + ).strip() else: - raise NotImplementedError("run_subcontrol not implemented for client " - f"{session.client}") + raise NotImplementedError( + "run_subcontrol not implemented for client " f"{session.client}" + ) cmd = python_binary + " " + remote_control_path if detach: session.set_output_func(LOG.info) @@ -356,12 +373,11 @@ def set_subcontrol_parameter(subcontrol, parameter, value): .. warning:: The `subcontrol` parameter is control path externally but control content internally after decoration. """ - match = re.search(fr"{parameter.upper()}[ \t\v]*=[ \t\v]*.*", subcontrol) + match = re.search(rf"{parameter.upper()}[ \t\v]*=[ \t\v]*.*", subcontrol) if match is None: return subcontrol # re.sub does undesirable post-processing of the replaced string - return subcontrol.replace(match.group(), - f"{parameter.upper()} = {value!r}") + return subcontrol.replace(match.group(), f"{parameter.upper()} = {value!r}") @set_subcontrol @@ -379,8 +395,7 @@ def set_subcontrol_parameter_list(subcontrol, list_name, value): .. warning:: The `subcontrol` parameter is control path externally but control content internally after decoration. """ - match = re.search(fr"{list_name.upper()}[ \t\v]*=[ \t\v]*\[.*\]", - subcontrol) + match = re.search(rf"{list_name.upper()}[ \t\v]*=[ \t\v]*\[.*\]", subcontrol) if match is None: return subcontrol # re.sub does undesirable post-processing of the replaced string @@ -402,12 +417,11 @@ def set_subcontrol_parameter_dict(subcontrol, dict_name, value): .. warning:: The `subcontrol` parameter is control path externally but control content internally after decoration. """ - match = re.search(fr"{dict_name.upper()}[ \t\v]*=[ \t\v]*\{{.*\}}", subcontrol) + match = re.search(rf"{dict_name.upper()}[ \t\v]*=[ \t\v]*\{{.*\}}", subcontrol) if match is None: return subcontrol # re.sub does undesirable post-processing of the replaced string - return subcontrol.replace(match.group(), - f"{dict_name.upper()} = {value!r}") + return subcontrol.replace(match.group(), f"{dict_name.upper()} = {value!r}") @set_subcontrol @@ -446,16 +460,18 @@ def set_subcontrol_parameter_object(subcontrol, value): pyrod_running = False # address already in use OS error except OSError: - pyro_daemon = Pyro4.Proxy("PYRO:" + Pyro4.constants.DAEMON_NAME + - "@" + host_ip + ":1437") + pyro_daemon = Pyro4.Proxy( + "PYRO:" + Pyro4.constants.DAEMON_NAME + "@" + host_ip + ":1437" + ) pyro_daemon.ping() registered = pyro_daemon.registered() - LOG.debug("Pyro4 daemon already started, available objects: %s", - registered) - assert len(registered) == 2, "The Pyro4 daemon should contain only two"\ - " initially registered objects" - assert registered[0] == "Pyro.Daemon", "The Pyro4 daemon must be first"\ - " registered object" + LOG.debug("Pyro4 daemon already started, available objects: %s", registered) + assert len(registered) == 2, ( + "The Pyro4 daemon should contain only two" " initially registered objects" + ) + assert registered[0] == "Pyro.Daemon", ( + "The Pyro4 daemon must be first" " registered object" + ) uri = "PYRO:" + registered[1] + "@" + host_ip + ":1437" pyrod_running = True @@ -466,8 +482,9 @@ def set_subcontrol_parameter_object(subcontrol, value): LOG.debug("Sending the params object to the host via uri %s", uri) # post-processing of the replaced string is allowed for a URI - subcontrol = re.sub("URI[ \t\v]*=[ \t\v]*\".*\"", f"URI = \"{uri}\"", - subcontrol, count=1) + subcontrol = re.sub( + 'URI[ \t\v]*=[ \t\v]*".*"', f'URI = "{uri}"', subcontrol, count=1 + ) return subcontrol @@ -542,8 +559,15 @@ def get_remote_object(object_name, session=None, host="localhost", port=9090): # if there is no door on the other side, open one _copy_control(session, os.path.abspath(__file__), is_utility=True) - run_remote_util(session, "remote_door", "share_local_object", - object_name, host=host, port=port, detach=True) + run_remote_util( + session, + "remote_door", + "share_local_object", + object_name, + host=host, + port=port, + detach=True, + ) output, attempts = "", 10 for _ in range(attempts): output = session.get_output() @@ -592,8 +616,15 @@ def get_remote_objects(session=None, host="localhost", port=0): # if there is no door on the other side, open one _copy_control(session, os.path.abspath(__file__), is_utility=True) - run_remote_util(session, "remote_door", "share_local_objects", - wait=True, host=host, port=port, detach=True) + run_remote_util( + session, + "remote_door", + "share_local_objects", + wait=True, + host=host, + port=port, + detach=True, + ) control_log = session.cmd("cat " + REMOTE_CONTROL_LOG) for _ in range(10): if "ready" in control_log: @@ -601,8 +632,7 @@ def get_remote_objects(session=None, host="localhost", port=0): time.sleep(1) control_log = session.cmd("cat " + REMOTE_CONTROL_LOG) else: - raise OSError("Local objects sharing failed:\n" - f"{control_log}") from error + raise OSError("Local objects sharing failed:\n" f"{control_log}") from error LOG.debug("Local objects sharing output:\n%s", control_log) remote_objects = flame.connect(host + ":" + str(port)) @@ -641,8 +671,10 @@ def share_local_object(object_name, whitelist=None, host="localhost", port=9090) except OSError: pyro_daemon = Pyro4.Proxy(f"PYRO:{Pyro4.constants.DAEMON_NAME}@{host}") pyro_daemon.ping() - LOG.debug("Pyro4 daemon already started, available objects: %s", - pyro_daemon.registered()) + LOG.debug( + "Pyro4 daemon already started, available objects: %s", + pyro_daemon.registered(), + ) pyrod_running = True # name server @@ -655,6 +687,7 @@ def share_local_object(object_name, whitelist=None, host="localhost", port=9090) except (OSError, Pyro4.errors.NamingError): # noinspection PyPackageRequirements from Pyro4 import naming + ns_uri, ns_daemon, _bc_server = naming.startNS(host=host, port=port) ns_server = Pyro4.Proxy(ns_uri) LOG.debug("Pyro4 name server started successfully with URI %s", ns_uri) @@ -669,18 +702,26 @@ def wrapper(*args, **kwargs): rarg = fun(*args, **kwargs) def proxify_type(_rarg): - if _rarg is None or type(_rarg) in (bool, int, float, str): # pylint: disable=C0123 + if _rarg is None or type(_rarg) in ( + bool, + int, + float, + str, + ): # pylint: disable=C0123 return _rarg if isinstance(_rarg, tuple): return tuple(proxify_type(e) for e in _rarg) if isinstance(_rarg, list): return [proxify_type(e) for e in _rarg] if isinstance(_rarg, dict): - return {proxify_type(k): proxify_type(v) for (k, v) in _rarg.items()} + return { + proxify_type(k): proxify_type(v) for (k, v) in _rarg.items() + } pyro_daemon.register(_rarg) return _rarg import types + if isinstance(rarg, types.GeneratorType): def generator_wrapper(): @@ -742,6 +783,7 @@ def share_local_objects(wait=False, host="localhost", port=0): # main retrieval of the local objects # noinspection PyPackageRequirements from Pyro4.utils import flame + flame.start(pyro_daemon) # request loop @@ -751,8 +793,14 @@ def share_local_objects(wait=False, host="localhost", port=0): loop.join() -def share_remote_objects(session, control_path, host="localhost", port=9090, - os_type="windows", extra_params=None): +def share_remote_objects( + session, + control_path, + host="localhost", + port=9090, + os_type="windows", + extra_params=None, +): """ Create and share remote objects from a remote location over the network. @@ -797,18 +845,30 @@ def share_remote_objects(session, control_path, host="localhost", port=9090, # optional parameters (set only if present and/or available) for key in extra_params.keys(): local_path = set_subcontrol_parameter(local_path, key, extra_params[key]) - remote_path = os.path.join(REMOTE_CONTROL_DIR, - os.path.basename(control_path)) + remote_path = os.path.join(REMOTE_CONTROL_DIR, os.path.basename(control_path)) # NOTE: since we are creating the path in Linux but use it in Windows, # we replace some backslashes if os_type == "windows": remote_path = remote_path.replace("/", "\\") - remote.copy_files_to(session.host, transfer_client, - session.username, session.password, transfer_port, - local_path, remote_path, timeout=10) - middleware_session = remote.wait_for_login(session.client, session.host, session.port, - session.username, session.password, - session.prompt, session.linesep) + remote.copy_files_to( + session.host, + transfer_client, + session.username, + session.password, + transfer_port, + local_path, + remote_path, + timeout=10, + ) + middleware_session = remote.wait_for_login( + session.client, + session.host, + session.port, + session.username, + session.password, + session.prompt, + session.linesep, + ) middleware_session.set_status_test_command(session.status_test_command) middleware_session.set_output_func(LOG.info) middleware_session.set_output_params(()) @@ -854,7 +914,9 @@ def list_module_exceptions(modstr): for name in imported_module.__dict__: if not inspect.isclass(imported_module.__dict__[name]): continue - if issubclass(imported_module.__dict__[name], Exception) or name.endswith('Error'): + if issubclass(imported_module.__dict__[name], Exception) or name.endswith( + "Error" + ): module_exceptions.append(modstr + "." + name) return module_exceptions @@ -862,11 +924,14 @@ def list_module_exceptions(modstr): modules = [] if not modules else modules for module in modules: exceptions += list_module_exceptions(module) - LOG.debug("Registering the following exceptions for deserialization: %s", - ", ".join(exceptions)) + LOG.debug( + "Registering the following exceptions for deserialization: %s", + ", ".join(exceptions), + ) class RemoteCustomException(Exception): """Standard class to instantiate during remote expectation deserialization.""" + __customclass__ = None def recreate_exception(class_name, class_dict): diff --git a/aexpect/rss_client.py b/aexpect/rss_client.py index 344bfbf..b7bf237 100644 --- a/aexpect/rss_client.py +++ b/aexpect/rss_client.py @@ -25,13 +25,14 @@ # ..todo:: we could reduce the disabled issues after more significant refactoring from __future__ import division, print_function + +import argparse +import glob +import os import socket import struct -import time import sys -import os -import glob -import argparse +import time # Globals CHUNKSIZE = 65536 @@ -63,8 +64,7 @@ def __init__(self, msg, e=None, filename=None): def __str__(self): errmsg = self.msg if self.error and self.filename: - errmsg += (f" (error: {self.error}," - f" filename: {self.filename})") + errmsg += f" (error: {self.error}," f" filename: {self.filename})" elif self.error: errmsg += f" ({self.error})" elif self.filename: @@ -122,24 +122,25 @@ def __init__(self, address, port, log_func=None, timeout=20): :param timeout: Time duration to wait for connection to succeed :raise FileTransferConnectError: Raised if the connection fails """ - family = socket.AF_INET6 if ':' in address else socket.AF_INET + family = socket.AF_INET6 if ":" in address else socket.AF_INET self._socket = socket.socket(family, socket.SOCK_STREAM) self._socket.settimeout(timeout) try: - addrinfo = socket.getaddrinfo(address, port, family, - socket.SOCK_STREAM, - socket.IPPROTO_TCP) + addrinfo = socket.getaddrinfo( + address, port, family, socket.SOCK_STREAM, socket.IPPROTO_TCP + ) self._socket.connect(addrinfo[0][4]) except socket.error as error: - raise FileTransferConnectError("Cannot connect to server at " - f"{address}:{port}", - error) from error + raise FileTransferConnectError( + "Cannot connect to server at " f"{address}:{port}", error + ) from error try: if self._receive_msg(timeout) != RSS_MAGIC: raise FileTransferConnectError("Received wrong magic number") except FileTransferTimeoutError as timeout_error: - raise FileTransferConnectError("Timeout expired while waiting to " - "receive magic number") from timeout_error + raise FileTransferConnectError( + "Timeout expired while waiting to receive magic number" + ) from timeout_error self._send(struct.pack("=i", CHUNKSIZE)) self._log_func = log_func self._last_time = time.time() @@ -162,11 +163,13 @@ def _send(self, data, timeout=60): self._socket.settimeout(timeout) self._socket.sendall(data) except socket.timeout as error: - raise FileTransferTimeoutError("Timeout expired while sending " - "data to server") from error + raise FileTransferTimeoutError( + "Timeout expired while sending data to server" + ) from error except socket.error as error: - raise FileTransferSocketError("Could not send data to server", - error) from error + raise FileTransferSocketError( + "Could not send data to server", error + ) from error def _receive(self, size, timeout=60): strs = [] @@ -179,29 +182,32 @@ def _receive(self, size, timeout=60): self._socket.settimeout(timeout) data = self._socket.recv(size) if not data: - raise FileTransferProtocolError("Connection closed " - "unexpectedly while " - "receiving data from " - "server") + raise FileTransferProtocolError( + "Connection closed " + "unexpectedly while " + "receiving data from " + "server" + ) strs.append(data) size -= len(data) except socket.timeout as error: - raise FileTransferTimeoutError("Timeout expired while receiving " - "data from server") from error + raise FileTransferTimeoutError( + "Timeout expired while receiving data from server" + ) from error except socket.error as error: - raise FileTransferSocketError("Error receiving data from server", - error) from error + raise FileTransferSocketError( + "Error receiving data from server", error + ) from error return b"".join(strs) def _report_stats(self, data): if self._log_func: delta = time.time() - self._last_time if delta >= 1: - transferred = self.transferred / 1048576. + transferred = self.transferred / 1048576.0 speed = (self.transferred - self._last_transferred) / delta - speed /= 1048576. - self._log_func(f"{data} {transferred:.3f} MB ({speed:.3f}" - " MB/sec)") + speed /= 1048576.0 + self._log_func(f"{data} {transferred:.3f} MB ({speed:.3f}" " MB/sec)") self._last_time = time.time() self._last_transferred = self.transferred @@ -358,9 +364,11 @@ def upload(self, src_pattern, dst_path, timeout=600): else: # If nothing was transferred, raise an exception if not matches: - raise FileTransferNotFoundError(f"Pattern {src_pattern} " - "does not match any files " - "or directories") + raise FileTransferNotFoundError( + f"Pattern {src_pattern} " + "does not match any files " + "or directories" + ) # Look for RSS_OK or RSS_ERROR msg = self._receive_msg(int(end_time - time.time())) if msg == RSS_OK: @@ -473,7 +481,8 @@ def download(self, src_pattern, dst_path, timeout=600): if not file_count and not dir_count: raise FileTransferNotFoundError( f"Pattern {src_pattern} does not match any files " - "or directories that could be downloaded") + "or directories that could be downloaded" + ) break elif msg == RSS_ERROR: # Receive error message and abort @@ -488,8 +497,9 @@ def download(self, src_pattern, dst_path, timeout=600): raise -def upload(address, port, src_pattern, dst_path, log_func=None, timeout=60, - connect_timeout=20): +def upload( + address, port, src_pattern, dst_path, log_func=None, timeout=60, connect_timeout=20 +): """ Connect to server and upload files. @@ -500,8 +510,9 @@ def upload(address, port, src_pattern, dst_path, log_func=None, timeout=60, client.close() -def download(address, port, src_pattern, dst_path, log_func=None, timeout=60, - connect_timeout=20): +def download( + address, port, src_pattern, dst_path, log_func=None, timeout=60, connect_timeout=20 +): """ Connect to server and upload files. @@ -520,18 +531,31 @@ def main(): parser.add_argument("port") parser.add_argument("src_pattern") parser.add_argument("dst_path") - parser.add_argument("-d", "--download", - action="store_true", dest="download", - help="download files from server") - parser.add_argument("-u", "--upload", - action="store_true", dest="upload", - help="upload files to server") - parser.add_argument("-v", "--verbose", - action="store_true", dest="verbose", - help="be verbose") - parser.add_argument("-t", "--timeout", - type=int, dest="timeout", default=3600, - help="transfer timeout") + parser.add_argument( + "-d", + "--download", + action="store_true", + dest="download", + help="download files from server", + ) + parser.add_argument( + "-u", + "--upload", + action="store_true", + dest="upload", + help="upload files to server", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", dest="verbose", help="be verbose" + ) + parser.add_argument( + "-t", + "--timeout", + type=int, + dest="timeout", + default=3600, + help="transfer timeout", + ) args = parser.parse_args() if args.download == args.upload: parser.error("you must specify either -d or -u") diff --git a/aexpect/shared.py b/aexpect/shared.py index 15cf095..6ef12ac 100644 --- a/aexpect/shared.py +++ b/aexpect/shared.py @@ -11,11 +11,11 @@ """Some shared functions""" -import os import fcntl +import os import termios -BASE_DIR = os.environ.get('TMPDIR', '/tmp') +BASE_DIR = os.environ.get("TMPDIR", "/tmp") def get_lock_fd(filename): @@ -59,14 +59,22 @@ def wait_for_lock(filename): def makeraw(shell_fd): """Turn console into 'raw' format""" attr = termios.tcgetattr(shell_fd) - attr[0] &= ~(termios.IGNBRK | termios.BRKINT | termios.PARMRK | - termios.ISTRIP | termios.INLCR | termios.IGNCR | - termios.ICRNL | termios.IXON) + attr[0] &= ~( + termios.IGNBRK + | termios.BRKINT + | termios.PARMRK + | termios.ISTRIP + | termios.INLCR + | termios.IGNCR + | termios.ICRNL + | termios.IXON + ) attr[1] &= ~termios.OPOST attr[2] &= ~(termios.CSIZE | termios.PARENB) attr[2] |= termios.CS8 - attr[3] &= ~(termios.ECHO | termios.ECHONL | termios.ICANON | - termios.ISIG | termios.IEXTEN) + attr[3] &= ~( + termios.ECHO | termios.ECHONL | termios.ICANON | termios.ISIG | termios.IEXTEN + ) termios.tcsetattr(shell_fd, termios.TCSANOW, attr) @@ -86,9 +94,16 @@ def makestandard(shell_fd, echo): def get_filenames(base_dir): """Get paths to files produced by aexpect in it's working dir""" - files = ("shell-pid", "status", "output", "inpipe", "ctrlpipe", - "lock-server-running", "lock-client-starting", - "server-log") + files = ( + "shell-pid", + "status", + "output", + "inpipe", + "ctrlpipe", + "lock-server-running", + "lock-client-starting", + "server-log", + ) return [os.path.join(base_dir, s) for s in files] diff --git a/aexpect/utils/astring.py b/aexpect/utils/astring.py index 4e58313..abf81f8 100644 --- a/aexpect/utils/astring.py +++ b/aexpect/utils/astring.py @@ -42,8 +42,7 @@ def strip_console_codes(output, custom_codes=None): while index < len(output): tmp_index = 0 tmp_word = "" - while (len(re.findall("\x1b", tmp_word)) < 2 and - index + tmp_index < len(output)): + while len(re.findall("\x1b", tmp_word)) < 2 and index + tmp_index < len(output): tmp_word += output[index + tmp_index] tmp_index += 1 @@ -55,12 +54,14 @@ def strip_console_codes(output, custom_codes=None): special_code = re.findall(console_codes, tmp_word)[0] except IndexError as error: if index + tmp_index < len(output): - raise ValueError(f"{tmp_word} is not included in the known " - "console codes list " - f"{console_codes}") from error + raise ValueError( + f"{tmp_word} is not included in the known " + "console codes list " + f"{console_codes}" + ) from error continue if special_code == tmp_word: continue old_word = tmp_word - return_str += tmp_word[len(special_code):] + return_str += tmp_word[len(special_code) :] return return_str diff --git a/aexpect/utils/data_factory.py b/aexpect/utils/data_factory.py index 3ddf41d..098fb19 100644 --- a/aexpect/utils/data_factory.py +++ b/aexpect/utils/data_factory.py @@ -17,8 +17,7 @@ _RAND_POOL = random.SystemRandom() -def generate_random_string(length, ignore=string.punctuation, - convert=""): +def generate_random_string(length, ignore=string.punctuation, convert=""): """ Generate a random string using alphanumeric characters. diff --git a/aexpect/utils/path.py b/aexpect/utils/path.py index 7508888..00e1295 100644 --- a/aexpect/utils/path.py +++ b/aexpect/utils/path.py @@ -29,8 +29,10 @@ def __init__(self, cmd, paths): self.paths = paths def __str__(self): - return (f"Command '{self.cmd}' could not be found in any of the PATH " - f"dirs: {self.paths}") + return ( + f"Command '{self.cmd}' could not be found in any of the PATH " + f"dirs: {self.paths}" + ) def find_command(cmd, default=None): @@ -44,12 +46,19 @@ def find_command(cmd, default=None): command was not found and no default was given. """ try: - path_paths = os.environ['PATH'].split(":") + path_paths = os.environ["PATH"].split(":") except IndexError: path_paths = [] - for common_path in ["/usr/libexec", "/usr/local/sbin", "/usr/local/bin", - "/usr/sbin", "/usr/bin", "/sbin", "/bin"]: + for common_path in [ + "/usr/libexec", + "/usr/local/sbin", + "/usr/local/bin", + "/usr/sbin", + "/usr/bin", + "/sbin", + "/bin", + ]: if common_path not in path_paths: path_paths.append(common_path) diff --git a/aexpect/utils/process.py b/aexpect/utils/process.py index 0f85b97..d8cc0e2 100644 --- a/aexpect/utils/process.py +++ b/aexpect/utils/process.py @@ -11,15 +11,16 @@ """Process handling helpers""" -import subprocess -import signal import os +import signal +import subprocess def getoutput(cmd): """Executes command and returns stdout+stderr without tailing \n\r""" - with subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, - stderr=subprocess.PIPE) as proc: + with subprocess.Popen( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) as proc: return proc.communicate()[0].decode().rstrip("\n\r") @@ -84,7 +85,7 @@ def get_children_pids(ppid): param ppid: parent PID return: list of PIDs of all children/threads of ppid """ - return getoutput(f"ps -L --ppid={int(ppid)} -o lwp").split('\n')[1:] + return getoutput(f"ps -L --ppid={int(ppid)} -o lwp").split("\n")[1:] def process_in_ptree_is_defunct(ppid): @@ -104,7 +105,7 @@ def process_in_ptree_is_defunct(ppid): for pid in pids: cmd = f"ps --no-headers -o cmd {int(pid)}" proc_name = getoutput(cmd) - if '' in proc_name: + if "" in proc_name: defunct = True break return defunct diff --git a/aexpect/utils/wait.py b/aexpect/utils/wait.py index 04c5011..15acf8c 100644 --- a/aexpect/utils/wait.py +++ b/aexpect/utils/wait.py @@ -11,8 +11,8 @@ """Module that helps waiting for conditions""" -import time import logging +import time _LOG = logging.getLogger(__file__) diff --git a/scripts/aexpect_helper b/scripts/aexpect_helper index 4a41d49..0884db2 100755 --- a/scripts/aexpect_helper +++ b/scripts/aexpect_helper @@ -15,21 +15,23 @@ Helper script that runs and interacts with the process executed by aexpect """ -import os -import sys import logging +import os import pty -import tempfile import select +import sys +import tempfile -from aexpect.shared import BASE_DIR -from aexpect.shared import get_filenames -from aexpect.shared import get_reader_filename -from aexpect.shared import get_lock_fd -from aexpect.shared import unlock_fd -from aexpect.shared import wait_for_lock -from aexpect.shared import makestandard -from aexpect.shared import makeraw +from aexpect.shared import ( + BASE_DIR, + get_filenames, + get_lock_fd, + get_reader_filename, + makeraw, + makestandard, + unlock_fd, + wait_for_lock, +) def main(): # too-many-* pylint:disable=R0914,R0912,R0915 diff --git a/setup.py b/setup.py index 4cdd551..50c67b0 100644 --- a/setup.py +++ b/setup.py @@ -17,22 +17,23 @@ from setuptools import setup -if __name__ == '__main__': - setup(name='aexpect', - version='1.7.0', - description='Aexpect', - author='Aexpect developers', - author_email='avocado-devel@redhat.com', - url='http://avocado-framework.github.io/', - license="GPLv2+", - classifiers=[ +if __name__ == "__main__": + setup( + name="aexpect", + version="1.7.0", + description="Aexpect", + author="Aexpect developers", + author_email="avocado-devel@redhat.com", + url="http://avocado-framework.github.io/", + license="GPLv2+", + classifiers=[ "Development Status :: 6 - Mature", "License :: OSI Approved :: GNU General Public License v2 or later (GPLv2+)", "Natural Language :: English", "Operating System :: POSIX", "Programming Language :: Python :: 3", - ], - packages=['aexpect', - 'aexpect.utils'], - scripts=['scripts/aexpect_helper'], - test_suite='tests') + ], + packages=["aexpect", "aexpect.utils"], + scripts=["scripts/aexpect_helper"], + test_suite="tests", + ) diff --git a/static-checks b/static-checks new file mode 160000 index 0000000..1484a86 --- /dev/null +++ b/static-checks @@ -0,0 +1 @@ +Subproject commit 1484a86d9de27be65f665f21c43e33fabb643bd7 diff --git a/tests/test_client.py b/tests/test_client.py index d7fce27..feff19f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -24,13 +24,11 @@ class ClientTest(unittest.TestCase): - def test_client_spawn(self): """ Tests the basic spawning of an interactive process """ - key = "".join([random.choice(string.ascii_uppercase) - for _ in range(10)]) + key = "".join([random.choice(string.ascii_uppercase) for _ in range(10)]) python = client.Spawn(sys.executable) self.assertTrue(python.is_alive()) python.sendline(f"print('{key}')") @@ -41,24 +39,37 @@ def test_client_spawn(self): class CommandsTests(unittest.TestCase): - def setUp(self): - non_get_cmds = ('get_id', 'get_output', 'get_pid', 'get_status', - 'get_stripped_output') - self.cmds = [cmd for cmd in dir(client.ShellSession) - if cmd.startswith('get') and cmd not in non_get_cmds] - self.cmds.extend(cmd for cmd in dir(client.ShellSession) - if cmd.startswith("cmd")) + non_get_cmds = ( + "get_id", + "get_output", + "get_pid", + "get_status", + "get_stripped_output", + ) + self.cmds = [ + cmd + for cmd in dir(client.ShellSession) + if cmd.startswith("get") and cmd not in non_get_cmds + ] + self.cmds.extend( + cmd for cmd in dir(client.ShellSession) if cmd.startswith("cmd") + ) def test_cmd_true(self): """Check that the true command finishes properly""" for cmd in self.cmds: - if cmd in ('get_id', 'get_output', 'get_pid', 'get_status', - 'get_stripped_output'): + if cmd in ( + "get_id", + "get_output", + "get_pid", + "get_status", + "get_stripped_output", + ): # These are not commands continue session = client.ShellSession("sh") - getattr(session, cmd)('true') + getattr(session, cmd)("true") def test_cmd_terminated(self): """ @@ -66,8 +77,13 @@ def test_cmd_terminated(self): raised """ for cmd in self.cmds: - if cmd in ('get_id', 'get_output', 'get_pid', 'get_status', - 'get_stripped_output'): + if cmd in ( + "get_id", + "get_output", + "get_pid", + "get_status", + "get_stripped_output", + ): # These are not commands continue session = client.ShellSession("sh") @@ -80,50 +96,62 @@ def test_cmd_terminated(self): # command will be processed after the helper realizes # it's dead. out = getattr(session, cmd)(f"kill {session.get_pid()}") - out += getattr(session, cmd)('true') - self.fail("Killed session did not produce 'ShellError' using " - f"command {cmd} ({self.cmds})\n{out}") + out += getattr(session, cmd)("true") + self.fail( + "Killed session did not produce 'ShellError' using " + f"command {cmd} ({self.cmds})\n{out}" + ) except client.ShellError as details: if cmd in ("cmd_output", "cmd_output_safe"): - if not isinstance(details, - client.ShellProcessTerminatedError): - self.fail(f"Incorrect exception '{details}' " - f"({type(details)}) was raised using command" - f" {cmd} ({self.cmds})\n{out}") + if not isinstance(details, client.ShellProcessTerminatedError): + self.fail( + f"Incorrect exception '{details}' " + f"({type(details)}) was raised using command" + f" {cmd} ({self.cmds})\n{out}" + ) def test_cmd_timeout(self): """Check that 0s timeout timeouts""" for cmd in self.cmds: - if cmd in ('get_id', 'get_output', 'get_pid', 'get_status', - 'get_stripped_output'): + if cmd in ( + "get_id", + "get_output", + "get_pid", + "get_status", + "get_stripped_output", + ): # These are not commands continue session = client.ShellSession("sh") try: - execute = (f"{sys.executable} -c " - "'import time; time.sleep(10)'") + execute = f"{sys.executable} -c " "'import time; time.sleep(10)'" out = getattr(session, cmd)(execute, timeout=0) - self.fail("Killed session did not produce 'ShellError' using " - f"command {cmd} ({self.cmds})\n{out}") + self.fail( + "Killed session did not produce 'ShellError' using " + f"command {cmd} ({self.cmds})\n{out}" + ) except client.ShellError as details: if cmd in ("cmd_output", "cmd_output_safe"): - if not isinstance(details, - client.ShellTimeoutError): - self.fail(f"Incorrect exception '{details}' " - f"({type(details)}) was raised " - f"using command {cmd} ({self.cmds})") - - @unittest.skipUnless(os.environ.get('AEXPECT_TIME_SENSITIVE'), - "AEXPECT_TIME_SENSITIVE env variable not set") + if not isinstance(details, client.ShellTimeoutError): + self.fail( + f"Incorrect exception '{details}' " + f"({type(details)}) was raised " + f"using command {cmd} ({self.cmds})" + ) + + @unittest.skipUnless( + os.environ.get("AEXPECT_TIME_SENSITIVE"), + "AEXPECT_TIME_SENSITIVE env variable not set", + ) def test_cmd_output_with_inner_timeout(self): """ cmd_output_safe uses 0.5s inner timeout, make sure all lines are present in the output. """ session = client.ShellSession("sh") - out = session.cmd_output_safe("echo FIRST LINE; sleep 2; " - "echo SECOND LINE; sleep 2; " - "echo THIRD LINE") + out = session.cmd_output_safe( + "echo FIRST LINE; sleep 2; echo SECOND LINE; sleep 2; echo THIRD LINE" + ) self.assertIn("FIRST LINE", out) self.assertIn("SECOND LINE", out) self.assertIn("THIRD LINE", out) @@ -132,6 +160,7 @@ def test_fd_leak(self): """ Check file descriptors are not being leaked """ + def get_proc_fds(): """ Returns a set containing the fd names opened under the process @@ -147,10 +176,12 @@ def get_proc_fds(): session = client.ShellSession("sh") session.close() fds_after = get_proc_fds() - self.assertEqual(fds_after, fds_before, - msg="fd leak: Closing the session didn't close " - "the file descriptors") + self.assertEqual( + fds_after, + fds_before, + msg="fd leak: Closing the session didn't close " "the file descriptors", + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_pass_fds.py b/tests/test_pass_fds.py index 3e75171..000ac06 100644 --- a/tests/test_pass_fds.py +++ b/tests/test_pass_fds.py @@ -19,13 +19,11 @@ from aexpect import client -LIST_FD_CMD = ('''python -c "import os; os.system('ls -l /proc/%d/fd' % ''' - '''os.getpid())"''') +LIST_FD_CMD = """python -c 'import os; os.system("ls -l /proc/%d/fd" % os.getpid())'""" class PassfdsTest(unittest.TestCase): - - @unittest.skipUnless(os.path.exists('/proc/1/fd'), "requires Linux") + @unittest.skipUnless(os.path.exists("/proc/1/fd"), "requires Linux") def test_pass_fds_spawn(self): """ Tests fd passing for `client.Spawn` @@ -34,17 +32,15 @@ def test_pass_fds_spawn(self): fd_null = devnull.fileno() child = client.Spawn(LIST_FD_CMD) - self.assertFalse(bool(child.get_status()), - "child terminated abnormally") + self.assertFalse(bool(child.get_status()), "child terminated abnormally") self.assertNotIn(os.devnull, child.get_output()) child.close() child = client.Spawn(LIST_FD_CMD, pass_fds=[fd_null]) - self.assertFalse(bool(child.get_status()), - "child terminated abnormally") + self.assertFalse(bool(child.get_status()), "child terminated abnormally") self.assertIn(os.devnull, child.get_output()) child.close() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/tests/test_remote_door.py b/tests/test_remote_door.py index 7d93520..f61af6a 100644 --- a/tests/test_remote_door.py +++ b/tests/test_remote_door.py @@ -14,8 +14,8 @@ # # selftests pylint: disable=C0111,C0111,W0613,R0913 -import os import glob +import os import re import shutil import unittest.mock @@ -27,22 +27,46 @@ # noinspection PyUnusedLocal -def _local_login(client, host, port, username, password, prompt, - linesep="\n", log_filename=None, log_function=None, - timeout=10, internal_timeout=10, interface=None): +def _local_login( + client, + host, + port, + username, + password, + prompt, + linesep="\n", + log_filename=None, + log_function=None, + timeout=10, + internal_timeout=10, + interface=None, +): return RemoteSession("sh", prompt=prompt, client=client) # noinspection PyUnusedLocal -def _local_copy(address, client, username, password, port, local_path, - remote_path, limit="", log_filename=None, log_function=None, - verbose=False, timeout=600, interface=None, filesize=None, - directory=True): +def _local_copy( + address, + client, + username, + password, + port, + local_path, + remote_path, + limit="", + log_filename=None, + log_function=None, + verbose=False, + timeout=600, + interface=None, + filesize=None, + directory=True, +): shutil.copy(local_path, remote_path) -@mock.patch('aexpect.remote_door.remote.copy_files_to', _local_copy) -@mock.patch('aexpect.remote_door.remote.wait_for_login', _local_login) +@mock.patch("aexpect.remote_door.remote.copy_files_to", _local_copy) +@mock.patch("aexpect.remote_door.remote.wait_for_login", _local_login) class RemoteDoorTest(unittest.TestCase): """Unit test class for the remote door.""" @@ -54,10 +78,13 @@ def setUp(self): def tearDown(self): for control_file in glob.glob("tmp*.control"): os.unlink(control_file) - for control_file in glob.glob(os.path.join(remote_door.REMOTE_CONTROL_DIR, - "tmp*.control")): + for control_file in glob.glob( + os.path.join(remote_door.REMOTE_CONTROL_DIR, "tmp*.control") + ): os.unlink(control_file) - deployed_remote_door = os.path.join(remote_door.REMOTE_PYTHON_PATH, "remote_door.py") + deployed_remote_door = os.path.join( + remote_door.REMOTE_PYTHON_PATH, "remote_door.py" + ) if os.path.exists(deployed_remote_door): os.unlink(deployed_remote_door) os.rmdir(remote_door.REMOTE_PYTHON_PATH) @@ -68,12 +95,14 @@ def test_run_remote_util(self): result = remote_door.run_remote_util(self.session, "math", "gcd", 2, 3) self.assertEqual(int(result), 1) local_controls = glob.glob("tmp*.control") - remote_controls = glob.glob(os.path.join(remote_door.REMOTE_CONTROL_DIR, - "tmp*.control")) + remote_controls = glob.glob( + os.path.join(remote_door.REMOTE_CONTROL_DIR, "tmp*.control") + ) self.assertEqual(len(local_controls), len(remote_controls)) self.assertEqual(len(remote_controls), 1) - self.assertEqual(os.path.basename(local_controls[0]), - os.path.basename(remote_controls[0])) + self.assertEqual( + os.path.basename(local_controls[0]), os.path.basename(remote_controls[0]) + ) with open(remote_controls[0], encoding="utf-8") as handle: control_lines = handle.readlines() self.assertIn("import math\n", control_lines) @@ -81,43 +110,56 @@ def test_run_remote_util(self): def test_run_remote_util_arg_types(self): """Test that a remote utility runs properly with different argument types.""" - result = remote_door.run_remote_util(self.session, "json", "dumps", - ["foo", {"bar": ["baz", None, 1.0, 2]}], - skipkeys=False, separators=None, - # must be boolean but we want to test string - allow_nan="string for yes") + result = remote_door.run_remote_util( + self.session, + "json", + "dumps", + ["foo", {"bar": ["baz", None, 1.0, 2]}], + skipkeys=False, + separators=None, + # must be boolean but we want to test string + allow_nan="string for yes", + ) self.assertEqual(result, '["foo", {"bar": ["baz", null, 1.0, 2]}]') local_controls = glob.glob("tmp*.control") - remote_controls = glob.glob(os.path.join(remote_door.REMOTE_CONTROL_DIR, - "tmp*.control")) + remote_controls = glob.glob( + os.path.join(remote_door.REMOTE_CONTROL_DIR, "tmp*.control") + ) self.assertEqual(len(local_controls), len(remote_controls)) self.assertEqual(len(remote_controls), 1) - self.assertEqual(os.path.basename(local_controls[0]), - os.path.basename(remote_controls[0])) + self.assertEqual( + os.path.basename(local_controls[0]), os.path.basename(remote_controls[0]) + ) with open(remote_controls[0], encoding="utf-8") as handle: control_lines = handle.readlines() self.assertIn("import json\n", control_lines) - self.assertIn("result = json.dumps(['foo', {'bar': ['baz', None, 1.0, 2]}], " - "allow_nan=r'string for yes', separators=None, skipkeys=False)\n", - control_lines) + self.assertIn( + "result = json.dumps(['foo', {'bar': ['baz', None, 1.0, 2]}], " + "allow_nan=r'string for yes', separators=None, skipkeys=False)\n", + control_lines, + ) def test_run_remote_util_object(self): """Test that a remote utility object runs properly.""" - result = remote_door.run_remote_util(self.session, "collections", - "OrderedDict().get", "akey") + result = remote_door.run_remote_util( + self.session, "collections", "OrderedDict().get", "akey" + ) self.assertEqual(result, "None") local_controls = glob.glob("tmp*.control") - remote_controls = glob.glob(os.path.join(remote_door.REMOTE_CONTROL_DIR, - "tmp*.control")) + remote_controls = glob.glob( + os.path.join(remote_door.REMOTE_CONTROL_DIR, "tmp*.control") + ) self.assertEqual(len(local_controls), len(remote_controls)) self.assertEqual(len(remote_controls), 1) - self.assertEqual(os.path.basename(local_controls[0]), - os.path.basename(remote_controls[0])) + self.assertEqual( + os.path.basename(local_controls[0]), os.path.basename(remote_controls[0]) + ) with open(remote_controls[0], encoding="utf-8") as handle: control_lines = handle.readlines() - self.assertIn("result = collections.OrderedDict().get(r'akey')\n", - control_lines) + self.assertIn( + "result = collections.OrderedDict().get(r'akey')\n", control_lines + ) def test_run_remote_decorator(self): """Test that a remote decorated function runs properly.""" @@ -137,12 +179,14 @@ def do_nothing(): self.assertEqual(int(result), 4) local_controls = glob.glob("tmp*.control") - remote_controls = glob.glob(os.path.join(remote_door.REMOTE_CONTROL_DIR, - "tmp*.control")) + remote_controls = glob.glob( + os.path.join(remote_door.REMOTE_CONTROL_DIR, "tmp*.control") + ) self.assertEqual(len(local_controls), len(remote_controls)) self.assertEqual(len(remote_controls), 1) - self.assertEqual(os.path.basename(local_controls[0]), - os.path.basename(remote_controls[0])) + self.assertEqual( + os.path.basename(local_controls[0]), os.path.basename(remote_controls[0]) + ) with open(remote_controls[0], encoding="utf-8") as handle: control_lines = handle.readlines() self.assertIn("def add_one(number):\n", control_lines) @@ -150,11 +194,14 @@ def do_nothing(): def test_get_remote_object(self): """Test that a remote object can be retrieved properly.""" - self.session = mock.MagicMock(name='session') + self.session = mock.MagicMock(name="session") self.session.client = "ssh" remote_door.Pyro4 = mock.MagicMock() disconnect = remote_door.Pyro4.errors.PyroError = Exception - remote_door.Pyro4.Proxy.side_effect = [disconnect("no such object"), mock.DEFAULT] + remote_door.Pyro4.Proxy.side_effect = [ + disconnect("no such object"), + mock.DEFAULT, + ] self.session.get_output.return_value = "Local object sharing ready\n" self.session.get_output.return_value += "RESULT = None\n" @@ -169,37 +216,42 @@ def test_get_remote_object(self): self.assertIsNotNone(match, "A control file has to be called on the peer side") control_file = match.group(1) local_controls = glob.glob("tmp*.control") - remote_controls = glob.glob(os.path.join(remote_door.REMOTE_CONTROL_DIR, - "tmp*.control")) + remote_controls = glob.glob( + os.path.join(remote_door.REMOTE_CONTROL_DIR, "tmp*.control") + ) self.assertEqual(len(local_controls), len(remote_controls)) self.assertEqual(len(remote_controls), 1) - self.assertEqual(os.path.basename(local_controls[0]), - os.path.basename(remote_controls[0])) + self.assertEqual( + os.path.basename(local_controls[0]), os.path.basename(remote_controls[0]) + ) self.assertEqual(control_file, os.path.basename(remote_controls[0])) with open(control_file, encoding="utf-8") as handle: control_lines = handle.readlines() self.assertIn("import remote_door\n", control_lines) - self.assertIn("result = remote_door.share_local_object(r'html', " - "host=r'testhost', port=4242)\n", - control_lines) + self.assertIn( + "result = remote_door.share_local_object(r'html', " + "host=r'testhost', port=4242)\n", + control_lines, + ) # since the local run was face redo it here remote_door.share_local_object("html", None, "testhost", 4242) def test_share_remote_objects(self): """Test that a remote object can be shared properly and remotely.""" - self.session = mock.MagicMock(name='session') + self.session = mock.MagicMock(name="session") self.session.client = "ssh" remote_door.Pyro4 = mock.MagicMock() - control_file = os.path.join(remote_door.REMOTE_CONTROL_DIR, - "tmpxxxxxxxx.control") + control_file = os.path.join( + remote_door.REMOTE_CONTROL_DIR, "tmpxxxxxxxx.control" + ) with open(control_file, "wt", encoding="utf-8") as handle: handle.write("print('Remote objects shared over the network')") - middleware = remote_door.share_remote_objects(self.session, control_file, - "testhost", 4242, - os_type="linux") + middleware = remote_door.share_remote_objects( + self.session, control_file, "testhost", 4242, os_type="linux" + ) # we just test dummy initialization for the remote object control server middleware.close() @@ -213,9 +265,11 @@ def test_share_remote_objects(self): def test_import_remote_exceptions(self): """Test that selected remote exceptions are properly imported and deserialized.""" remote_door.Pyro4 = mock.MagicMock() - preselected_exceptions = ["aexpect.remote.RemoteError", - "aexpect.remote.LoginError", - "aexpect.remote.TransferError"] + preselected_exceptions = [ + "aexpect.remote.RemoteError", + "aexpect.remote.LoginError", + "aexpect.remote.TransferError", + ] remote_door.import_remote_exceptions(preselected_exceptions) register_method = remote_door.Pyro4.util.SerializerBase.register_dict_to_class self.assertEqual(len(register_method.mock_calls), 3)