diff --git a/tests/test_token_plugins.py b/tests/test_token_plugins.py index f09ae8a8..a9fd256d 100644 --- a/tests/test_token_plugins.py +++ b/tests/test_token_plugins.py @@ -7,7 +7,29 @@ from unittest.mock import patch, mock_open, MagicMock from jwcrypto import jwt, jwk -from websockify.token_plugins import ReadOnlyTokenFile, JWTTokenApi, TokenRedis +from websockify.token_plugins import parse_source_args, ReadOnlyTokenFile, JWTTokenApi, TokenRedis + +class ParseSourceArgumentsTestCase(unittest.TestCase): + def test_parameterized(self): + params = [ + ('', ['']), + (':', ['', '']), + ('::', ['', '', '']), + ('"', ['"']), + ('""', ['""']), + ('"""', ['"""']), + ('"localhost"', ['localhost']), + ('"localhost":', ['localhost', '']), + ('"localhost"::', ['localhost', '', '']), + ('"local:host"', ['local:host']), + ('"local:host:"pass"', ['"local', 'host', "pass"]), + ('"local":"host"', ['local', 'host']), + ('"local":host"', ['local', 'host"']), + ('localhost:6379:1:pass"word:"my-app-namespace:dev"', + ['localhost', '6379', '1', 'pass"word', 'my-app-namespace:dev']), + ] + for src, args in params: + self.assertEqual(args, parse_source_args(src)) class ReadOnlyTokenFileTestCase(unittest.TestCase): patch('os.path.isdir', MagicMock(return_value=False)) @@ -267,6 +289,42 @@ def test_invalid_token(self, mock_redis): instance.get.assert_called_once_with('testhost') self.assertIsNone(result) + @patch('redis.Redis') + def test_token_without_namespace(self, mock_redis): + plugin = TokenRedis('127.0.0.1:1234') + token = 'testhost' + + def mock_redis_get(key): + self.assertEqual(key, token) + return b'remote_host:remote_port' + + instance = mock_redis.return_value + instance.get = mock_redis_get + + result = plugin.lookup(token) + + self.assertIsNotNone(result) + self.assertEqual(result[0], 'remote_host') + self.assertEqual(result[1], 'remote_port') + + @patch('redis.Redis') + def test_token_with_namespace(self, mock_redis): + plugin = TokenRedis('127.0.0.1:1234:::namespace') + token = 'testhost' + + def mock_redis_get(key): + self.assertEqual(key, "namespace:" + token) + return b'remote_host:remote_port' + + instance = mock_redis.return_value + instance.get = mock_redis_get + + result = plugin.lookup(token) + + self.assertIsNotNone(result) + self.assertEqual(result[0], 'remote_host') + self.assertEqual(result[1], 'remote_port') + def test_src_only_host(self): plugin = TokenRedis('127.0.0.1') @@ -274,6 +332,7 @@ def test_src_only_host(self): self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "") def test_src_with_host_port(self): plugin = TokenRedis('127.0.0.1:1234') @@ -282,6 +341,7 @@ def test_src_with_host_port(self): self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "") def test_src_with_host_port_db(self): plugin = TokenRedis('127.0.0.1:1234:2') @@ -290,6 +350,7 @@ def test_src_with_host_port_db(self): self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "") def test_src_with_host_port_db_pass(self): plugin = TokenRedis('127.0.0.1:1234:2:verysecret') @@ -298,67 +359,112 @@ def test_src_with_host_port_db_pass(self): self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._namespace, "") - def test_src_with_host_empty_port_empty_db_pass(self): + def test_src_with_host_port_db_pass_namespace(self): + plugin = TokenRedis('127.0.0.1:1234:2:verysecret:namespace') + + self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._port, 1234) + self.assertEqual(plugin._db, 2) + self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._namespace, "namespace:") + + def test_src_with_host_empty_port_empty_db_pass_no_namespace(self): plugin = TokenRedis('127.0.0.1:::verysecret') self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._namespace, "") + + def test_src_with_host_empty_port_empty_db_empty_pass_empty_namespace(self): + plugin = TokenRedis('127.0.0.1::::') - def test_src_with_host_empty_port_empty_db_empty_pass(self): + self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._port, 6379) + self.assertEqual(plugin._db, 0) + self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "") + + def test_src_with_host_empty_port_empty_db_empty_pass_no_namespace(self): plugin = TokenRedis('127.0.0.1:::') self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "") - def test_src_with_host_empty_port_empty_db_no_pass(self): + def test_src_with_host_empty_port_empty_db_no_pass_no_namespace(self): plugin = TokenRedis('127.0.0.1::') self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "") - def test_src_with_host_empty_port_no_db_no_pass(self): + def test_src_with_host_empty_port_no_db_no_pass_no_namespace(self): plugin = TokenRedis('127.0.0.1:') self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "") + + def test_src_with_host_empty_port_empty_db_empty_pass_namespace(self): + plugin = TokenRedis('127.0.0.1::::namespace') + + self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._port, 6379) + self.assertEqual(plugin._db, 0) + self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "namespace:") + + def test_src_with_host_empty_port_empty_db_empty_pass_nested_namespace(self): + plugin = TokenRedis('127.0.0.1::::"ns1:ns2"') + + self.assertEqual(plugin._server, '127.0.0.1') + self.assertEqual(plugin._port, 6379) + self.assertEqual(plugin._db, 0) + self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "ns1:ns2:") - def test_src_with_host_empty_port_db_no_pass(self): + def test_src_with_host_empty_port_db_no_pass_no_namespace(self): plugin = TokenRedis('127.0.0.1::2') self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "") - def test_src_with_host_port_empty_db_pass(self): + def test_src_with_host_port_empty_db_pass_no_namespace(self): plugin = TokenRedis('127.0.0.1:1234::verysecret') self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._port, 1234) self.assertEqual(plugin._db, 0) self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._namespace, "") - def test_src_with_host_empty_port_db_pass(self): + def test_src_with_host_empty_port_db_pass_no_namespace(self): plugin = TokenRedis('127.0.0.1::2:verysecret') self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, 'verysecret') + self.assertEqual(plugin._namespace, "") - def test_src_with_host_empty_port_db_empty_pass(self): + def test_src_with_host_empty_port_db_empty_pass_no_namespace(self): plugin = TokenRedis('127.0.0.1::2:') self.assertEqual(plugin._server, '127.0.0.1') self.assertEqual(plugin._port, 6379) self.assertEqual(plugin._db, 2) self.assertEqual(plugin._password, None) + self.assertEqual(plugin._namespace, "") diff --git a/websockify/token_plugins.py b/websockify/token_plugins.py index 9c36383a..5a95490c 100644 --- a/websockify/token_plugins.py +++ b/websockify/token_plugins.py @@ -7,6 +7,24 @@ logger = logging.getLogger(__name__) +_SOURCE_SPLIT_REGEX = re.compile( + r'(?<=^)"([^"]+)"(?=:|$)' + r'|(?<=:)"([^"]+)"(?=:|$)' + r'|(?<=^)([^:]*)(?=:|$)' + r'|(?<=:)([^:]*)(?=:|$)', +) + + +def parse_source_args(src): + """It works like src.split(":") but with the ability to use a colon + if you wrap the word in quotation marks. + + a:b:c:d -> ['a', 'b', 'c', 'd' + a:"b:c":c -> ['a', 'b:c', 'd'] + """ + matches = _SOURCE_SPLIT_REGEX.findall(src) + return [m[0] or m[1] or m[2] or m[3] for m in matches] + class BasePlugin(): def __init__(self, src): @@ -178,9 +196,9 @@ class TokenRedis(BasePlugin): The token source is in the format: - host[:port[:db[:password]]] + host[:port[:db[:password[:namespace]]]] - where port, db and password are optional. If port or db are left empty + where port, db, password and namespace are optional. If port or db are left empty they will take its default value, ie. 6379 and 0 respectively. If your redis server is using the default port (6379) then you can use: @@ -192,9 +210,18 @@ class TokenRedis(BasePlugin): my-redis-host:::verysecretpass + You can also specify a namespace. In this case, the tokens + will be stored in the format '{namespace}:{token}' + + my-redis-host::::my-app-namespace + + Or if your namespace is nested, you can wrap it in quotes: + + my-redis-host::::"first-ns:second-ns" + In the more general case you will use: - my-redis-host:6380:1:verysecretpass + my-redis-host:6380:1:verysecretpass:my-app-namespace The TokenRedis plugin expects the format of the target in one of these two formats: @@ -234,8 +261,9 @@ def __init__(self, src): self._port = 6379 self._db = 0 self._password = None + self._namespace = "" try: - fields = src.split(":") + fields = parse_source_args(src) if len(fields) == 1: self._server = fields[0] elif len(fields) == 2: @@ -256,15 +284,28 @@ def __init__(self, src): self._db = 0 if not self._password: self._password = None + elif len(fields) == 5: + self._server, self._port, self._db, self._password, self._namespace = fields + if not self._port: + self._port = 6379 + if not self._db: + self._db = 0 + if not self._password: + self._password = None + if not self._namespace: + self._namespace = "" else: raise ValueError self._port = int(self._port) self._db = int(self._db) - logger.info("TokenRedis backend initilized (%s:%s)" % + if self._namespace: + self._namespace += ":" + + logger.info("TokenRedis backend initialized (%s:%s)" % (self._server, self._port)) except ValueError: logger.error("The provided --token-source='%s' is not in the " - "expected format [:[:[:]]]" % + "expected format [:[:[:[:]]]]" % src) sys.exit() @@ -278,7 +319,7 @@ def lookup(self, token): logger.info("resolving token '%s'" % token) client = redis.Redis(host=self._server, port=self._port, db=self._db, password=self._password) - stuff = client.get(token) + stuff = client.get(self._namespace + token) if stuff is None: return None else: