diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 2b11178..f17e487 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -11,4 +11,5 @@ jobs: steps: - uses: actions/checkout@v3 - uses: actions/setup-python@v3 + - run: python -m pip install .[tests] - uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 898981e..baeef2d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,3 +22,4 @@ repos: args: [--strict, --ignore-missing-imports, --check-untyped-defs] additional_dependencies: - types-PyYAML + - types-paramiko==3.4.0.* diff --git a/pyproject.toml b/pyproject.toml index f841c8a..97da775 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,3 +24,6 @@ classifiers = [ [project.urls] Homepage = "https://scs.community" + +[tool.pytest.ini_options] +pythonpath = [ "src" ] diff --git a/setup.cfg b/setup.cfg index b1254e4..d1edfe8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,3 +3,7 @@ platforms = any [options] install_requires=file:requirements.txt + +[options.extras_require] +tests = + pytest==8.0.2 diff --git a/src/rookify/modules/module.py b/src/rookify/modules/module.py index 8d1a75a..18d28ba 100644 --- a/src/rookify/modules/module.py +++ b/src/rookify/modules/module.py @@ -4,7 +4,8 @@ import yaml import json import abc -import rados + +# import rados import kubernetes import fabric import jinja2 diff --git a/tests/mock_ceph.py b/tests/mock_ceph.py new file mode 100644 index 0000000..9116a75 --- /dev/null +++ b/tests/mock_ceph.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +import json +from collections.abc import Callable +from rookify.modules.module import ModuleException +from threading import RLock +from typing import Any, Dict, List, Optional, Tuple + + +class MockCeph(object): + def __init__(self, config: Dict[str, Any]): + self._callback_handler: Optional[ + Callable[[str, bytes], Tuple[int, bytes, str]] + ] = None + self._thread_lock = RLock() + + def handle_with_callback( + self, _callable: Callable[[str, bytes], Tuple[int, bytes, str]] + ) -> None: + with self._thread_lock: + if self._callback_handler is not None: + raise RuntimeError("Callback handler already registered") + + self._callback_handler = _callable + + def mon_command( + self, command: str, inbuf: bytes, **kwargs: Any + ) -> Dict[str, Any] | List[Any]: + if not callable(self._callback_handler): + raise RuntimeError("Handler function given is invalid") + + ret, outbuf, outstr = self._callback_handler(command, inbuf, **kwargs) + if ret != 0: + raise ModuleException("Ceph did return an error: {0!r}".format(outbuf)) + + data = json.loads(outbuf) + assert isinstance(data, dict) or isinstance(data, list) + return data + + def stop_handler(self) -> None: + self._callback_handler = None diff --git a/tests/mock_ssh_server.py b/tests/mock_ssh_server.py new file mode 100644 index 0000000..01d0b55 --- /dev/null +++ b/tests/mock_ssh_server.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- + + +from collections.abc import Callable +from socket import AF_INET, IPPROTO_TCP, SO_REUSEADDR, SOCK_STREAM, SOL_SOCKET, socket +from threading import Event, RLock +from typing import Any, Optional + +from paramiko import ( # type: ignore[attr-defined] + AUTH_FAILED, + AUTH_SUCCESSFUL, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + OPEN_SUCCEEDED, + AutoAddPolicy, + Channel, + PKey, + RSAKey, + ServerInterface, + SSHClient, + Transport, +) + + +class MockSSHServer(ServerInterface): + """An ssh server accepting the pre-generated key.""" + + ssh_username = "pytest" + ssh_key = RSAKey.generate(4096) + + def __init__(self) -> None: + ServerInterface.__init__(self) + + self._callback_handler: Optional[Callable[[bytes, Channel], None]] = None + self._channel: Any = None + self._client: Optional[SSHClient] = None + self._command: Optional[bytes] = None + self.event = Event() + self._server_transport: Optional[Transport] = None + self._thread_lock = RLock() + + def __del__(self) -> None: + self.close() + + @property + def client(self) -> SSHClient: + with self._thread_lock: + if self._client is None: + connection_event = Event() + + server_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP) + server_socket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + server_socket.bind(("127.0.0.1", 0)) + server_socket.listen() + + server_address = server_socket.getsockname() + + client_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP) + client_socket.connect(server_address) + + (transport_socket, _) = server_socket.accept() + + self._server_transport = Transport(transport_socket) + self._server_transport.add_server_key(self.__class__.ssh_key) + self._server_transport.start_server(connection_event, self) + + self._client = SSHClient() + self._client.set_missing_host_key_policy(AutoAddPolicy()) + + self._client.connect( + server_address[0], + server_address[1], + username=self.__class__.ssh_username, + pkey=self.__class__.ssh_key, + sock=client_socket, + ) + + connection_event.wait() + + return self._client + + def check_channel_request(self, kind: str, chanid: int) -> int: + if kind == "session": + return OPEN_SUCCEEDED # type: ignore[no-any-return] + return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED # type: ignore[no-any-return] + + def check_auth_password(self, username: str, password: str) -> int: + return AUTH_FAILED # type: ignore[no-any-return] + + def check_auth_publickey(self, username: str, key: PKey) -> int: + if username == self.__class__.ssh_username and key == self.__class__.ssh_key: + return AUTH_SUCCESSFUL # type: ignore[no-any-return] + return AUTH_FAILED # type: ignore[no-any-return] + + def check_channel_exec_request(self, channel: Channel, command: bytes) -> bool: + if self.event.is_set(): + return False + + self.event.set() + + with self._thread_lock: + self._channel = channel + self._command = command + + if self._callback_handler is not None: + self.handle_exec_request(self._callback_handler) + + return True + + def close(self) -> None: + self.stop_exec_requests_handler() + + if self._server_transport is not None: + self._server_transport.close() + self._server_transport = None + + def get_allowed_auths(self, username: str) -> str: + if username == self.__class__.ssh_username: + return "publickey" + return "" + + def handle_exec_request(self, _callable: Callable[[bytes, Channel], None]) -> None: + if not callable(_callable): + raise RuntimeError("Handler function given is invalid") + + _callable(self._command, self._channel) # type: ignore[arg-type] + + if self._channel.recv_ready() is not True: + self._channel.send( + bytes("Command {0!r} invalid\n".format(self._command), "utf-8") + ) + + self._channel = None + self._client = None + + self.event.clear() + + def handle_exec_requests_with_callback( + self, _callable: Callable[[bytes, Channel], None] + ) -> None: + with self._thread_lock: + if self._callback_handler is not None: + raise RuntimeError("Callback handler already registered") + + self._callback_handler = _callable + + def stop_exec_requests_handler(self) -> None: + self._callback_handler = None diff --git a/tests/modules/__init__.py b/tests/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/modules/test_example.py b/tests/modules/test_example.py new file mode 100644 index 0000000..1ff294a --- /dev/null +++ b/tests/modules/test_example.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +import pytest + +from rookify.modules.example.main import ExampleHandler +from rookify.modules.module import ModuleException + + +def test_preflight() -> None: + with pytest.raises(ModuleException): + ExampleHandler({}, {}, "").preflight() diff --git a/tests/test_mock_ceph.py b/tests/test_mock_ceph.py new file mode 100644 index 0000000..79ad9cd --- /dev/null +++ b/tests/test_mock_ceph.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +from typing import Any, Dict, Tuple +from unittest import TestCase + +from .mock_ceph import MockCeph + + +class TestMockCeph(TestCase): + ceph: Any = None + + @classmethod + def setUpClass(cls) -> None: + cls.ceph = MockCeph({}) + + def setUp(self) -> None: + self.__class__.ceph.handle_with_callback(self._command_callback) + + def tearDown(self) -> None: + self.__class__.ceph.stop_handler() + + def _command_callback( + self, command: str, inbuf: bytes, **kwargs: Dict[Any, Any] + ) -> Tuple[int, bytes, str]: + if command == "test": + return 0, b'["ok"]', "" + return -1, b'["Command not found"]', "" + + def test_self(self) -> None: + res = self.__class__.ceph.mon_command("test", b"") + self.assertEqual(res, ["ok"]) diff --git a/tests/test_mock_ssh_server.py b/tests/test_mock_ssh_server.py new file mode 100644 index 0000000..bbf498e --- /dev/null +++ b/tests/test_mock_ssh_server.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- + +from paramiko import Channel +from typing import Any +from unittest import TestCase + +from .mock_ssh_server import MockSSHServer + + +class TestMockSSHServer(TestCase): + ssh_client: Any = None + ssh_server: Any = None + + @classmethod + def setUpClass(cls) -> None: + cls.ssh_server = MockSSHServer() + cls.ssh_client = cls.ssh_server.client + + @classmethod + def tearDownClass(cls) -> None: + cls.ssh_server.close() + + def setUp(self) -> None: + self.__class__.ssh_server.handle_exec_requests_with_callback( + self._command_callback + ) + + def tearDown(self) -> None: + self.__class__.ssh_server.stop_exec_requests_handler() + + def _command_callback(self, command: bytes, channel: Channel) -> None: + if command == b"test": + channel.send(b"ok\n") + + def test_self(self) -> None: + _, stdout, _ = self.__class__.ssh_client.exec_command("test") + self.assertEqual(stdout.readline(), "ok\n")