diff --git a/doc/configuration.rst b/doc/configuration.rst index 553418d03..addce0b29 100644 --- a/doc/configuration.rst +++ b/doc/configuration.rst @@ -1573,6 +1573,13 @@ Arguments: target. - explicit_sftp_mode (bool, default=False): if set to True, `put()` and `get()` will explicitly use the SFTP protocol for file transfers instead of scp's default protocol + - su_username(str, default="root"): only used if su_password is set + - su_prompt(str, default="Passowrd:"): prompt string for su + - su_password(str, default=None): su password for the user set via su_username + +.. note:: + Using the su support will automatically enable stderr_merge, since ssh this + is required to interact with the password prompt. UBootDriver ~~~~~~~~~~~ diff --git a/examples/ssh-su-example/conf.yaml b/examples/ssh-su-example/conf.yaml new file mode 100644 index 000000000..18844c9f1 --- /dev/null +++ b/examples/ssh-su-example/conf.yaml @@ -0,0 +1,11 @@ +targets: + main: + resources: + NetworkService: + address: 127.0.0.1 + username: + drivers: + SSHDriver: + su_password: + su_username: + stderr_merge: true diff --git a/examples/ssh-su-example/test.py b/examples/ssh-su-example/test.py new file mode 100644 index 000000000..1f7792c0b --- /dev/null +++ b/examples/ssh-su-example/test.py @@ -0,0 +1,14 @@ +import logging + +from labgrid import Environment + +logging.basicConfig( + level=logging.DEBUG +) + +env = Environment("conf.yaml") +target = env.get_target() +ssh = target.get_driver("SSHDriver") +out, _, code = ssh.run("ps -p $PPID") +print(code) +print(out) diff --git a/labgrid/driver/sshdriver.py b/labgrid/driver/sshdriver.py index 3ad6fafa5..52a48e6e9 100644 --- a/labgrid/driver/sshdriver.py +++ b/labgrid/driver/sshdriver.py @@ -11,6 +11,8 @@ import time import attr +from pexpect.fdpexpect import fdspawn +from pexpect.exceptions import EOF, TIMEOUT from ..factory import target_factory from ..protocol import CommandProtocol, FileTransferProtocol @@ -22,6 +24,7 @@ from ..util.proxy import proxymanager from ..util.timeout import Timeout from ..util.ssh import get_ssh_connect_timeout +from ..util.marker import gen_marker @target_factory.reg_driver @@ -34,6 +37,9 @@ class SSHDriver(CommandMixin, Driver, CommandProtocol, FileTransferProtocol): stderr_merge = attr.ib(default=False, validator=attr.validators.instance_of(bool)) connection_timeout = attr.ib(default=float(get_ssh_connect_timeout()), validator=attr.validators.instance_of(float)) explicit_sftp_mode = attr.ib(default=False, validator=attr.validators.instance_of(bool)) + su_password = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(str))) + su_username = attr.ib(default="root", validator=attr.validators.instance_of(str)) + su_prompt = attr.ib(default="Password:", validator=attr.validators.instance_of(str)) def __attrs_post_init__(self): super().__attrs_post_init__() @@ -180,6 +186,40 @@ def _start_own_master_once(self, timeout): def run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): return self._run(cmd, codec=codec, decodeerrors=decodeerrors, timeout=timeout) + def handle_password(self, fd, stdin, marker): + p = fdspawn(fd, timeout=15) + try: + p.expect([f"{marker}\n"]) + except TIMEOUT: + raise ExecutionError(f"Failed to find marker before su: {p.buffer!r}") + except EOF: + raise ExecutionError("Unexpected disconnect before su") + + try: + index = p.expect([f"{marker}\n", self.su_prompt]) + except TIMEOUT: + raise ExecutionError(f"Unexpected output from su: {p.buffer!r}") + except EOF: + raise ExecutionError("Unexpected disconnect after starting su") + + if index == 0: + # no password needed + return p.after + + stdin.write(f"{self.su_password}".encode("utf-8")) + # It seems we need to close stdin here to reliably get su to accept the + # password. \n doesn't seem to work. + stdin.close() + + try: + p.expect([f"{marker}\n"]) + except TIMEOUT: + raise ExecutionError(f"Unexpected output from su after entering password: {p.buffer!r}") + except EOF: + raise ExecutionError(f"Unexpected disconnect after after entering su password: {p.before!r}") + + return p.after + def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): """Execute `cmd` on the target. @@ -196,22 +236,48 @@ def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None): complete_cmd = ["ssh", "-x", *self.ssh_prefix, "-p", str(self.networkservice.port), "-l", self.networkservice.username, self.networkservice.address - ] + cmd.split(" ") + ] + if self.su_password: + self.stderr_merge = True # with -tt, we get all output on stdout + marker = gen_marker() + complete_cmd += ["-tt", "--", "echo", f"{marker};", "su", self.su_username, "--"] + inner_cmd = f"echo '{marker[:4]}''{marker[4:]}'; {cmd}" + complete_cmd += ["-c", shlex.quote(inner_cmd)] + else: + complete_cmd += ["--"] + cmd.split(" ") + self.logger.debug("Sending command: %s", complete_cmd) if self.stderr_merge: stderr_pipe = subprocess.STDOUT else: stderr_pipe = subprocess.PIPE + stdin = subprocess.PIPE if self.su_password else None + stdout, stderr = b"", b"" try: sub = subprocess.Popen( - complete_cmd, stdout=subprocess.PIPE, stderr=stderr_pipe + complete_cmd, stdout=subprocess.PIPE, stderr=stderr_pipe, stdin=stdin, ) except: raise ExecutionError( f"error executing command: {complete_cmd}" ) - stdout, stderr = sub.communicate(timeout=timeout) + if self.su_password: + fd = sub.stdout if self.stderr_merge else sub.stderr + output = self.handle_password(fd, sub.stdin, marker) + sub.stdin.close() + self.logger.debug(f"su leftover output: %s", output) + if self.stderr_merge: + stderr += output + else: + stdout += output + + sub.stdin = None # never try to write to stdin here + comout, comerr = sub.communicate(timeout=timeout) + stdout += comout + if not self.stderr_merge: + stderr += comerr + stdout = stdout.decode(codec, decodeerrors).split('\n') if stdout[-1] == '': stdout.pop()