diff --git a/gvm/connections.py b/gvm/connections.py
index e1a1fb38a..cf77cfccd 100644
--- a/gvm/connections.py
+++ b/gvm/connections.py
@@ -13,9 +13,10 @@
import ssl
import sys
import time
+from abc import ABC, abstractmethod
from os import PathLike
from pathlib import Path
-from typing import Optional, Union
+from typing import Optional, Protocol, Union, runtime_checkable
import paramiko
import paramiko.ssh_exception
@@ -41,6 +42,19 @@
Data = Union[str, bytes]
+@runtime_checkable
+class GvmConnection(Protocol):
+ def connect(self) -> None: ...
+
+ def disconnect(self) -> None: ...
+
+ def send(self, data: Data) -> None: ...
+
+ def read(self) -> str: ...
+
+ def finish_send(self): ...
+
+
class XmlReader:
"""
Read a XML command until its closing element
@@ -77,7 +91,7 @@ def feed_xml(self, data: Data) -> None:
) from None
-class GvmConnection:
+class AbstractGvmConnection(ABC):
"""
Base class for establishing a connection to a remote server daemon.
@@ -97,6 +111,7 @@ def _read(self) -> bytes:
return self._socket.recv(BUF_SIZE)
+ @abstractmethod
def connect(self) -> None:
"""Establish a connection to a remote server"""
raise NotImplementedError
@@ -164,7 +179,7 @@ def finish_send(self):
self._socket.shutdown(socketlib.SHUT_WR)
-class SSHConnection(GvmConnection):
+class SSHConnection(AbstractGvmConnection):
"""
SSH Class to connect, read and write from GVM via SSH
@@ -174,7 +189,7 @@ class SSHConnection(GvmConnection):
127.0.0.1.
port: Port of the remote SSH server. Default is port 22.
username: Username to use for SSH login. Default is "gmp".
- password: Passwort to use for SSH login. Default is "".
+ password: Password to use for SSH login. Default is "".
"""
def __init__(
@@ -188,8 +203,7 @@ def __init__(
known_hosts_file: Optional[Union[str, PathLike]] = None,
auto_accept_host: Optional[bool] = None,
) -> None:
- super().__init__(timeout=timeout)
-
+ super().__init__(timeout)
self.hostname = hostname if hostname is not None else DEFAULT_HOSTNAME
self.port = int(port) if port is not None else DEFAULT_SSH_PORT
self.username = (
@@ -414,11 +428,11 @@ def connect(self) -> None:
def _read(self) -> bytes:
return self._stdout.channel.recv(BUF_SIZE)
- def send(self, data: Union[bytes, str]) -> int:
+ def send(self, data: Data) -> None:
if isinstance(data, str):
- return self._send_all(data.encode())
-
- return self._send_all(data)
+ self._send_all(data.encode())
+ else:
+ self._send_all(data)
def finish_send(self) -> None:
# shutdown socket for sending. only allow reading data afterwards
@@ -439,7 +453,7 @@ def disconnect(self) -> None:
del self._socket, self._stdin, self._stdout, self._stderr
-class TLSConnection(GvmConnection):
+class TLSConnection(AbstractGvmConnection):
"""
TLS class to connect, read and write from a remote GVM daemon via TLS
secured socket.
@@ -524,7 +538,7 @@ def disconnect(self):
return super().disconnect()
-class UnixSocketConnection(GvmConnection):
+class UnixSocketConnection(AbstractGvmConnection):
"""
UNIX-Socket class to connect, read, write from a daemon via direct
communicating UNIX-Socket
diff --git a/tests/connections/test_gvm_connection.py b/tests/connections/test_gvm_connection.py
index fea552073..48b415062 100644
--- a/tests/connections/test_gvm_connection.py
+++ b/tests/connections/test_gvm_connection.py
@@ -6,7 +6,13 @@
import unittest
from unittest.mock import patch
-from gvm.connections import DEFAULT_TIMEOUT, GvmConnection, XmlReader
+from gvm.connections import (
+ DEFAULT_TIMEOUT,
+ AbstractGvmConnection,
+ DebugConnection,
+ GvmConnection,
+ XmlReader,
+)
from gvm.errors import GvmError
@@ -19,39 +25,49 @@ def test_is_end_xml_false(self):
self.assertFalse(false)
+class TestConnection(AbstractGvmConnection):
+ def connect(self) -> None:
+ pass
+
+
class GvmConnectionTestCase(unittest.TestCase):
# pylint: disable=protected-access
def test_init_no_args(self):
- connection = GvmConnection()
+ connection = TestConnection()
self.check_for_default_values(connection)
def test_init_with_none(self):
- connection = GvmConnection(timeout=None)
+ connection = TestConnection(timeout=None)
self.check_for_default_values(connection)
def check_for_default_values(self, gvm_connection: GvmConnection):
self.assertIsNone(gvm_connection._socket)
self.assertEqual(gvm_connection._timeout, DEFAULT_TIMEOUT)
- def test_connect_not_implemented(self):
- connection = GvmConnection()
- with self.assertRaises(NotImplementedError):
- connection.connect()
-
- @patch("gvm.connections.GvmConnection._read")
+ @patch("gvm.connections.AbstractGvmConnection._read")
def test_read_no_data(self, _read_mock):
_read_mock.return_value = None
- connection = GvmConnection()
+ connection = TestConnection()
with self.assertRaises(GvmError, msg="Remote closed the connection"):
connection.read()
- @patch("gvm.connections.GvmConnection._read")
+ @patch("gvm.connections.AbstractGvmConnection._read")
def test_read_trigger_timeout(self, _read_mock):
# mocking the response into two parts, so we run into the timeout
# check in the loop
_read_mock.side_effect = [b"xyz", b""]
- connection = GvmConnection(timeout=0)
+ connection = TestConnection(timeout=0)
with self.assertRaises(
GvmError, msg="Timeout while reading the response"
):
connection.read()
+
+ def test_is_gvm_connection(self):
+ connection = TestConnection()
+ self.assertTrue(isinstance(connection, GvmConnection))
+
+
+class DebugConnectionTestCase(unittest.TestCase):
+ def test_is_gvm_connection(self):
+ connection = DebugConnection(TestConnection())
+ self.assertTrue(isinstance(connection, GvmConnection))
diff --git a/tests/connections/test_ssh_connection.py b/tests/connections/test_ssh_connection.py
index 029e7922d..75d5fe70b 100644
--- a/tests/connections/test_ssh_connection.py
+++ b/tests/connections/test_ssh_connection.py
@@ -17,6 +17,7 @@
DEFAULT_SSH_PASSWORD,
DEFAULT_SSH_PORT,
DEFAULT_SSH_USERNAME,
+ GvmConnection,
SSHConnection,
)
from gvm.errors import GvmError
@@ -177,7 +178,7 @@ def test_connect_adding_and_save_hostkey(self, input_mock, _print_mock):
)
with self.assertLogs("gvm.connections", level="INFO") as cm:
- hostkeys = paramiko.HostKeys(filename=self.known_hosts_file)
+ hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file))
ssh_connection._ssh_authentication_input_loop(
hostkeys=hostkeys, key=key
)
@@ -229,7 +230,7 @@ def test_connect_adding_and_dont_save_hostkey(
)
with self.assertLogs("gvm.connections", level="INFO") as cm:
- hostkeys = paramiko.HostKeys(filename=self.known_hosts_file)
+ hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file))
ssh_connection._ssh_authentication_input_loop(
hostkeys=hostkeys, key=key
)
@@ -274,7 +275,7 @@ def test_connect_wrong_input(self, stdout_mock, input_mock):
ssh_connection._socket = paramiko.SSHClient()
with self.assertLogs("gvm.connections", level="INFO") as cm:
- hostkeys = paramiko.HostKeys(filename=self.known_hosts_file)
+ hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file))
ssh_connection._ssh_authentication_input_loop(
hostkeys=hostkeys, key=key
)
@@ -323,7 +324,7 @@ def test_user_denies_auth(self, input_mock):
with self.assertRaises(
SystemExit, msg="User denied key. Host key verification failed."
):
- hostkeys = paramiko.HostKeys(filename=self.known_hosts_file)
+ hostkeys = paramiko.HostKeys(filename=str(self.known_hosts_file))
ssh_connection._ssh_authentication_input_loop(
hostkeys=hostkeys, key=key
)
@@ -441,8 +442,7 @@ def test_send(self):
)
ssh_connection.connect()
- req = ssh_connection.send("blah")
- self.assertEqual(req, 4)
+ ssh_connection.send("blah")
ssh_connection.disconnect()
def test_send_error(self):
@@ -473,8 +473,7 @@ def test_send_and_slice(self):
)
ssh_connection.connect()
- req = ssh_connection.send("blah")
- self.assertEqual(req, 4)
+ ssh_connection.send("blah")
stdin.channel.send.assert_called()
with self.assertRaises(AssertionError):
@@ -495,3 +494,7 @@ def test_read(self):
recved = ssh_connection._read()
self.assertEqual(recved, b"foo bar baz")
ssh_connection.disconnect()
+
+ def test_is_gvm_connection(self):
+ ssh_connection = SSHConnection(known_hosts_file=self.known_hosts_file)
+ self.assertTrue(isinstance(ssh_connection, GvmConnection))
diff --git a/tests/connections/test_tls_connection.py b/tests/connections/test_tls_connection.py
index f479f16da..7b1cf9dd6 100644
--- a/tests/connections/test_tls_connection.py
+++ b/tests/connections/test_tls_connection.py
@@ -10,6 +10,7 @@
DEFAULT_GVM_PORT,
DEFAULT_HOSTNAME,
DEFAULT_TIMEOUT,
+ GvmConnection,
TLSConnection,
)
@@ -62,3 +63,7 @@ def test_connect_auth(self):
context_mock.load_cert_chain.assert_called_once()
context_mock.wrap_socket.assert_called_once()
self.assertFalse(context_mock.check_hostname)
+
+ def test_is_gvm_connection(self):
+ connection = TLSConnection()
+ self.assertTrue(isinstance(connection, GvmConnection))
diff --git a/tests/connections/test_unix_socket_connection.py b/tests/connections/test_unix_socket_connection.py
index 478928edd..4568f78e6 100644
--- a/tests/connections/test_unix_socket_connection.py
+++ b/tests/connections/test_unix_socket_connection.py
@@ -14,6 +14,7 @@
from gvm.connections import (
DEFAULT_TIMEOUT,
DEFAULT_UNIX_SOCKET_PATH,
+ GvmConnection,
UnixSocketConnection,
)
from gvm.errors import GvmError
@@ -65,8 +66,7 @@ def test_unix_socket_connection_connect_send_bytes_read(self):
path=self.socketname, timeout=DEFAULT_TIMEOUT
)
connection.connect()
- req = connection.send(bytes("", "utf-8"))
- self.assertIsNone(req)
+ connection.send(bytes("", "utf-8"))
resp = connection.read()
self.assertEqual(resp, '')
connection.disconnect()
@@ -76,8 +76,7 @@ def test_unix_socket_connection_connect_send_str_read(self):
path=self.socketname, timeout=DEFAULT_TIMEOUT
)
connection.connect()
- req = connection.send("")
- self.assertIsNone(req)
+ connection.send("")
resp = connection.read()
self.assertEqual(resp, '')
connection.disconnect()
@@ -120,6 +119,6 @@ def check_default_values(self, connection: UnixSocketConnection):
self.assertEqual(connection._timeout, DEFAULT_TIMEOUT)
self.assertEqual(connection.path, DEFAULT_UNIX_SOCKET_PATH)
-
-if __name__ == "__main__":
- unittest.main()
+ def test_is_gvm_connection(self):
+ connection = UnixSocketConnection()
+ self.assertTrue(isinstance(connection, GvmConnection))