Skip to content

Commit

Permalink
Add option to listen to unix socket
Browse files Browse the repository at this point in the history
  • Loading branch information
shiomax committed Dec 14, 2022
1 parent 71d55fc commit 63eb24d
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 34 deletions.
20 changes: 18 additions & 2 deletions websockify/websocketproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
'''

import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl
import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat
from socketserver import ThreadingMixIn
from http.server import HTTPServer

Expand Down Expand Up @@ -112,7 +112,9 @@ def new_websocket_client(self):
self.server.target_host, self.server.target_port, e)
raise self.CClose(1011, "Failed to connect to downstream server")

self.request.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
# Option unavailable when listening to unix socket
if not self.server.unix_listen:
self.request.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
if not self.server.wrap_cmd and not self.server.unix_target:
tsock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)

Expand Down Expand Up @@ -467,6 +469,10 @@ def websockify_init():
parser.add_option("--ssl-ciphers", action="store",
help="list of ciphers allowed for connection. For a list of "
"supported ciphers run `openssl ciphers`")
parser.add_option("--unix-listen",
help="listen to unix socket", metavar="FILE", default=None)
parser.add_option("--unix-listen-mode", default=None,
help="specify mode for unix socket (defaults to 0600)")
parser.add_option("--unix-target",
help="connect to unix socket target", metavar="FILE")
parser.add_option("--inetd",
Expand Down Expand Up @@ -617,6 +623,16 @@ def websockify_init():

if opts.inetd:
opts.listen_fd = sys.stdin.fileno()
elif opts.unix_listen:
if opts.unix_listen_mode:
try:
# Parse octal notation (like 750)
opts.unix_listen_mode = int(opts.unix_listen_mode, 8)
except ValueError:
parser.error("Error parsing listen unix socket mode")
else:
# Default to 0600 (Owner Read/Write)
opts.unix_listen_mode = stat.S_IREAD | stat.S_IWRITE
else:
if len(args) < 1:
parser.error("Too few arguments")
Expand Down
92 changes: 60 additions & 32 deletions websockify/websockifyserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,37 +325,40 @@ def __init__(self, RequestHandlerClass, listen_fd=None,
file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0):
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0,
unix_listen=None, unix_listen_mode=None):

# settings
self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose
self.listen_fd = listen_fd
self.listen_host = listen_host
self.listen_port = listen_port
self.prefer_ipv6 = source_is_ipv6
self.ssl_only = ssl_only
self.ssl_ciphers = ssl_ciphers
self.ssl_options = ssl_options
self.verify_client = verify_client
self.daemon = daemon
self.run_once = run_once
self.timeout = timeout
self.idle_timeout = idle_timeout
self.traffic = traffic
self.file_only = file_only
self.web_auth = web_auth

self.launch_time = time.time()
self.ws_connection = False
self.handler_id = 1
self.terminating = False

self.logger = self.get_logger()
self.tcp_keepalive = tcp_keepalive
self.tcp_keepcnt = tcp_keepcnt
self.tcp_keepidle = tcp_keepidle
self.tcp_keepintvl = tcp_keepintvl
self.verbose = verbose
self.listen_fd = listen_fd
self.unix_listen = unix_listen
self.unix_listen_mode = unix_listen_mode
self.listen_host = listen_host
self.listen_port = listen_port
self.prefer_ipv6 = source_is_ipv6
self.ssl_only = ssl_only
self.ssl_ciphers = ssl_ciphers
self.ssl_options = ssl_options
self.verify_client = verify_client
self.daemon = daemon
self.run_once = run_once
self.timeout = timeout
self.idle_timeout = idle_timeout
self.traffic = traffic
self.file_only = file_only
self.web_auth = web_auth

self.launch_time = time.time()
self.ws_connection = False
self.handler_id = 1
self.terminating = False

self.logger = self.get_logger()
self.tcp_keepalive = tcp_keepalive
self.tcp_keepcnt = tcp_keepcnt
self.tcp_keepidle = tcp_keepidle
self.tcp_keepintvl = tcp_keepintvl

# keyfile path must be None if not specified
self.key = None
Expand Down Expand Up @@ -387,6 +390,8 @@ def __init__(self, RequestHandlerClass, listen_fd=None,
self.msg("WebSocket server settings:")
if self.listen_fd != None:
self.msg(" - Listen for inetd connections")
elif self.unix_listen != None:
self.msg(" - Listen on unix socket %s", self.unix_listen)
else:
self.msg(" - Listen on %s:%s",
self.listen_host, self.listen_port)
Expand Down Expand Up @@ -421,8 +426,9 @@ def get_logger():

@staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False,
unix_socket=None, use_ssl=False, tcp_keepalive=True,
tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None):
unix_socket=None, unix_socket_mode=None, unix_socket_listen=False,
use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None,
tcp_keepidle=None, tcp_keepintvl=None):
""" Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set,
otherwise connect to it. Return the socket.
Expand Down Expand Up @@ -470,8 +476,22 @@ def socket(host, port=None, connect=False, prefer_ipv6=False,
sock.bind(addrs[0][4])
sock.listen(100)
else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(unix_socket)
if unix_socket_listen:
# Make sure the socket does not already exist
try:
os.unlink(unix_socket)
except FileNotFoundError:
pass
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
oldmask = os.umask(0o777 ^ unix_socket_mode)
try:
sock.bind(unix_socket)
finally:
os.umask(oldmask)
sock.listen(100)
else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(unix_socket)

return sock

Expand Down Expand Up @@ -700,6 +720,11 @@ def start_server(self):

if self.listen_fd != None:
lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM)
elif self.unix_listen != None:
lsock = self.socket(host=None,
unix_socket=self.unix_listen,
unix_socket_mode=self.unix_listen_mode,
unix_socket_listen=True)
else:
lsock = self.socket(self.listen_host, self.listen_port, False,
self.prefer_ipv6,
Expand Down Expand Up @@ -766,6 +791,9 @@ def start_server(self):
ready = select.select([lsock], [], [], 1)[0]
if lsock in ready:
startsock, address = lsock.accept()
# Unix Socket will not report address (empty string), but address[0] is logged a bunch
if self.unix_listen != None:
address = [ self.unix_listen ]
else:
continue
except self.Terminate:
Expand Down

0 comments on commit 63eb24d

Please sign in to comment.