diff --git a/flask_redis/__init__.py b/flask_redis/__init__.py index a51d95d..33538de 100644 --- a/flask_redis/__init__.py +++ b/flask_redis/__init__.py @@ -1,17 +1,16 @@ from .client import FlaskRedis - -__version__ = "0.5.0.dev0" +__version__ = "0.6.0" __title__ = "flask-redis" -__description__ = "A nice way to use Redis in your Flask app" -__url__ = "https://github.com/underyx/flask-redis/" +__description__ = "A nice way to use Redis in your Flask app with sentinel support" +__url__ = "https://github.com/cyrinux/flask-redis/" __uri__ = __url__ -__author__ = "Bence Nagy" -__email__ = "bence@underyx.me" +__author__ = "Cyrinux" +__email__ = "python@levis.name" -__license__ = "Blue Oak License" -__copyright__ = "Copyright (c) 2019 Bence Nagy" +__license__ = "Blue Oak Model License" +__copyright__ = "Copyright (c) 2024" __all__ = [FlaskRedis] diff --git a/flask_redis/client.py b/flask_redis/client.py index c33004f..89e19a2 100644 --- a/flask_redis/client.py +++ b/flask_redis/client.py @@ -1,28 +1,38 @@ +import ssl +from urllib.parse import parse_qs, unquote, urlparse + try: import redis + from redis.sentinel import Sentinel except ImportError: - # We can still allow custom provider-only usage without redis-py being installed + # Allow usage without redis-py being installed redis = None + Sentinel = None class FlaskRedis(object): - def __init__(self, app=None, strict=True, config_prefix="REDIS", **kwargs): + def __init__( + self, + app=None, + strict=True, + config_prefix="REDIS", + decode_responses=True, + **kwargs, + ): self._redis_client = None self.provider_class = redis.StrictRedis if strict else redis.Redis - self.provider_kwargs = kwargs self.config_prefix = config_prefix + self.decode_responses = decode_responses + self.provider_kwargs = kwargs if app is not None: self.init_app(app) @classmethod def from_custom_provider(cls, provider, app=None, **kwargs): - assert provider is not None, "your custom provider is None, come on" + assert provider is not None, "Your custom provider is None." - # We never pass the app parameter here, so we can call init_app - # ourselves later, after the provider class has been set instance = cls(**kwargs) - instance.provider_class = provider if app is not None: instance.init_app(app) @@ -30,18 +40,167 @@ def from_custom_provider(cls, provider, app=None, **kwargs): def init_app(self, app, **kwargs): redis_url = app.config.get( - "{0}_URL".format(self.config_prefix), "redis://localhost:6379/0" + f"{self.config_prefix}_URL", "redis://localhost:6379/0" ) self.provider_kwargs.update(kwargs) - self._redis_client = self.provider_class.from_url( - redis_url, **self.provider_kwargs - ) + + parsed_url = urlparse(redis_url) + scheme = parsed_url.scheme + + if scheme in ["redis+sentinel", "rediss+sentinel"]: + if Sentinel is None: + raise ImportError("redis-py must be installed to use Redis Sentinel.") + self._init_sentinel_client(parsed_url) + else: + self._init_standard_client(redis_url) if not hasattr(app, "extensions"): app.extensions = {} app.extensions[self.config_prefix.lower()] = self + def _init_standard_client(self, redis_url): + self._redis_client = self.provider_class.from_url( + redis_url, decode_responses=self.decode_responses, **self.provider_kwargs + ) + + def _init_sentinel_client(self, parsed_url): + sentinel_kwargs, client_kwargs = self._parse_sentinel_parameters(parsed_url) + + sentinel = Sentinel( + sentinel_kwargs["hosts"], + socket_timeout=sentinel_kwargs["socket_timeout"], + **sentinel_kwargs["ssl_params"], + **sentinel_kwargs["auth_params"], + **self.provider_kwargs, + ) + + self._redis_client = sentinel.master_for( + sentinel_kwargs["master_name"], + db=client_kwargs["db"], + socket_timeout=client_kwargs["socket_timeout"], + decode_responses=self.decode_responses, + **client_kwargs["ssl_params"], + **client_kwargs["auth_params"], + **self.provider_kwargs, + ) + + def _parse_sentinel_parameters(self, parsed_url): + username, password = self._extract_credentials(parsed_url) + hosts = self._parse_hosts(parsed_url) + master_name, db = self._parse_master_and_db(parsed_url) + query_params = parse_qs(parsed_url.query) + + socket_timeout = self._parse_socket_timeout(query_params) + ssl_enabled = self._parse_ssl_enabled(parsed_url.scheme, query_params) + ssl_params = self._parse_ssl_params(query_params, ssl_enabled) + auth_params = self._parse_auth_params(username, password) + + sentinel_kwargs = { + "hosts": hosts, + "socket_timeout": socket_timeout, + "ssl_params": ssl_params, + "auth_params": auth_params, + "master_name": master_name, + } + + client_kwargs = { + "db": db, + "socket_timeout": socket_timeout, + "ssl_params": ssl_params, + "auth_params": auth_params, + } + + return sentinel_kwargs, client_kwargs + + def _extract_credentials(self, parsed_url): + username = parsed_url.username + password = parsed_url.password + return username, password + + def _parse_hosts(self, parsed_url): + netloc = parsed_url.netloc + if "@" in netloc: + hosts_part = netloc.split("@", 1)[1] + else: + hosts_part = netloc + + hosts = [] + for host_port in hosts_part.split(","): + if ":" in host_port: + host, port = host_port.split(":", 1) + port = int(port) + else: + host = host_port + port = 26379 # Default Sentinel port + hosts.append((host, port)) + return hosts + + def _parse_master_and_db(self, parsed_url): + path = parsed_url.path.lstrip("/") + if "/" in path: + master_name, db_part = path.split("/", 1) + db = int(db_part) + else: + master_name = path + db = 0 # Default DB + return master_name, db + + def _parse_socket_timeout(self, query_params): + socket_timeout = query_params.get("socket_timeout", [None])[0] + if socket_timeout is not None: + return float(socket_timeout) + return None + + def _parse_ssl_enabled(self, scheme, query_params): + if scheme == "rediss+sentinel": + return True + ssl_param = query_params.get("ssl", ["False"])[0].lower() + return ssl_param == "true" + + def _parse_ssl_params(self, query_params, ssl_enabled): + ssl_params = {} + if ssl_enabled: + ssl_cert_reqs = self._parse_ssl_cert_reqs(query_params) + ssl_keyfile = query_params.get("ssl_keyfile", [None])[0] + ssl_certfile = query_params.get("ssl_certfile", [None])[0] + ssl_ca_certs = query_params.get("ssl_ca_certs", [None])[0] + + ssl_params = {"ssl": True} + if ssl_cert_reqs is not None: + ssl_params["ssl_cert_reqs"] = ssl_cert_reqs + if ssl_keyfile: + ssl_params["ssl_keyfile"] = ssl_keyfile + if ssl_certfile: + ssl_params["ssl_certfile"] = ssl_certfile + if ssl_ca_certs: + ssl_params["ssl_ca_certs"] = ssl_ca_certs + return ssl_params + + def _parse_ssl_cert_reqs(self, query_params): + ssl_cert_reqs = query_params.get("ssl_cert_reqs", [None])[0] + if ssl_cert_reqs: + ssl_cert_reqs = ssl_cert_reqs.lower() + return { + "required": ssl.CERT_REQUIRED, + "optional": ssl.CERT_OPTIONAL, + "none": ssl.CERT_NONE, + }.get(ssl_cert_reqs) + return None + + def _parse_auth_params(self, username, password): + auth_params = {} + if username: + auth_params["username"] = username + if password: + auth_params["password"] = password + return auth_params + + def hmset(self, name, mapping): + # Implement hmset for compatibility + # Use hset with mapping parameter + return self._redis_client.hset(name, mapping=mapping) + def __getattr__(self, name): return getattr(self._redis_client, name) diff --git a/requirements.txt b/requirements.txt index 375c09a..d95827b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ Flask>=0.9 -redis>=2.6.2 +redis>=5.0.0 diff --git a/test/integration/test_client.py b/test/integration/test_client.py index cbc6b20..7d60c5b 100644 --- a/test/integration/test_client.py +++ b/test/integration/test_client.py @@ -4,8 +4,9 @@ import flask import pytest +from unittest import mock -from flask_redis import client as uut +from flask_redis import FlaskRedis # Adjust the import based on your package structure @pytest.fixture @@ -14,70 +15,107 @@ def app(): def test_constructor(app): - """Test that a constructor with app instance will initialize the - connection""" - redis = uut.FlaskRedis(app) - assert redis._redis_client is not None - assert hasattr(redis._redis_client, "connection_pool") + """Test that a constructor with app instance initializes the connection.""" + redis_client = FlaskRedis(app) + assert redis_client._redis_client is not None + assert hasattr(redis_client._redis_client, "connection_pool") def test_init_app(app): - """Test that a constructor without app instance will not initialize the - connection. - - After FlaskRedis.init_app(app) is called, the connection will be - initialized.""" - redis = uut.FlaskRedis() - assert redis._redis_client is None - redis.init_app(app) - assert redis._redis_client is not None - assert hasattr(redis._redis_client, "connection_pool") + """Test that a constructor without app instance does not initialize the connection. + + After FlaskRedis.init_app(app) is called, the connection is initialized.""" + redis_client = FlaskRedis() + assert redis_client._redis_client is None + redis_client.init_app(app) + assert redis_client._redis_client is not None + assert hasattr(redis_client._redis_client, "connection_pool") if hasattr(app, "extensions"): assert "redis" in app.extensions - assert app.extensions["redis"] == redis + assert app.extensions["redis"] == redis_client def test_custom_prefix(app): - """Test that config prefixes enable distinct connections""" + """Test that config prefixes enable distinct connections.""" app.config["DBA_URL"] = "redis://localhost:6379/1" app.config["DBB_URL"] = "redis://localhost:6379/2" - redis_a = uut.FlaskRedis(app, config_prefix="DBA") - redis_b = uut.FlaskRedis(app, config_prefix="DBB") - assert redis_a.connection_pool.connection_kwargs["db"] == 1 - assert redis_b.connection_pool.connection_kwargs["db"] == 2 + redis_a = FlaskRedis(app, config_prefix="DBA") + redis_b = FlaskRedis(app, config_prefix="DBB") + assert redis_a._redis_client.connection_pool.connection_kwargs["db"] == 1 + assert redis_b._redis_client.connection_pool.connection_kwargs["db"] == 2 @pytest.mark.parametrize( ["strict_flag", "allowed_names"], [ - [ - True, - # StrictRedis points to Redis in newer versions - {"Redis", "StrictRedis"}, - ], + [True, {"Redis", "StrictRedis"}], [False, {"Redis"}], ], ) def test_strict_parameter(app, strict_flag, allowed_names): - """Test that initializing with the strict parameter set to True will use - StrictRedis, and that False will keep using the old Redis class.""" + """Test that initializing with the strict parameter uses the correct client class.""" + redis_client = FlaskRedis(app, strict=strict_flag) + assert redis_client._redis_client is not None + assert type(redis_client._redis_client).__name__ in allowed_names + + +def test_sentinel_connection(app, mocker): + """Test that FlaskRedis can connect to Redis Sentinel.""" + app.config["REDIS_URL"] = "redis+sentinel://localhost:26379/mymaster/0" + + # Mock Sentinel to prevent actual network calls + mock_sentinel = mocker.patch("flask_redis.Sentinel", autospec=True) + mock_sentinel_instance = mock_sentinel.return_value + mock_master_for = mock_sentinel_instance.master_for + mock_master_for.return_value = mock.MagicMock() + + redis_client = FlaskRedis(app) + + # Verify that Sentinel was initialized with the correct parameters + mock_sentinel.assert_called_once() + mock_master_for.assert_called_once_with( + "mymaster", + db=0, + socket_timeout=None, + decode_responses=True, + ssl_params={}, + auth_params={}, + ) + assert redis_client._redis_client is not None + + +def test_ssl_connection(app): + """Test that FlaskRedis can connect with SSL parameters.""" + app.config["REDIS_URL"] = "rediss://localhost:6379/0" + redis_client = FlaskRedis(app) + assert redis_client._redis_client is not None + assert redis_client._redis_client.connection_pool.connection_kwargs.get("ssl") is True + + +def test_ssl_sentinel_connection(app, mocker): + """Test that FlaskRedis can connect to Redis Sentinel with SSL.""" + app.config[ + "REDIS_URL" + ] = "rediss+sentinel://localhost:26379/mymaster/0?ssl_cert_reqs=required" + + # Mock Sentinel to prevent actual network calls + mock_sentinel = mocker.patch("flask_redis.Sentinel", autospec=True) + mock_sentinel_instance = mock_sentinel.return_value + mock_master_for = mock_sentinel_instance.master_for + mock_master_for.return_value = mock.MagicMock() + + redis_client = FlaskRedis(app) + + # Verify that Sentinel was initialized with SSL parameters + expected_ssl_params = {"ssl": True, "ssl_cert_reqs": ssl.CERT_REQUIRED} + mock_sentinel.assert_called_once() + mock_master_for.assert_called_once_with( + "mymaster", + db=0, + socket_timeout=None, + decode_responses=True, + ssl_params=expected_ssl_params, + auth_params={}, + ) + assert redis_client._redis_client is not None - redis = uut.FlaskRedis(app, strict=strict_flag) - assert redis._redis_client is not None - assert type(redis._redis_client).__name__ in allowed_names - - -def test_custom_provider(app): - """Test that FlaskRedis can be instructed to use a different Redis client, - like StrictRedis""" - - class FakeProvider(object): - @classmethod - def from_url(cls, *args, **kwargs): - return cls() - - redis = uut.FlaskRedis.from_custom_provider(FakeProvider) - assert redis._redis_client is None - redis.init_app(app) - assert redis._redis_client is not None - assert isinstance(redis._redis_client, FakeProvider) diff --git a/test/unit/test_client.py b/test/unit/test_client.py index b02c07d..23e1499 100644 --- a/test/unit/test_client.py +++ b/test/unit/test_client.py @@ -1,11 +1,13 @@ -from flask_redis import client as uut +from unittest import mock + +from flask_redis import FlaskRedis def test_constructor_app(mocker): - """Test that the constructor passes the app to FlaskRedis.init_app""" - mocker.patch.object(uut.FlaskRedis, "init_app", autospec=True) + """Test that the constructor passes the app to FlaskRedis.init_app.""" + mocker.patch.object(FlaskRedis, "init_app", autospec=True) app_stub = mocker.stub(name="app_stub") - uut.FlaskRedis(app_stub) + FlaskRedis(app_stub) - uut.FlaskRedis.init_app.assert_called_once_with(mocker.ANY, app_stub) + FlaskRedis.init_app.assert_called_once_with(mock.ANY, app_stub) diff --git a/tox.ini b/tox.ini index ca6d5a4..94a5ca4 100644 --- a/tox.ini +++ b/tox.ini @@ -11,7 +11,7 @@ isolated_build = true [testenv] deps = - oldpy2deps: redis==2.6.2 + oldpy2deps: redis==5.0.0 oldpy2deps: flask==0.8.0 oldpy2deps: werkzeug==0.8.3 oldpy3deps: redis==2.6.2