Skip to content

Commit

Permalink
feat: add sentinel support
Browse files Browse the repository at this point in the history
  • Loading branch information
cyrinux committed Nov 16, 2024
1 parent 3ad7661 commit 102a6f0
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 21 deletions.
15 changes: 7 additions & 8 deletions flask_redis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from .client import FlaskRedis


__version__ = "0.5.0.dev0"
__version__ = "0.6.0"

__title__ = "flask-redis"
__description__ = "A nice way to use Redis in your Flask app"
__url__ = "https://github.com/underyx/flask-redis/"
__description__ = "A nice way to use Redis in your Flask app with sentinel support"
__url__ = "https://github.com/cyrinux/flask-redis/"
__uri__ = __url__

__author__ = "Bence Nagy"
__email__ = "[email protected]"
__author__ = "Cyrinux"
__email__ = "[email protected]"

__license__ = "Blue Oak License"
__copyright__ = "Copyright (c) 2019 Bence Nagy"
__license__ = "MIT"
__copyright__ = "Copyright (c) 2024"

__all__ = [FlaskRedis]
181 changes: 170 additions & 11 deletions flask_redis/client.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,206 @@
import ssl
from urllib.parse import parse_qs, unquote, urlparse

try:
import redis
from redis.sentinel import Sentinel
except ImportError:
# We can still allow custom provider-only usage without redis-py being installed
# Allow usage without redis-py being installed
redis = None
Sentinel = None


class FlaskRedis(object):
def __init__(self, app=None, strict=True, config_prefix="REDIS", **kwargs):
def __init__(
self,
app=None,
strict=True,
config_prefix="REDIS",
decode_responses=True,
**kwargs,
):
self._redis_client = None
self.provider_class = redis.StrictRedis if strict else redis.Redis
self.provider_kwargs = kwargs
self.config_prefix = config_prefix
self.decode_responses = decode_responses
self.provider_kwargs = kwargs

if app is not None:
self.init_app(app)

@classmethod
def from_custom_provider(cls, provider, app=None, **kwargs):
assert provider is not None, "your custom provider is None, come on"
assert provider is not None, "Your custom provider is None."

# We never pass the app parameter here, so we can call init_app
# ourselves later, after the provider class has been set
instance = cls(**kwargs)

instance.provider_class = provider
if app is not None:
instance.init_app(app)
return instance

def init_app(self, app, **kwargs):
redis_url = app.config.get(
"{0}_URL".format(self.config_prefix), "redis://localhost:6379/0"
f"{self.config_prefix}_URL", "redis://localhost:6379/0"
)

self.provider_kwargs.update(kwargs)
self._redis_client = self.provider_class.from_url(
redis_url, **self.provider_kwargs
)

parsed_url = urlparse(redis_url)
scheme = parsed_url.scheme

if scheme in ["redis+sentinel", "rediss+sentinel"]:
if Sentinel is None:
raise ImportError("redis-py must be installed to use Redis Sentinel.")
self._init_sentinel_client(parsed_url)
else:
self._init_standard_client(redis_url)

if not hasattr(app, "extensions"):
app.extensions = {}
app.extensions[self.config_prefix.lower()] = self

def _init_standard_client(self, redis_url):
self._redis_client = self.provider_class.from_url(
redis_url, decode_responses=self.decode_responses, **self.provider_kwargs
)

def _init_sentinel_client(self, parsed_url):
sentinel_kwargs, client_kwargs = self._parse_sentinel_parameters(parsed_url)

sentinel = Sentinel(
sentinel_kwargs["hosts"],
socket_timeout=sentinel_kwargs["socket_timeout"],
**sentinel_kwargs["ssl_params"],
**sentinel_kwargs["auth_params"],
**self.provider_kwargs,
)

self._redis_client = sentinel.master_for(
sentinel_kwargs["master_name"],
db=client_kwargs["db"],
socket_timeout=client_kwargs["socket_timeout"],
decode_responses=self.decode_responses,
**client_kwargs["ssl_params"],
**client_kwargs["auth_params"],
**self.provider_kwargs,
)

def _parse_sentinel_parameters(self, parsed_url):
username, password = self._extract_credentials(parsed_url)
hosts = self._parse_hosts(parsed_url)
master_name, db = self._parse_master_and_db(parsed_url)
query_params = parse_qs(parsed_url.query)

socket_timeout = self._parse_socket_timeout(query_params)
ssl_enabled = self._parse_ssl_enabled(parsed_url.scheme, query_params)
ssl_params = self._parse_ssl_params(query_params, ssl_enabled)
auth_params = self._parse_auth_params(username, password)

sentinel_kwargs = {
"hosts": hosts,
"socket_timeout": socket_timeout,
"ssl_params": ssl_params,
"auth_params": auth_params,
"master_name": master_name,
}

client_kwargs = {
"db": db,
"socket_timeout": socket_timeout,
"ssl_params": ssl_params,
"auth_params": auth_params,
}

return sentinel_kwargs, client_kwargs

def _extract_credentials(self, parsed_url):
username = parsed_url.username
password = parsed_url.password
return username, password

def _parse_hosts(self, parsed_url):
netloc = parsed_url.netloc
if "@" in netloc:
hosts_part = netloc.split("@", 1)[1]
else:
hosts_part = netloc

hosts = []
for host_port in hosts_part.split(","):
if ":" in host_port:
host, port = host_port.split(":", 1)
port = int(port)
else:
host = host_port
port = 26379 # Default Sentinel port
hosts.append((host, port))
return hosts

def _parse_master_and_db(self, parsed_url):
path = parsed_url.path.lstrip("/")
if "/" in path:
master_name, db_part = path.split("/", 1)
db = int(db_part)
else:
master_name = path
db = 0 # Default DB
return master_name, db

def _parse_socket_timeout(self, query_params):
socket_timeout = query_params.get("socket_timeout", [None])[0]
if socket_timeout is not None:
return float(socket_timeout)
return None

def _parse_ssl_enabled(self, scheme, query_params):
if scheme == "rediss+sentinel":
return True
ssl_param = query_params.get("ssl", ["False"])[0].lower()
return ssl_param == "true"

def _parse_ssl_params(self, query_params, ssl_enabled):
ssl_params = {}
if ssl_enabled:
ssl_cert_reqs = self._parse_ssl_cert_reqs(query_params)
ssl_keyfile = query_params.get("ssl_keyfile", [None])[0]
ssl_certfile = query_params.get("ssl_certfile", [None])[0]
ssl_ca_certs = query_params.get("ssl_ca_certs", [None])[0]

ssl_params = {"ssl": True}
if ssl_cert_reqs is not None:
ssl_params["ssl_cert_reqs"] = ssl_cert_reqs
if ssl_keyfile:
ssl_params["ssl_keyfile"] = ssl_keyfile
if ssl_certfile:
ssl_params["ssl_certfile"] = ssl_certfile
if ssl_ca_certs:
ssl_params["ssl_ca_certs"] = ssl_ca_certs
return ssl_params

def _parse_ssl_cert_reqs(self, query_params):
ssl_cert_reqs = query_params.get("ssl_cert_reqs", [None])[0]
if ssl_cert_reqs:
ssl_cert_reqs = ssl_cert_reqs.lower()
return {
"required": ssl.CERT_REQUIRED,
"optional": ssl.CERT_OPTIONAL,
"none": ssl.CERT_NONE,
}.get(ssl_cert_reqs)
return None

def _parse_auth_params(self, username, password):
auth_params = {}
if username:
auth_params["username"] = username
if password:
auth_params["password"] = password
return auth_params

def hmset(self, name, mapping):
# Implement hmset for compatibility
# Use hset with mapping parameter
return self._redis_client.hset(name, mapping=mapping)

def __getattr__(self, name):
return getattr(self._redis_client, name)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Flask>=0.9
redis>=2.6.2
redis>=5.0.0
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ isolated_build = true

[testenv]
deps =
oldpy2deps: redis==2.6.2
oldpy2deps: redis==5.0.0
oldpy2deps: flask==0.8.0
oldpy2deps: werkzeug==0.8.3
oldpy3deps: redis==2.6.2
Expand Down

0 comments on commit 102a6f0

Please sign in to comment.