Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSHDriver: implement user switching via su #1220

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,13 @@ Arguments:
target.
- explicit_sftp_mode (bool, default=False): if set to True, `put()` and `get()` will
explicitly use the SFTP protocol for file transfers instead of scp's default protocol
- su_username(str, default="root"): only used if su_password is set
- su_prompt(str, default="Password:"): prompt string for su
- su_password(str): optional, su password for the user set via su_username

.. note::
Using the su support will automatically enable ``stderr_merge``, since this
is required to interact with the password prompt.

UBootDriver
~~~~~~~~~~~
Expand Down
11 changes: 11 additions & 0 deletions examples/ssh-su-example/conf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
targets:
main:
resources:
NetworkService:
address: 127.0.0.1
username: <login_username>
drivers:
SSHDriver:
su_password: <the_password>
su_username: <the_username>
stderr_merge: true
14 changes: 14 additions & 0 deletions examples/ssh-su-example/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging

from labgrid import Environment

logging.basicConfig(
level=logging.DEBUG
)

env = Environment("conf.yaml")
target = env.get_target()
ssh = target.get_driver("SSHDriver")
out, _, code = ssh.run("ps -p $PPID")
print(code)
print(out)
70 changes: 66 additions & 4 deletions labgrid/driver/sshdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import time

import attr
from pexpect.fdpexpect import fdspawn
from pexpect.exceptions import EOF, TIMEOUT

from ..factory import target_factory
from ..protocol import CommandProtocol, FileTransferProtocol
Expand All @@ -22,6 +24,7 @@
from ..util.proxy import proxymanager
from ..util.timeout import Timeout
from ..util.ssh import get_ssh_connect_timeout
from ..util.marker import gen_marker


@target_factory.reg_driver
Expand All @@ -34,6 +37,9 @@ class SSHDriver(CommandMixin, Driver, CommandProtocol, FileTransferProtocol):
stderr_merge = attr.ib(default=False, validator=attr.validators.instance_of(bool))
connection_timeout = attr.ib(default=float(get_ssh_connect_timeout()), validator=attr.validators.instance_of(float))
explicit_sftp_mode = attr.ib(default=False, validator=attr.validators.instance_of(bool))
su_password = attr.ib(default=None, validator=attr.validators.optional(attr.validators.instance_of(str)))
su_username = attr.ib(default="root", validator=attr.validators.instance_of(str))
su_prompt = attr.ib(default="Password:", validator=attr.validators.instance_of(str))
Emantor marked this conversation as resolved.
Show resolved Hide resolved

def __attrs_post_init__(self):
super().__attrs_post_init__()
Expand Down Expand Up @@ -180,6 +186,40 @@ def _start_own_master_once(self, timeout):
def run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None):
return self._run(cmd, codec=codec, decodeerrors=decodeerrors, timeout=timeout)

def handle_password(self, fd, stdin, marker):
p = fdspawn(fd, timeout=15)
try:
p.expect([f"{marker}\n"])
except TIMEOUT:
raise ExecutionError(f"Failed to find marker before su: {p.buffer!r}")
except EOF:
raise ExecutionError("Unexpected disconnect before su")

try:
index = p.expect([f"{marker}\n", self.su_prompt])
except TIMEOUT:
raise ExecutionError(f"Unexpected output from su: {p.buffer!r}")
except EOF:
raise ExecutionError("Unexpected disconnect after starting su")

if index == 0:
# no password needed
return p.after

stdin.write(self.su_password.encode("utf-8"))
# It seems we need to close stdin here to reliably get su to accept the
# password. \n doesn't seem to work.
stdin.close()

try:
p.expect([f"{marker}\n"])
except TIMEOUT:
raise ExecutionError(f"Unexpected output from su after entering password: {p.buffer!r}")
except EOF:
raise ExecutionError(f"Unexpected disconnect after after entering su password: {p.before!r}")

return p.after

def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None):
"""Execute `cmd` on the target.

Expand All @@ -193,25 +233,47 @@ def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None):
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")

complete_cmd = ["ssh", "-x", *self.ssh_prefix,
complete_cmd = ["ssh", "-x", "-o", "LogLevel=QUIET", *self.ssh_prefix,
"-p", str(self.networkservice.port), "-l", self.networkservice.username,
self.networkservice.address
] + cmd.split(" ")
]
if self.su_password:
self.stderr_merge = True # with -tt, we get all output on stdout
Emantor marked this conversation as resolved.
Show resolved Hide resolved
marker = gen_marker()
complete_cmd += ["-tt", "--", "echo", f"{marker};", "su", self.su_username, "--"]
inner_cmd = f"echo '{marker[:4]}''{marker[4:]}'; {cmd}"
complete_cmd += ["-c", shlex.quote(inner_cmd)]
else:
complete_cmd += ["--"] + cmd.split(" ")

self.logger.debug("Sending command: %s", complete_cmd)
if self.stderr_merge:
stderr_pipe = subprocess.STDOUT
else:
stderr_pipe = subprocess.PIPE
stdin = subprocess.PIPE if self.su_password else None
stdout, stderr = b"", b""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work with..

if stderr is None:

..below.

try:
sub = subprocess.Popen(
complete_cmd, stdout=subprocess.PIPE, stderr=stderr_pipe
complete_cmd, stdout=subprocess.PIPE, stderr=stderr_pipe, stdin=stdin,
)
except:
raise ExecutionError(
f"error executing command: {complete_cmd}"
)

stdout, stderr = sub.communicate(timeout=timeout)
if self.su_password:
output = self.handle_password(sub.stdout, sub.stdin, marker)
sub.stdin.close()
self.logger.debug("su leftover output: %s", output)
stderr += output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this appended to stderr? Why does this have to be error output?

sub.stdin = None # never try to write to stdin here

comout, comerr = sub.communicate(timeout=timeout)
stdout += comout
if not self.stderr_merge:
stderr += comerr
Comment on lines +272 to +275
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this simply..

stdout, stderr = sub.communicate(timeout=timeout)

..as it was before?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in case we are manging the password, we wan't to prepend the additional bytes we might have read after password input.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still having problems understanding this:

Why is the comout variable needed? Couldn't this be stdout directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use stdout directly, the previous contents will be overwritten if we read additional bytes after password input.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is..

stdout = b""
comout, comerr = sub.communicate(timeout=timeout)
stdout += comout

..any different than..

stdout, comerr = sub.communicate(timeout=timeout)

..? What am I missing here?


stdout = stdout.decode(codec, decodeerrors).split('\n')
if stdout[-1] == '':
stdout.pop()
Expand Down