From eb5d35a82c5aec91d9388e6876b8d57efb5c53f4 Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Wed, 5 Jun 2024 10:43:43 +0200 Subject: [PATCH 1/2] PR #6450 --- pyproject.toml | 7 ++++ tests/cmdline/commands/test_code.py | 4 +-- tests/common/test_timezone.py | 13 ------- tests/conftest.py | 15 ++++++++ tests/orm/data/code/test_installed.py | 36 +++++++++---------- tests/orm/data/code/test_portable.py | 4 +-- .../pytest_fixtures/test_configuration.py | 10 +++--- 7 files changed, 48 insertions(+), 41 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 129746e0fc..c403809be2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -431,6 +431,13 @@ deps = py311: -rrequirements/requirements-py-3.11.txt py312: -rrequirements/requirements-py-3.12.txt +[testenv:py{39,310,311,312}-presto] +passenv = + PYTHONASYNCIODEBUG +setenv = + AIIDA_WARN_v3 = +commands = pytest -m 'presto' {posargs} + [testenv:py{39,310,311,312}] passenv = PYTHONASYNCIODEBUG diff --git a/tests/cmdline/commands/test_code.py b/tests/cmdline/commands/test_code.py index b7d1c5cf5f..32f956d0db 100644 --- a/tests/cmdline/commands/test_code.py +++ b/tests/cmdline/commands/test_code.py @@ -530,7 +530,7 @@ def test_code_test(run_cli_command): @pytest.fixture -def command_options(request, aiida_localhost, tmp_path): +def command_options(request, aiida_localhost, tmp_path, bash_path): """Return tuple of list of options and entry point.""" options = [request.param, '-n', '--label', str(uuid.uuid4())] @@ -550,7 +550,7 @@ def command_options(request, aiida_localhost, tmp_path): '--computer', str(aiida_localhost.pk), '--filepath-executable', - '/usr/bin/bash', + str(bash_path.absolute()), '--engine-command', engine_command, '--image-name', diff --git a/tests/common/test_timezone.py b/tests/common/test_timezone.py index be08c7fef3..8115007de7 100644 --- a/tests/common/test_timezone.py +++ b/tests/common/test_timezone.py @@ -40,19 +40,6 @@ def test_now(): assert from_tz >= ref - dt -def test_make_aware(): - """Test the :func:`aiida.common.timezone.make_aware` function. - - This should make a naive datetime object aware using the timezone of the operating system. - """ - system_tzinfo = datetime.now(timezone.utc).astimezone() # This is how to get the timezone of the OS. - naive = datetime(1970, 1, 1) - aware = make_aware(naive) - assert is_aware(aware) - assert aware.tzinfo.tzname(aware) == system_tzinfo.tzname() - assert aware.tzinfo.utcoffset(aware) == system_tzinfo.utcoffset() - - def test_make_aware_already_aware(): """Test the :func:`aiida.common.timezone.make_aware` function for an already aware datetime. diff --git a/tests/conftest.py b/tests/conftest.py index 3a9cd56336..1dcb644b21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,7 @@ import dataclasses import os import pathlib +import subprocess import types import typing as t import warnings @@ -860,3 +861,17 @@ def factory(dirpath: pathlib.Path, read_bytes=True) -> dict: return serialized return factory + + +@pytest.fixture(scope='session') +def bash_path() -> Path: + run_process = subprocess.run(['which', 'bash'], capture_output=True, check=True) + path = run_process.stdout.decode('utf-8').strip() + return Path(path) + + +@pytest.fixture(scope='session') +def cat_path() -> Path: + run_process = subprocess.run(['which', 'cat'], capture_output=True, check=True) + path = run_process.stdout.decode('utf-8').strip() + return Path(path) diff --git a/tests/orm/data/code/test_installed.py b/tests/orm/data/code/test_installed.py index 5e1fda4273..928d718080 100644 --- a/tests/orm/data/code/test_installed.py +++ b/tests/orm/data/code/test_installed.py @@ -17,29 +17,29 @@ from aiida.orm.nodes.data.code.installed import InstalledCode -def test_constructor_raises(aiida_localhost): +def test_constructor_raises(aiida_localhost, bash_path): """Test the constructor when it is supposed to raise.""" with pytest.raises(TypeError, match=r'missing .* required positional arguments'): InstalledCode() with pytest.raises(TypeError, match=r'Got object of type .*'): - InstalledCode(computer=aiida_localhost, filepath_executable=pathlib.Path('/usr/bin/bash')) + InstalledCode(computer=aiida_localhost, filepath_executable=bash_path) with pytest.raises(TypeError, match=r'Got object of type .*'): InstalledCode(computer='computer', filepath_executable='/usr/bin/bash') -def test_constructor(aiida_localhost): +def test_constructor(aiida_localhost, bash_path): """Test the constructor.""" - filepath_executable = '/usr/bin/bash' + filepath_executable = str(bash_path.absolute()) code = InstalledCode(computer=aiida_localhost, filepath_executable=filepath_executable) assert code.computer.pk == aiida_localhost.pk assert code.filepath_executable == pathlib.PurePath(filepath_executable) -def test_validate(aiida_localhost): +def test_validate(aiida_localhost, bash_path): """Test the validator is called before storing.""" - filepath_executable = '/usr/bin/bash' + filepath_executable = str(bash_path.absolute()) code = InstalledCode(computer=aiida_localhost, filepath_executable=filepath_executable) code.computer = aiida_localhost @@ -53,18 +53,18 @@ def test_validate(aiida_localhost): assert code.is_stored -def test_can_run_on_computer(aiida_localhost): +def test_can_run_on_computer(aiida_localhost, bash_path): """Test the :meth:`aiida.orm.nodes.data.code.installed.InstalledCode.can_run_on_computer` method.""" - code = InstalledCode(computer=aiida_localhost, filepath_executable='/usr/bin/bash') + code = InstalledCode(computer=aiida_localhost, filepath_executable=str(bash_path.absolute())) computer = Computer() assert code.can_run_on_computer(aiida_localhost) assert not code.can_run_on_computer(computer) -def test_filepath_executable(aiida_localhost): +def test_filepath_executable(aiida_localhost, bash_path, cat_path): """Test the :meth:`aiida.orm.nodes.data.code.installed.InstalledCode.filepath_executable` property.""" - filepath_executable = '/usr/bin/bash' + filepath_executable = str(bash_path.absolute()) code = InstalledCode(computer=aiida_localhost, filepath_executable=filepath_executable) assert code.filepath_executable == pathlib.PurePath(filepath_executable) @@ -74,7 +74,7 @@ def test_filepath_executable(aiida_localhost): assert code.filepath_executable == pathlib.PurePath(filepath_executable) # Change through the property - filepath_executable = '/usr/bin/cat' + filepath_executable = str(cat_path.absolute()) code.filepath_executable = filepath_executable assert code.filepath_executable == pathlib.PurePath(filepath_executable) @@ -101,7 +101,7 @@ def computer(request, aiida_computer_local, aiida_computer_ssh): @pytest.mark.usefixtures('aiida_profile_clean') @pytest.mark.parametrize('computer', ('core.local', 'core.ssh'), indirect=True) -def test_validate_filepath_executable(ssh_key, computer): +def test_validate_filepath_executable(ssh_key, computer, bash_path): """Test the :meth:`aiida.orm.nodes.data.code.installed.InstalledCode.validate_filepath_executable` method.""" filepath_executable = '/usr/bin/not-existing' code = InstalledCode(computer=computer, filepath_executable=filepath_executable) @@ -117,19 +117,19 @@ def test_validate_filepath_executable(ssh_key, computer): with pytest.raises(ValidationError, match=r'The provided remote absolute path .* does not exist on the computer\.'): code.validate_filepath_executable() - code.filepath_executable = '/usr/bin/bash' + code.filepath_executable = str(bash_path.absolute()) code.validate_filepath_executable() -def test_full_label(aiida_localhost): +def test_full_label(aiida_localhost, bash_path): """Test the :meth:`aiida.orm.nodes.data.code.installed.InstalledCode.full_label` property.""" label = 'some-label' - code = InstalledCode(label=label, computer=aiida_localhost, filepath_executable='/usr/bin/bash') + code = InstalledCode(label=label, computer=aiida_localhost, filepath_executable=str(bash_path.absolute())) assert code.full_label == f'{label}@{aiida_localhost.label}' -def test_get_execname(aiida_localhost): +def test_get_execname(aiida_localhost, bash_path): """Test the deprecated :meth:`aiida.orm.nodes.data.code.installed.InstalledCode.get_execname` method.""" - code = InstalledCode(label='some-label', computer=aiida_localhost, filepath_executable='/usr/bin/bash') + code = InstalledCode(label='some-label', computer=aiida_localhost, filepath_executable=str(bash_path.absolute())) with pytest.warns(AiidaDeprecationWarning): - assert code.get_execname() == '/usr/bin/bash' + assert code.get_execname() == str(bash_path.absolute()) diff --git a/tests/orm/data/code/test_portable.py b/tests/orm/data/code/test_portable.py index 839c1ed2c6..14dd6ca700 100644 --- a/tests/orm/data/code/test_portable.py +++ b/tests/orm/data/code/test_portable.py @@ -17,13 +17,13 @@ from aiida.orm.nodes.data.code.portable import PortableCode -def test_constructor_raises(tmp_path): +def test_constructor_raises(tmp_path, bash_path): """Test the constructor when it is supposed to raise.""" with pytest.raises(TypeError, match=r'missing .* required positional argument'): PortableCode() with pytest.raises(TypeError, match=r'Got object of type .*'): - PortableCode(filepath_executable=pathlib.Path('/usr/bin/bash'), filepath_files=tmp_path) + PortableCode(filepath_executable=bash_path, filepath_files=tmp_path) with pytest.raises(TypeError, match=r'Got object of type .*'): PortableCode(filepath_executable='bash', filepath_files='string') diff --git a/tests/tools/pytest_fixtures/test_configuration.py b/tests/tools/pytest_fixtures/test_configuration.py index 574d0d4f6a..954e50fa66 100644 --- a/tests/tools/pytest_fixtures/test_configuration.py +++ b/tests/tools/pytest_fixtures/test_configuration.py @@ -1,24 +1,22 @@ """Test the pytest fixtures.""" -import tempfile - -def test_aiida_config(): +def test_aiida_config(tmp_path_factory): """Test that ``aiida_config`` fixture is loaded by default and creates a config instance in temp directory.""" from aiida.manage.configuration import get_config from aiida.manage.configuration.config import Config config = get_config() assert isinstance(config, Config) - assert config.dirpath.startswith(tempfile.gettempdir()) + assert config.dirpath.startswith(str(tmp_path_factory.getbasetemp())) -def test_aiida_config_tmp(aiida_config_tmp): +def test_aiida_config_tmp(aiida_config_tmp, tmp_path_factory): """Test that ``aiida_config_tmp`` returns a config instance in temp directory.""" from aiida.manage.configuration.config import Config assert isinstance(aiida_config_tmp, Config) - assert aiida_config_tmp.dirpath.startswith(tempfile.gettempdir()) + assert aiida_config_tmp.dirpath.startswith(str(tmp_path_factory.getbasetemp())) def test_aiida_profile(): From c3c259ef62fa3991cfe6882d5cdc3f57cc696a04 Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Wed, 11 Sep 2024 18:15:41 +0200 Subject: [PATCH 2/2] Tests now use ssh key specified by env var AIIDA_PYTEST_SSH_KEY The tests were relying on `~/.ssh/id_rsa` to be used to allow ssh to localhost. This had the side effect that several tests specified a ssh key in their configuration that was then not working in the authentication but a fall back to the default ssh key still allowed a connection. This PR fixes this behaviour and uses also the ssh key specified for the connection. This also allows the user to specify a dedicated ssh key for aiida tests and does not enforce the usage of the default key. --- .github/workflows/ci-code.yml | 11 ++- .github/workflows/setup_ssh.sh | 13 +++- pyproject.toml | 4 ++ src/aiida/manage/tests/pytest_fixtures.py | 66 ++++++++++-------- src/aiida/tools/pytest_fixtures/orm.py | 83 +++++++++++++---------- tests/transports/test_all_plugins.py | 14 ++-- tests/transports/test_ssh.py | 45 ++++++++---- 7 files changed, 145 insertions(+), 91 deletions(-) diff --git a/.github/workflows/ci-code.yml b/.github/workflows/ci-code.yml index ff95047bb4..8eb9faf31c 100644 --- a/.github/workflows/ci-code.yml +++ b/.github/workflows/ci-code.yml @@ -104,7 +104,10 @@ jobs: AIIDA_WARN_v3: 1 # Python 3.12 has a performance regression when running with code coverage # so run code coverage only for python 3.9. - run: pytest -v tests -m 'not nightly' ${{ matrix.python-version == '3.9' && '--cov aiida' || '' }} + run: | + # this env needs to set in run and not env, because we need to access $HOME + export AIIDA_PYTEST_SSH_KEY=$HOME/.ssh/id_rsa_aiida_pytest + pytest -v tests -m 'not nightly' ${{ matrix.python-version == '3.9' && '--cov aiida' || '' }} - name: Upload coverage report if: matrix.python-version == 3.9 && github.repository == 'aiidateam/aiida-core' @@ -139,8 +142,10 @@ jobs: - name: Run test suite env: AIIDA_WARN_v3: 0 - run: pytest -m 'presto' - + run: | + # this env needs to set in run and not env, because we need to access $HOME + export AIIDA_PYTEST_SSH_KEY=$HOME/.ssh/id_rsa_aiida_pytest + pytest -m 'presto' verdi: diff --git a/.github/workflows/setup_ssh.sh b/.github/workflows/setup_ssh.sh index a244f1e470..36ab04bd18 100755 --- a/.github/workflows/setup_ssh.sh +++ b/.github/workflows/setup_ssh.sh @@ -1,9 +1,18 @@ #!/usr/bin/env bash +# Sets up ssh keys to allow a ssh connection to localhost. This is needed +# because localhost is used as remote address to run the tests locally. set -ev -ssh-keygen -q -t rsa -b 4096 -N "" -f "${HOME}/.ssh/id_rsa" -ssh-keygen -y -f "${HOME}/.ssh/id_rsa" >> "${HOME}/.ssh/authorized_keys" +mkdir -p ${HOME}/.ssh +ssh-keygen -q -t rsa -b 4096 -N "" -f "${HOME}/.ssh/id_rsa_aiida_pytest" +ssh-keygen -y -f "${HOME}/.ssh/id_rsa_aiida_pytest" >> "${HOME}/.ssh/authorized_keys" ssh-keyscan -H localhost >> "${HOME}/.ssh/known_hosts" +# to test core.ssh_auto transport plugin we need to append this to the config +cat <> ${HOME}/.ssh/config +Host localhost + IdentityFile ${HOME}/.ssh/id_rsa_aiida_pytest +EOT + # The permissions on the GitHub runner are 777 which will cause SSH to refuse the keys and cause authentication to fail chmod 755 "${HOME}" diff --git a/pyproject.toml b/pyproject.toml index c403809be2..94952ad6e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -430,12 +430,15 @@ deps = py310: -rrequirements/requirements-py-3.10.txt py311: -rrequirements/requirements-py-3.11.txt py312: -rrequirements/requirements-py-3.12.txt +setenv = + AIIDA_PYTEST_SSH_KEY = $HOME/.ssh/id_rsa_aiida_pytest [testenv:py{39,310,311,312}-presto] passenv = PYTHONASYNCIODEBUG setenv = AIIDA_WARN_v3 = + AIIDA_PYTEST_SSH_KEY = $HOME/.ssh/id_rsa_aiida_pytest commands = pytest -m 'presto' {posargs} [testenv:py{39,310,311,312}] @@ -443,6 +446,7 @@ passenv = PYTHONASYNCIODEBUG setenv = AIIDA_WARN_v3 = + AIIDA_PYTEST_SSH_KEY = $HOME/.ssh/id_rsa_aiida_pytest commands = pytest {posargs} [testenv:py{39,310,311,312}-verdi] diff --git a/src/aiida/manage/tests/pytest_fixtures.py b/src/aiida/manage/tests/pytest_fixtures.py index 92856aff66..9cf3100101 100644 --- a/src/aiida/manage/tests/pytest_fixtures.py +++ b/src/aiida/manage/tests/pytest_fixtures.py @@ -504,44 +504,50 @@ def get_code(entry_point, executable, computer=aiida_localhost, label=None, **kw @pytest.fixture(scope='session') def ssh_key(tmp_path_factory) -> t.Generator[pathlib.Path, None, None]: - """Generate a temporary SSH key pair for the test session and return the filepath of the private key. + """Returns a SSH key for the test session. + + If the environment variable ``AIIDA_PYTEST_SSH_KEY`` is set we take the key from this path otherwise we generate a + temporary SSH key pair for the test session and return the filepath of the private key. The filepath of the public key is the same as the private key, but it adds the ``.pub`` file extension. """ - from cryptography.hazmat.backends import default_backend as crypto_default_backend - from cryptography.hazmat.primitives import serialization as crypto_serialization - from cryptography.hazmat.primitives.asymmetric import rsa - - key = rsa.generate_private_key( - backend=crypto_default_backend(), - public_exponent=65537, - key_size=2048, - ) + if (ssh_key_path := os.environ.get('AIIDA_PYTEST_SSH_KEY')) is not None: + yield pathlib.Path(ssh_key_path) + else: + from cryptography.hazmat.backends import default_backend as crypto_default_backend + from cryptography.hazmat.primitives import serialization as crypto_serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + key = rsa.generate_private_key( + backend=crypto_default_backend(), + public_exponent=65537, + key_size=2048, + ) - private_key = key.private_bytes( - crypto_serialization.Encoding.PEM, - crypto_serialization.PrivateFormat.PKCS8, - crypto_serialization.NoEncryption(), - ) + private_key = key.private_bytes( + crypto_serialization.Encoding.PEM, + crypto_serialization.PrivateFormat.PKCS8, + crypto_serialization.NoEncryption(), + ) - public_key = key.public_key().public_bytes( - crypto_serialization.Encoding.OpenSSH, - crypto_serialization.PublicFormat.OpenSSH, - ) + public_key = key.public_key().public_bytes( + crypto_serialization.Encoding.OpenSSH, + crypto_serialization.PublicFormat.OpenSSH, + ) - dirpath = tmp_path_factory.mktemp('keys') - filename = uuid.uuid4().hex - filepath_private_key = dirpath / filename - filepath_public_key = dirpath / f'{filename}.pub' + dirpath = tmp_path_factory.mktemp('keys') + filename = uuid.uuid4().hex + filepath_private_key = dirpath / filename + filepath_public_key = dirpath / f'{filename}.pub' - filepath_private_key.write_bytes(private_key) - filepath_public_key.write_bytes(public_key) + filepath_private_key.write_bytes(private_key) + filepath_public_key.write_bytes(public_key) - try: - yield filepath_private_key - finally: - filepath_private_key.unlink(missing_ok=True) - filepath_public_key.unlink(missing_ok=True) + try: + yield filepath_private_key + finally: + filepath_private_key.unlink(missing_ok=True) + filepath_public_key.unlink(missing_ok=True) @pytest.fixture diff --git a/src/aiida/tools/pytest_fixtures/orm.py b/src/aiida/tools/pytest_fixtures/orm.py index 0ed7ea18d7..1f50d82526 100644 --- a/src/aiida/tools/pytest_fixtures/orm.py +++ b/src/aiida/tools/pytest_fixtures/orm.py @@ -13,48 +13,57 @@ @pytest.fixture(scope='session') def ssh_key(tmp_path_factory) -> t.Generator[pathlib.Path, None, None]: - """Generate a temporary SSH key pair for the test session and return the filepath of the private key. + """Returns a SSH key for the test session. + + If the environment variable ``AIIDA_PYTEST_SSH_KEY`` is set we take the key + from this path otherwise we generate a temporary SSH key pair for the test + session and return the filepath of the private key. The filepath of the public key is the same as the private key, but it adds the ``.pub`` file extension. :returns: The filepath of the generated private key. """ - from uuid import uuid4 - - from cryptography.hazmat.backends import default_backend as crypto_default_backend - from cryptography.hazmat.primitives import serialization as crypto_serialization - from cryptography.hazmat.primitives.asymmetric import rsa - - key = rsa.generate_private_key( - backend=crypto_default_backend(), - public_exponent=65537, - key_size=2048, - ) - - private_key = key.private_bytes( - crypto_serialization.Encoding.PEM, - crypto_serialization.PrivateFormat.PKCS8, - crypto_serialization.NoEncryption(), - ) - - public_key = key.public_key().public_bytes( - crypto_serialization.Encoding.OpenSSH, - crypto_serialization.PublicFormat.OpenSSH, - ) - - dirpath = tmp_path_factory.mktemp('keys') - filename = uuid4().hex - filepath_private_key = dirpath / filename - filepath_public_key = dirpath / f'{filename}.pub' - - filepath_private_key.write_bytes(private_key) - filepath_public_key.write_bytes(public_key) - - try: - yield filepath_private_key - finally: - filepath_private_key.unlink(missing_ok=True) - filepath_public_key.unlink(missing_ok=True) + import os + + if (ssh_key_path := os.environ.get('AIIDA_PYTEST_SSH_KEY')) is not None: + yield pathlib.Path(ssh_key_path) + else: + from uuid import uuid4 + + from cryptography.hazmat.backends import default_backend as crypto_default_backend + from cryptography.hazmat.primitives import serialization as crypto_serialization + from cryptography.hazmat.primitives.asymmetric import rsa + + key = rsa.generate_private_key( + backend=crypto_default_backend(), + public_exponent=65537, + key_size=2048, + ) + + private_key = key.private_bytes( + crypto_serialization.Encoding.PEM, + crypto_serialization.PrivateFormat.PKCS8, + crypto_serialization.NoEncryption(), + ) + + public_key = key.public_key().public_bytes( + crypto_serialization.Encoding.OpenSSH, + crypto_serialization.PublicFormat.OpenSSH, + ) + + dirpath = tmp_path_factory.mktemp('keys') + filename = uuid4().hex + filepath_private_key = dirpath / filename + filepath_public_key = dirpath / f'{filename}.pub' + + filepath_private_key.write_bytes(private_key) + filepath_public_key.write_bytes(public_key) + + try: + yield filepath_private_key + finally: + filepath_private_key.unlink(missing_ok=True) + filepath_public_key.unlink(missing_ok=True) @pytest.fixture diff --git a/tests/transports/test_all_plugins.py b/tests/transports/test_all_plugins.py index c536b196a2..da950ee8ac 100644 --- a/tests/transports/test_all_plugins.py +++ b/tests/transports/test_all_plugins.py @@ -34,18 +34,20 @@ @pytest.fixture(scope='function', params=entry_point.get_entry_point_names('aiida.transports')) -def custom_transport(request, tmp_path, monkeypatch) -> Transport: +def custom_transport(request, tmp_path, monkeypatch, ssh_key) -> Transport: """Fixture that parametrizes over all the registered implementations of the ``CommonRelaxWorkChain``.""" plugin = TransportFactory(request.param) if request.param == 'core.ssh': - kwargs = {'machine': 'localhost', 'timeout': 30, 'load_system_host_keys': True, 'key_policy': 'AutoAddPolicy'} + kwargs = { + 'machine': 'localhost', + 'timeout': 30, + 'load_system_host_keys': True, + 'key_policy': 'AutoAddPolicy', + 'key_filename': str(ssh_key), + } elif request.param == 'core.ssh_auto': kwargs = {'machine': 'localhost'} - filepath_config = tmp_path / 'config' - monkeypatch.setattr(plugin, 'FILEPATH_CONFIG', filepath_config) - if not filepath_config.exists(): - filepath_config.write_text('Host localhost') else: kwargs = {} diff --git a/tests/transports/test_ssh.py b/tests/transports/test_ssh.py index 27698dfa54..6309b53cb1 100644 --- a/tests/transports/test_ssh.py +++ b/tests/transports/test_ssh.py @@ -30,16 +30,27 @@ def test_closed_connection_sftp(): transport.listdir() -def test_auto_add_policy(): +def test_auto_add_policy(ssh_key): """Test the auto add policy.""" - with SshTransport(machine='localhost', timeout=30, load_system_host_keys=True, key_policy='AutoAddPolicy'): + with SshTransport( + machine='localhost', + timeout=30, + load_system_host_keys=True, + key_policy='AutoAddPolicy', + key_filename=str(ssh_key), + ): pass -def test_proxy_jump(): +def test_proxy_jump(ssh_key): """Test the connection with a proxy jump or several""" with SshTransport( - machine='localhost', proxy_jump='localhost', timeout=30, load_system_host_keys=True, key_policy='AutoAddPolicy' + machine='localhost', + proxy_jump='localhost', + timeout=30, + load_system_host_keys=True, + key_policy='AutoAddPolicy', + key_filename=str(ssh_key), ): pass @@ -50,6 +61,7 @@ def test_proxy_jump(): timeout=30, load_system_host_keys=True, key_policy='AutoAddPolicy', + key_filename=str(ssh_key), ): pass @@ -69,7 +81,7 @@ def test_proxy_jump_invalid(): pass -def test_proxy_command(): +def test_proxy_command(ssh_key): """Test the connection with a proxy command""" with SshTransport( machine='localhost', @@ -77,6 +89,7 @@ def test_proxy_command(): timeout=30, load_system_host_keys=True, key_policy='AutoAddPolicy', + key_filename=str(ssh_key), ): pass @@ -94,7 +107,7 @@ def test_no_host_key(): logging.disable(logging.NOTSET) -def test_gotocomputer(): +def test_gotocomputer(ssh_key): """Test gotocomputer""" with SshTransport( machine='localhost', @@ -102,18 +115,21 @@ def test_gotocomputer(): use_login_shell=False, key_policy='AutoAddPolicy', proxy_command='ssh -W localhost:22 localhost', + key_filename=str(ssh_key), ) as transport: cmd_str = transport.gotocomputer_command('/remote_dir/') - expected_str = ( - """ssh -t localhost -o ProxyCommand='ssh -W localhost:22 localhost' "if [ -d '/remote_dir/' ] ;""" + expected_startwith = 'ssh -t localhost -i ' + expected_endwith = ( + """ -o ProxyCommand='ssh -W localhost:22 localhost' "if [ -d '/remote_dir/' ] ;""" """ then cd '/remote_dir/' ; bash ; else echo ' ** The directory' ; """ """echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """ ) - assert cmd_str == expected_str + assert cmd_str.startswith(expected_startwith) + assert cmd_str.endswith(expected_endwith) -def test_gotocomputer_proxyjump(): +def test_gotocomputer_proxyjump(ssh_key): """Test gotocomputer""" with SshTransport( machine='localhost', @@ -121,12 +137,15 @@ def test_gotocomputer_proxyjump(): use_login_shell=False, key_policy='AutoAddPolicy', proxy_jump='localhost', + key_filename=str(ssh_key), ) as transport: cmd_str = transport.gotocomputer_command('/remote_dir/') - expected_str = ( - """ssh -t localhost -o ProxyJump='localhost' "if [ -d '/remote_dir/' ] ;""" + expected_startwith = 'ssh -t localhost -i ' + expected_endwith = ( + """-o ProxyJump='localhost' "if [ -d '/remote_dir/' ] ;""" """ then cd '/remote_dir/' ; bash ; else echo ' ** The directory' ; """ """echo ' ** /remote_dir/' ; echo ' ** seems to have been deleted, I logout...' ; fi" """ ) - assert cmd_str == expected_str + assert cmd_str.startswith(expected_startwith) + assert cmd_str.endswith(expected_endwith)