Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit e4b5d48
Merge: a24ff07 ec52843
Author: Benoit Chesneau <[email protected]>
Date:   Sat Aug 10 10:23:20 2024 +0200

    Merge pull request benoitc#3188 from tnusser/patch-1

    Update sock.py by making bind list creation more readable

commit ec52843
Author: Tobias Nusser <[email protected]>
Date:   Wed Apr 17 12:48:26 2024 +0200

    Update sock.py by making bind list creation more readable

    Implementing suggestion by reviewer benoitc#3127 (comment) to make code more readable.

commit a24ff07
Author: Randall Leeds <[email protected]>
Date:   Thu Dec 28 00:57:50 2023 -0800

    Use plain socket objects instead of wrapper classes

    Refactor socket creation to remove the socket wrapper classes so that
    these objects have less surprising behavior when used in worker hooks,
    worker classes, and custom applications.

    Close benoitc#3013.
  • Loading branch information
rverelabs committed Sep 27, 2024
1 parent 497ad24 commit 01dbf55
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 194 deletions.
4 changes: 2 additions & 2 deletions gunicorn/arbiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def start(self):

self.LISTENERS = sock.create_sockets(self.cfg, self.log, fds)

listeners_str = ",".join([str(lnr) for lnr in self.LISTENERS])
listeners_str = ",".join([sock.get_uri(lnr) for lnr in self.LISTENERS])
self.log.debug("Arbiter booted")
self.log.info("Listening at: %s (%s)", listeners_str, self.pid)
self.log.info("Using worker: %s", self.cfg.worker_class_str)
Expand Down Expand Up @@ -460,7 +460,7 @@ def reload(self):
lnr.close()
# init new listeners
self.LISTENERS = sock.create_sockets(self.cfg, self.log)
listeners_str = ",".join([str(lnr) for lnr in self.LISTENERS])
listeners_str = ",".join([sock.get_uri(lnr) for lnr in self.LISTENERS])
self.log.info("Listening at: %s", listeners_str)

# do some actions on reload
Expand Down
4 changes: 2 additions & 2 deletions gunicorn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2101,7 +2101,7 @@ class KeyFile(Setting):
section = "SSL"
cli = ["--keyfile"]
meta = "FILE"
validator = validate_string
validator = validate_file_exists
default = None
desc = """\
SSL key file
Expand All @@ -2113,7 +2113,7 @@ class CertFile(Setting):
section = "SSL"
cli = ["--certfile"]
meta = "FILE"
validator = validate_string
validator = validate_file_exists
default = None
desc = """\
SSL certificate file
Expand Down
277 changes: 106 additions & 171 deletions gunicorn/sock.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,130 +13,56 @@
from gunicorn import util


class BaseSocket:

def __init__(self, address, conf, log, fd=None):
self.log = log
self.conf = conf

self.cfg_addr = address
if fd is None:
sock = socket.socket(self.FAMILY, socket.SOCK_STREAM)
bound = False
def _get_socket_family(addr):
if isinstance(addr, tuple):
if util.is_ipv6(addr[0]):
return socket.AF_INET6
else:
sock = socket.fromfd(fd, self.FAMILY, socket.SOCK_STREAM)
os.close(fd)
bound = True

self.sock = self.set_options(sock, bound=bound)

def __str__(self):
return "<socket %d>" % self.sock.fileno()

def __getattr__(self, name):
return getattr(self.sock, name)

def set_options(self, sock, bound=False):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if (self.conf.reuse_port
and hasattr(socket, 'SO_REUSEPORT')): # pragma: no cover
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except OSError as err:
if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL):
raise
if not bound:
self.bind(sock)
sock.setblocking(0)
return socket.AF_INET

# make sure that the socket can be inherited
if hasattr(sock, "set_inheritable"):
sock.set_inheritable(True)
if isinstance(addr, (str, bytes)):
return socket.AF_UNIX

sock.listen(self.conf.backlog)
return sock
raise TypeError("Unable to determine socket family for: %r" % addr)

def bind(self, sock):
sock.bind(self.cfg_addr)

def close(self):
if self.sock is None:
return
def create_socket(conf, log, addr):
family = _get_socket_family(addr)

if family is socket.AF_UNIX:
# remove any existing socket at the given path
try:
self.sock.close()
st = os.stat(addr)
except OSError as e:
self.log.info("Error while closing socket %s", str(e))

self.sock = None


class TCPSocket(BaseSocket):

FAMILY = socket.AF_INET

def __str__(self):
if self.conf.is_ssl:
scheme = "https"
if e.args[0] != errno.ENOENT:
raise
else:
scheme = "http"

addr = self.sock.getsockname()
return "%s://%s:%d" % (scheme, addr[0], addr[1])

def set_options(self, sock, bound=False):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return super().set_options(sock, bound=bound)


class TCP6Socket(TCPSocket):

FAMILY = socket.AF_INET6

def __str__(self):
(host, port, _, _) = self.sock.getsockname()
return "http://[%s]:%d" % (host, port)


class UnixSocket(BaseSocket):

FAMILY = socket.AF_UNIX

def __init__(self, addr, conf, log, fd=None):
if fd is None:
try:
st = os.stat(addr)
except OSError as e:
if e.args[0] != errno.ENOENT:
raise
if stat.S_ISSOCK(st.st_mode):
os.remove(addr)
else:
if stat.S_ISSOCK(st.st_mode):
os.remove(addr)
else:
raise ValueError("%r is not a socket" % addr)
super().__init__(addr, conf, log, fd=fd)

def __str__(self):
return "unix:%s" % self.cfg_addr

def bind(self, sock):
old_umask = os.umask(self.conf.umask)
sock.bind(self.cfg_addr)
util.chown(self.cfg_addr, self.conf.uid, self.conf.gid)
os.umask(old_umask)
raise ValueError("%r is not a socket" % addr)

for i in range(5):
try:
sock = socket.socket(family)
sock.bind(addr)
sock.listen(conf.backlog)
if family is socket.AF_UNIX:
util.chown(addr, conf.uid, conf.gid)
return sock
except socket.error as e:
if e.args[0] == errno.EADDRINUSE:
log.error("Connection in use: %s", str(addr))
if e.args[0] == errno.EADDRNOTAVAIL:
log.error("Invalid address: %s", str(addr))
if i < 5:
msg = "connection to {addr} failed: {error}"
log.debug(msg.format(addr=str(addr), error=str(e)))
log.error("Retrying in 1 second.")
time.sleep(1)

def _sock_type(addr):
if isinstance(addr, tuple):
if util.is_ipv6(addr[0]):
sock_type = TCP6Socket
else:
sock_type = TCPSocket
elif isinstance(addr, (str, bytes)):
sock_type = UnixSocket
else:
raise TypeError("Unable to create socket from: %r" % addr)
return sock_type
log.error("Can't connect to %s", str(addr))
sys.exit(1)


def create_sockets(conf, log, fds=None):
Expand All @@ -149,72 +75,79 @@ def create_sockets(conf, log, fds=None):
"""
listeners = []

# get it only once
addr = conf.address
fdaddr = [bind for bind in addr if isinstance(bind, int)]
if fds:
fdaddr += list(fds)
laddr = [bind for bind in addr if not isinstance(bind, int)]

# check ssl config early to raise the error on startup
# only the certfile is needed since it can contains the keyfile
if conf.certfile and not os.path.exists(conf.certfile):
raise ValueError('certfile "%s" does not exist' % conf.certfile)

if conf.keyfile and not os.path.exists(conf.keyfile):
raise ValueError('keyfile "%s" does not exist' % conf.keyfile)

# sockets are already bound
if fdaddr:
for fd in fdaddr:
sock = socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM)
sock_name = sock.getsockname()
sock_type = _sock_type(sock_name)
listener = sock_type(sock_name, conf, log, fd=fd)
listeners.append(listener)

return listeners

# no sockets is bound, first initialization of gunicorn in this env.
for addr in laddr:
sock_type = _sock_type(addr)
sock = None
for i in range(5):
try:
sock = sock_type(addr, conf, log)
except OSError as e:
if e.args[0] == errno.EADDRINUSE:
log.error("Connection in use: %s", str(addr))
if e.args[0] == errno.EADDRNOTAVAIL:
log.error("Invalid address: %s", str(addr))
msg = "connection to {addr} failed: {error}"
log.error(msg.format(addr=str(addr), error=str(e)))
if i < 5:
log.debug("Retrying in 1 second.")
time.sleep(1)
else:
break

if sock is None:
log.error("Can't connect to %s", str(addr))
sys.exit(1)

listeners.append(sock)
# sockets are already bound
listeners = []
for fd in list(fds) + [a for a in conf.address if isinstance(a, int)]:
sock = socket.socket(fileno=fd)
set_socket_options(conf, sock)
listeners.append(sock)
else:
# first initialization of gunicorn
old_umask = os.umask(conf.umask)
try:
bind_list = [bind for bind in conf.address if not isinstance(bind, int)]
for addr in bind_list:
sock = create_socket(conf, log, addr)
set_socket_options(conf, sock)
listeners.append(sock)
finally:
os.umask(old_umask)

return listeners


def close_sockets(listeners, unlink=True):
for sock in listeners:
sock_name = sock.getsockname()
sock.close()
if unlink and _sock_type(sock_name) is UnixSocket:
os.unlink(sock_name)
try:
if unlink and sock.family is socket.AF_UNIX:
sock_name = sock.getsockname()
os.unlink(sock_name)
finally:
sock.close()


def get_uri(listener, is_ssl=False):
addr = listener.getsockname()
family = _get_socket_family(addr)
scheme = "https" if is_ssl else "http"

if family is socket.AF_INET:
(host, port) = listener.getsockname()
return f"{scheme}://{host}:{port}"

if family is socket.AF_INET6:
(host, port, _, _) = listener.getsockname()
return f"{scheme}://[{host}]:{port}"

if family is socket.AF_UNIX:
path = listener.getsockname()
return f"unix://{path}"


def set_socket_options(conf, sock):
sock.setblocking(False)

# make sure that the socket can be inherited
if hasattr(sock, "set_inheritable"):
sock.set_inheritable(True)

if sock.family in (socket.AF_INET, socket.AF_INET6):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if conf.reuse_port and hasattr(socket, "SO_REUSEPORT"): # pragma: no cover
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
except socket.error as err:
if err.errno not in (errno.ENOPROTOOPT, errno.EINVAL):
raise


def ssl_context(conf):
def default_ssl_context_factory():
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH, cafile=conf.ca_certs)
context = ssl.create_default_context(
ssl.Purpose.CLIENT_AUTH, cafile=conf.ca_certs
)
context.load_cert_chain(certfile=conf.certfile, keyfile=conf.keyfile)
context.verify_mode = conf.cert_reqs
if conf.ciphers:
Expand All @@ -225,7 +158,9 @@ def default_ssl_context_factory():


def ssl_wrap_socket(sock, conf):
return ssl_context(conf).wrap_socket(sock,
server_side=True,
suppress_ragged_eofs=conf.suppress_ragged_eofs,
do_handshake_on_connect=conf.do_handshake_on_connect)
return ssl_context(conf).wrap_socket(
sock,
server_side=True,
suppress_ragged_eofs=conf.suppress_ragged_eofs,
do_handshake_on_connect=conf.do_handshake_on_connect,
)
Loading

0 comments on commit 01dbf55

Please sign in to comment.