Skip to content

Commit

Permalink
Merge pull request #217 from klattimer/master
Browse files Browse the repository at this point in the history
exclude certain headers when requested
  • Loading branch information
tito authored May 18, 2017
2 parents b69394a + ca3c890 commit cdf50bd
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 10 deletions.
13 changes: 9 additions & 4 deletions ws4py/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class WebSocketBaseClient(WebSocket):
def __init__(self, url, protocols=None, extensions=None,
heartbeat_freq=None, ssl_options=None, headers=None):
heartbeat_freq=None, ssl_options=None, headers=None, exclude_headers=None):
"""
A websocket client that implements :rfc:`6455` and provides a simple
interface to communicate with a websocket server.
Expand Down Expand Up @@ -78,6 +78,8 @@ def __init__(self, url, protocols=None, extensions=None,
self.resource = None
self.ssl_options = ssl_options or {}
self.extra_headers = headers or []
self.exclude_headers = exclude_headers or []
self.exclude_headers = [x.lower() for x in self.exclude_headers]

if self.scheme == "wss":
# Prevent check_hostname requires server_hostname (ref #187)
Expand Down Expand Up @@ -211,7 +213,7 @@ def connect(self):
# default port is now 443; upgrade self.sender to send ssl
self.sock = ssl.wrap_socket(self.sock, **self.ssl_options)
self._is_secure = True

self.sock.connect(self.bind_addr)

self._write(self.handshake_request)
Expand Down Expand Up @@ -257,14 +259,15 @@ def handshake_headers(self):
('Sec-WebSocket-Key', self.key.decode('utf-8')),
('Sec-WebSocket-Version', str(max(WS_VERSION)))
]

if self.protocols:
headers.append(('Sec-WebSocket-Protocol', ','.join(self.protocols)))

if self.extra_headers:
headers.extend(self.extra_headers)

if not any(x for x in headers if x[0].lower() == 'origin'):
if not any(x for x in headers if x[0].lower() == 'origin') and \
'origin' not in self.exclude_headers:

scheme, url = self.url.split(":", 1)
parsed = urlsplit(url, scheme="http")
Expand All @@ -277,6 +280,8 @@ def handshake_headers(self):
origin = origin + ':' + str(parsed.port)
headers.append(('Origin', origin))

headers = [x for x in headers if x[0].lower() not in self.exclude_headers]

return headers

@property
Expand Down
4 changes: 2 additions & 2 deletions ws4py/client/geventclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
__all__ = ['WebSocketClient']

class WebSocketClient(WebSocketBaseClient):
def __init__(self, url, protocols=None, extensions=None, heartbeat_freq=None, ssl_options=None, headers=None):
def __init__(self, url, protocols=None, extensions=None, heartbeat_freq=None, ssl_options=None, headers=None, exclude_headers=None):
"""
WebSocket client that executes the
:meth:`run() <ws4py.websocket.WebSocket.run>` into a gevent greenlet.
Expand Down Expand Up @@ -41,7 +41,7 @@ def outgoing():
gevent.joinall(greenlets)
"""
WebSocketBaseClient.__init__(self, url, protocols, extensions, heartbeat_freq,
ssl_options=ssl_options, headers=headers)
ssl_options=ssl_options, headers=headers, exclude_headers=exclude_headers)
self._th = Greenlet(self.run)

self.messages = Queue()
Expand Down
4 changes: 2 additions & 2 deletions ws4py/client/threadedclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class WebSocketClient(WebSocketBaseClient):
def __init__(self, url, protocols=None, extensions=None, heartbeat_freq=None,
ssl_options=None, headers=None):
ssl_options=None, headers=None, exclude_headers=None):
"""
.. code-block:: python
Expand All @@ -32,7 +32,7 @@ def received_message(self, m):
"""
WebSocketBaseClient.__init__(self, url, protocols, extensions, heartbeat_freq,
ssl_options, headers=headers)
ssl_options, headers=headers, exclude_headers=exclude_headers)
self._th = threading.Thread(target=self.run, name='WebSocketClient')
self._th.daemon = True

Expand Down
4 changes: 2 additions & 2 deletions ws4py/client/tornadoclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class TornadoWebSocketClient(WebSocketBaseClient):
def __init__(self, url, protocols=None, extensions=None,
io_loop=None, ssl_options=None, headers=None):
io_loop=None, ssl_options=None, headers=None, exclude_headers=None):
"""
.. code-block:: python
Expand All @@ -32,7 +32,7 @@ def closed(self, code, reason=None):
ioloop.IOLoop.instance().start()
"""
WebSocketBaseClient.__init__(self, url, protocols, extensions,
ssl_options=ssl_options, headers=headers)
ssl_options=ssl_options, headers=headers, exclude_headers=exclude_headers)
if self.scheme == "wss":
self.sock = ssl.wrap_socket(self.sock, do_handshake_on_connect=False, **self.ssl_options)
self._is_secure = True
Expand Down

0 comments on commit cdf50bd

Please sign in to comment.