Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to listen to unix socket #541

Merged
merged 1 commit into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions websockify/websocketproxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

'''

import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl
import signal, socket, optparse, time, os, sys, subprocess, logging, errno, ssl, stat
from socketserver import ThreadingMixIn
from http.server import HTTPServer

Expand Down Expand Up @@ -112,7 +112,9 @@ def new_websocket_client(self):
self.server.target_host, self.server.target_port, e)
raise self.CClose(1011, "Failed to connect to downstream server")

self.request.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
# Option unavailable when listening to unix socket
if not self.server.unix_listen:
self.request.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)
if not self.server.wrap_cmd and not self.server.unix_target:
tsock.setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1)

Expand Down Expand Up @@ -467,6 +469,10 @@ def websockify_init():
parser.add_option("--ssl-ciphers", action="store",
help="list of ciphers allowed for connection. For a list of "
"supported ciphers run `openssl ciphers`")
parser.add_option("--unix-listen",
help="listen to unix socket", metavar="FILE", default=None)
parser.add_option("--unix-listen-mode", default=None,
help="specify mode for unix socket (defaults to 0600)")
parser.add_option("--unix-target",
help="connect to unix socket target", metavar="FILE")
parser.add_option("--inetd",
Expand Down Expand Up @@ -617,6 +623,16 @@ def websockify_init():

if opts.inetd:
opts.listen_fd = sys.stdin.fileno()
elif opts.unix_listen:
if opts.unix_listen_mode:
try:
# Parse octal notation (like 750)
opts.unix_listen_mode = int(opts.unix_listen_mode, 8)
except ValueError:
parser.error("Error parsing listen unix socket mode")
else:
# Default to 0600 (Owner Read/Write)
opts.unix_listen_mode = stat.S_IREAD | stat.S_IWRITE
else:
if len(args) < 1:
parser.error("Too few arguments")
Expand Down
92 changes: 60 additions & 32 deletions websockify/websockifyserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,37 +325,40 @@ def __init__(self, RequestHandlerClass, listen_fd=None,
file_only=False,
run_once=False, timeout=0, idle_timeout=0, traffic=False,
tcp_keepalive=True, tcp_keepcnt=None, tcp_keepidle=None,
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0):
tcp_keepintvl=None, ssl_ciphers=None, ssl_options=0,
unix_listen=None, unix_listen_mode=None):

# settings
self.RequestHandlerClass = RequestHandlerClass
self.verbose = verbose
self.listen_fd = listen_fd
self.listen_host = listen_host
self.listen_port = listen_port
self.prefer_ipv6 = source_is_ipv6
self.ssl_only = ssl_only
self.ssl_ciphers = ssl_ciphers
self.ssl_options = ssl_options
self.verify_client = verify_client
self.daemon = daemon
self.run_once = run_once
self.timeout = timeout
self.idle_timeout = idle_timeout
self.traffic = traffic
self.file_only = file_only
self.web_auth = web_auth

self.launch_time = time.time()
self.ws_connection = False
self.handler_id = 1
self.terminating = False

self.logger = self.get_logger()
self.tcp_keepalive = tcp_keepalive
self.tcp_keepcnt = tcp_keepcnt
self.tcp_keepidle = tcp_keepidle
self.tcp_keepintvl = tcp_keepintvl
self.verbose = verbose
self.listen_fd = listen_fd
self.unix_listen = unix_listen
self.unix_listen_mode = unix_listen_mode
self.listen_host = listen_host
self.listen_port = listen_port
self.prefer_ipv6 = source_is_ipv6
self.ssl_only = ssl_only
self.ssl_ciphers = ssl_ciphers
self.ssl_options = ssl_options
self.verify_client = verify_client
self.daemon = daemon
self.run_once = run_once
self.timeout = timeout
self.idle_timeout = idle_timeout
self.traffic = traffic
self.file_only = file_only
self.web_auth = web_auth

self.launch_time = time.time()
self.ws_connection = False
self.handler_id = 1
self.terminating = False

self.logger = self.get_logger()
self.tcp_keepalive = tcp_keepalive
self.tcp_keepcnt = tcp_keepcnt
self.tcp_keepidle = tcp_keepidle
self.tcp_keepintvl = tcp_keepintvl

# keyfile path must be None if not specified
self.key = None
Expand Down Expand Up @@ -387,6 +390,8 @@ def __init__(self, RequestHandlerClass, listen_fd=None,
self.msg("WebSocket server settings:")
if self.listen_fd != None:
self.msg(" - Listen for inetd connections")
elif self.unix_listen != None:
self.msg(" - Listen on unix socket %s", self.unix_listen)
else:
self.msg(" - Listen on %s:%s",
self.listen_host, self.listen_port)
Expand Down Expand Up @@ -421,8 +426,9 @@ def get_logger():

@staticmethod
def socket(host, port=None, connect=False, prefer_ipv6=False,
unix_socket=None, use_ssl=False, tcp_keepalive=True,
tcp_keepcnt=None, tcp_keepidle=None, tcp_keepintvl=None):
unix_socket=None, unix_socket_mode=None, unix_socket_listen=False,
use_ssl=False, tcp_keepalive=True, tcp_keepcnt=None,
tcp_keepidle=None, tcp_keepintvl=None):
""" Resolve a host (and optional port) to an IPv4 or IPv6
address. Create a socket. Bind to it if listen is set,
otherwise connect to it. Return the socket.
Expand Down Expand Up @@ -470,8 +476,22 @@ def socket(host, port=None, connect=False, prefer_ipv6=False,
sock.bind(addrs[0][4])
sock.listen(100)
else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(unix_socket)
if unix_socket_listen:
# Make sure the socket does not already exist
try:
os.unlink(unix_socket)
except FileNotFoundError:
pass
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
oldmask = os.umask(0o777 ^ unix_socket_mode)
try:
sock.bind(unix_socket)
finally:
os.umask(oldmask)
sock.listen(100)
else:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(unix_socket)

return sock

Expand Down Expand Up @@ -700,6 +720,11 @@ def start_server(self):

if self.listen_fd != None:
lsock = socket.fromfd(self.listen_fd, socket.AF_INET, socket.SOCK_STREAM)
elif self.unix_listen != None:
lsock = self.socket(host=None,
unix_socket=self.unix_listen,
unix_socket_mode=self.unix_listen_mode,
unix_socket_listen=True)
else:
lsock = self.socket(self.listen_host, self.listen_port, False,
shiomax marked this conversation as resolved.
Show resolved Hide resolved
self.prefer_ipv6,
Expand Down Expand Up @@ -766,6 +791,9 @@ def start_server(self):
ready = select.select([lsock], [], [], 1)[0]
if lsock in ready:
startsock, address = lsock.accept()
# Unix Socket will not report address (empty string), but address[0] is logged a bunch
if self.unix_listen != None:
address = [ self.unix_listen ]
else:
continue
except self.Terminate:
Expand Down