Skip to content

Commit

Permalink
driver/sshdriver: extend with su handling
Browse files Browse the repository at this point in the history
Add two new attributes to the driver which will use su to switch to a
user to run a command. The su_password is required for this feature to
be used, su_username only needs to be set if another user than root
should be switched to.

Signed-off-by: Rouven Czerwinski <[email protected]>
Co-developed-by: Jan Luebbe <[email protected]>
  • Loading branch information
Emantor committed Aug 1, 2023
1 parent b0fe8dc commit 6720c2f
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 3 deletions.
7 changes: 7 additions & 0 deletions doc/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,13 @@ Arguments:
target.
- explicit_sftp_mode (bool, default=False): if set to True, `put()` and `get()` will
explicitly use the SFTP protocol for file transfers instead of scp's default protocol
- su_username(str, default="root"): only used if su_password is set
- su_prompt(str, default="Passowrd:"): prompt string for su
- su_password(str, default=None): su password for the user set via su_username

.. note::
Using the su support will automatically enable stderr_merge, since ssh this
is required to interact with the password prompt.

UBootDriver
~~~~~~~~~~~
Expand Down
11 changes: 11 additions & 0 deletions examples/ssh-su-example/conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
targets:
main:
resources:
NetworkService:
address: 127.0.0.1
username: <login_username>
drivers:
SSHDriver:
su_password: <the_password>
su_username: <the_username>
stderr_merge: true
14 changes: 14 additions & 0 deletions examples/ssh-su-example/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging

from labgrid import Environment

logging.basicConfig(
level=logging.DEBUG
)

env = Environment("conf.yaml")
target = env.get_target()
ssh = target.get_driver("SSHDriver")
out, _, code = ssh.run("ps -p $PPID")
print(code)
print(out)
72 changes: 69 additions & 3 deletions labgrid/driver/sshdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import time

import attr
from pexpect.fdpexpect import fdspawn
from pexpect.exceptions import EOF, TIMEOUT

from ..factory import target_factory
from ..protocol import CommandProtocol, FileTransferProtocol
Expand All @@ -22,6 +24,7 @@
from ..util.proxy import proxymanager
from ..util.timeout import Timeout
from ..util.ssh import get_ssh_connect_timeout
from ..util.marker import gen_marker


@target_factory.reg_driver
Expand All @@ -34,6 +37,9 @@ class SSHDriver(CommandMixin, Driver, CommandProtocol, FileTransferProtocol):
stderr_merge = attr.ib(default=False, validator=attr.validators.instance_of(bool))
connection_timeout = attr.ib(default=float(get_ssh_connect_timeout()), validator=attr.validators.instance_of(float))
explicit_sftp_mode = attr.ib(default=False, validator=attr.validators.instance_of(bool))
su_password = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(str)))
su_username = attr.ib(default="root", validator=attr.validators.instance_of(str))
su_prompt = attr.ib(default="Password:", validator=attr.validators.instance_of(str))

def __attrs_post_init__(self):
super().__attrs_post_init__()
Expand Down Expand Up @@ -180,6 +186,40 @@ def _start_own_master_once(self, timeout):
def run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None):
return self._run(cmd, codec=codec, decodeerrors=decodeerrors, timeout=timeout)

def handle_password(self, fd, stdin, marker):
p = fdspawn(fd, timeout=15)
try:
p.expect([f"{marker}\n"])
except TIMEOUT:
raise ExecutionError(f"Failed to find marker before su: {p.buffer!r}")
except EOF:
raise ExecutionError("Unexpected disconnect before su")

try:
index = p.expect([f"{marker}\n", self.su_prompt])
except TIMEOUT:
raise ExecutionError(f"Unexpected output from su: {p.buffer!r}")
except EOF:
raise ExecutionError("Unexpected disconnect after starting su")

if index == 0:
# no password needed
return p.after

stdin.write(f"{self.su_password}".encode("utf-8"))
# It seems we need to close stdin here to reliably get su to accept the
# password. \n doesn't seem to work.
stdin.close()

try:
p.expect([f"{marker}\n"])
except TIMEOUT:
raise ExecutionError(f"Unexpected output from su after entering password: {p.buffer!r}")
except EOF:
raise ExecutionError(f"Unexpected disconnect after after entering su password: {p.before!r}")

return p.after

def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None):
"""Execute `cmd` on the target.
Expand All @@ -196,22 +236,48 @@ def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None):
complete_cmd = ["ssh", "-x", *self.ssh_prefix,
"-p", str(self.networkservice.port), "-l", self.networkservice.username,
self.networkservice.address
] + cmd.split(" ")
]
if self.su_password:
self.stderr_merge = True # with -tt, we get all output on stdout
marker = gen_marker()
complete_cmd += ["-tt", "--", "echo", f"{marker};", "su", self.su_username, "--"]
inner_cmd = f"echo '{marker[:4]}''{marker[4:]}'; {cmd}"
complete_cmd += ["-c", shlex.quote(inner_cmd)]
else:
complete_cmd += ["--"] + cmd.split(" ")

self.logger.debug("Sending command: %s", complete_cmd)
if self.stderr_merge:
stderr_pipe = subprocess.STDOUT
else:
stderr_pipe = subprocess.PIPE
stdin = subprocess.PIPE if self.su_password else None
stdout, stderr = b"", b""
try:
sub = subprocess.Popen(
complete_cmd, stdout=subprocess.PIPE, stderr=stderr_pipe
complete_cmd, stdout=subprocess.PIPE, stderr=stderr_pipe, stdin=stdin,
)
except:
raise ExecutionError(
f"error executing command: {complete_cmd}"
)

stdout, stderr = sub.communicate(timeout=timeout)
if self.su_password:
fd = sub.stdout if self.stderr_merge else sub.stderr
output = self.handle_password(fd, sub.stdin, marker)
sub.stdin.close()
self.logger.debug(f"su leftover output: %s", output)
if self.stderr_merge:
stderr += output
else:
stdout += output

sub.stdin = None # never try to write to stdin here
comout, comerr = sub.communicate(timeout=timeout)
stdout += comout
if not self.stderr_merge:
stderr += comerr

stdout = stdout.decode(codec, decodeerrors).split('\n')
if stdout[-1] == '':
stdout.pop()
Expand Down

0 comments on commit 6720c2f

Please sign in to comment.