Skip to content

Commit

Permalink
Merge pull request #62 from noisyboiler/refactor-wampy-websockets
Browse files Browse the repository at this point in the history
Refactor wampy websockets
  • Loading branch information
noisyboiler authored Mar 17, 2018
2 parents 869ae32 + 5a88fab commit a941484
Show file tree
Hide file tree
Showing 9 changed files with 368 additions and 177 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"pytest-capturelog==0.7",
"colorlog",
"flake8==3.5.0",
"gevent-websocket==0.10.1",
],
'docs': [
"Sphinx==1.4.5",
Expand Down
6 changes: 4 additions & 2 deletions test/test_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import datetime
from time import sleep

import gevent
import pytest

from wampy.peers.clients import Client
Expand Down Expand Up @@ -169,8 +169,10 @@ class MyClient(Client):
client.start()
wait_for_session(client)

sleep(5)
gevent.sleep(5)

# this is purely to demonstrate we can make calls while sending
# pongs
client.publish(topic="test", message="test")
client.stop()
except Exception as e:
Expand Down
File renamed without changes.
78 changes: 78 additions & 0 deletions test/transports/test_websockets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging
from collections import OrderedDict

import pytest
import gevent
from gevent import Greenlet
from geventwebsocket import (
WebSocketApplication, WebSocketServer, Resource,
)
from mock import ANY
from mock import call, patch

from wampy.transports.websocket.connection import WebSocket
from wampy.transports.websocket.frames import Ping

logger = logging.getLogger(__name__)


class TestApplication(WebSocketApplication):
pass


@pytest.fixture
def server():
s = WebSocketServer(
('0.0.0.0', 8001),
Resource(OrderedDict([('/', TestApplication)]))
)
s.start()
thread = Greenlet.spawn(s.serve_forever)
yield s
s.stop()
thread.kill()


def test_send_ping(server):
websocket = WebSocket(server_url='ws://0.0.0.0:8001')
with patch.object(websocket, 'handle_ping') as mock_handle:
assert websocket.connected is False

websocket.connect(upgrade=False)

def connection_handler():
while True:
try:
message = websocket.receive()
except Exception:
logger.execption('connection handler exploded')
raise
if message:
logger.info('got message: %s', message)

assert websocket.connected is True

# the first bytes sent down the connection are the response bytes
# to the TCP connection and upgrade. we receieve in this thread
# because it will block all execution
Greenlet.spawn(connection_handler)
gevent.sleep(0.01) # enough for the upgrade to happen

clients = server.clients
assert len(clients) == 1

client_handler = list(clients.values())[0]
socket = client_handler.ws

ping_frame = Ping()
socket.send(ping_frame.frame)

with gevent.Timeout(5):
while mock_handle.call_count != 1:
gevent.sleep(0.01)

assert mock_handle.call_count == 1
assert mock_handle.call_args == call(ping_frame=ANY)

call_param = mock_handle.call_args[1]['ping_frame']
assert isinstance(call_param, Ping)
18 changes: 10 additions & 8 deletions wampy/peers/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,20 @@ def __init__(
# as WebSocket messages by default (well, actually... that's because no
# other transports are supported!)
if self.router.scheme == "ws":
self.transport = WebSocket()
self.transport = WebSocket(
server_url=self.router.url, ipv=self.router.ipv,
)
elif self.router.scheme == "wss":
self.transport = SecureWebSocket()
self.transport = SecureWebSocket(
server_url=self.router.url, ipv=self.router.ipv,
certificate_path=self.router.certificate,
)
else:
raise WampyError(
'Network protocl must be "ws" or "wss"'
)

# the transport is responsible for the connection.
self.transport.register_router(self.router)

# generally ``name`` is used for debuggubg and logging only
# generally ``name`` is used for debugging and logging only
self.name = name or self.__class__.__name__

self._session = None
Expand Down Expand Up @@ -152,8 +154,8 @@ def publish(self):
return PublishProxy(client=self)

def start(self):
# establish the underlying connection. this will raise on error.
connection = self.transport.connect()
# establish the underlying connection and upgrade it to WAMP.
connection = self.transport.connect(upgrade=True)

# create a Session repr between ourselves and the Router.
# pass in the live connection over a transport that the Session
Expand Down
1 change: 0 additions & 1 deletion wampy/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import logging
from functools import partial

import gevent
import gevent.queue
Expand Down
4 changes: 0 additions & 4 deletions wampy/transports/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
@six.add_metaclass(abc.ABCMeta)
class Transport(object):

@abc.abstractmethod
def register_router(self, router):
pass

@abc.abstractmethod
def connect(self):
""" should return ``self`` as the "connection" object """
Expand Down
71 changes: 37 additions & 34 deletions wampy/transports/websocket/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,32 @@
IncompleteFrameError, ConnectionError, WampProtocolError, WampyError)
from wampy.mixins import ParseUrlMixin
from wampy.transports.interface import Transport
from wampy.serializers import json_serialize

from . frames import ClientFrame, ServerFrame, PongFrame
from . frames import ClientFrame, FrameFactory, PongFrame

logger = logging.getLogger(__name__)


class WebSocket(Transport, ParseUrlMixin):

def register_router(self, router):
self.url = router.url
def __init__(self, server_url, ipv=4):
self.url = server_url
self.ipv = ipv

self.host = None
self.port = None
self.ipv = router.ipv
self.resource = None

self.parse_url()
self.websocket_location = self.resource
self.key = encodestring(uuid.uuid4().bytes).decode('utf-8').strip()
self.socket = None
self.connected = False

def connect(self):
def connect(self, upgrade=True):
# TCP connection
self._connect()
self._upgrade()
self._handshake(upgrade=upgrade)
return self

def disconnect(self):
Expand All @@ -52,53 +53,48 @@ def disconnect(self):
self.socket.close()

def send(self, message):
serialized_message = json_serialize(message)
frame = ClientFrame(serialized_message)
websocket_message = frame.payload
frame = ClientFrame(message)
websocket_message = frame.frame
self._send_raw(websocket_message)

def _send_raw(self, websocket_message):
logger.debug('send raw: %s', websocket_message)
self.socket.sendall(websocket_message)

def receive(self, bufsize=1):
frame = None
received_bytes = bytearray()

while True:
logger.debug("waiting for %s bytes", bufsize)

try:
bytes = self.socket.recv(bufsize)
except gevent.greenlet.GreenletExit as exc:
raise ConnectionError('Connection closed: "{}"'.format(exc))
except socket.timeout as e:
message = str(e)
raise ConnectionError('timeout: "{}"'.format(message))
except Exception as exc:
raise ConnectionError(
'unexpected error reading from socket: "{}"'.format(exc)
)

if not bytes:
break

logger.debug("received %s bytes", bufsize)
received_bytes.extend(bytes)

try:
frame = ServerFrame(received_bytes)
frame = FrameFactory.from_bytes(received_bytes)
except IncompleteFrameError as exc:
bufsize = exc.required_bytes
logger.debug('now requesting the missing %s bytes', bufsize)
else:
if frame.opcode == 9:
if frame.opcode == frame.OPCODE_PING:
# Opcode 0x9 marks a ping frame. It does not contain wamp
# data, so the frame is not returned.
# Still it must be handled or the server will close the
# connection.
self._send_raw(PongFrame(frame.payload).payload)
self.handle_ping(ping_frame=frame)
received_bytes = bytearray()
continue
if frame.opcode == frame.OPCODE_BINARY:
break

break

if frame is None:
Expand Down Expand Up @@ -143,8 +139,8 @@ def _connect(self):
self.socket = _socket
logger.debug("socket connected")

def _upgrade(self):
handshake_headers = self._get_handshake_headers()
def _handshake(self, upgrade):
handshake_headers = self._get_handshake_headers(upgrade=upgrade)
handshake = '\r\n'.join(handshake_headers) + "\r\n\r\n"

self.socket.send(handshake.encode())
Expand All @@ -159,7 +155,7 @@ def _upgrade(self):

logger.debug("connection upgraded")

def _get_handshake_headers(self):
def _get_handshake_headers(self, upgrade):
""" Do an HTTP upgrade handshake with the server.
Websockets upgrade from HTTP rather than TCP largely because it was
Expand All @@ -184,8 +180,11 @@ def _get_handshake_headers(self):
headers.append("Sec-WebSocket-Key: {}".format(self.key))
headers.append("Origin: ws://{}:{}".format(self.host, self.port))
headers.append("Sec-WebSocket-Version: {}".format(WEBSOCKET_VERSION))
headers.append("Sec-WebSocket-Protocol: {}".format(
WEBSOCKET_SUBPROTOCOLS))

if upgrade:
headers.append("Sec-WebSocket-Protocol: {}".format(
WEBSOCKET_SUBPROTOCOLS)
)

logger.debug("connection headers: %s", headers)

Expand All @@ -200,8 +199,8 @@ def _read_handshake_response(self):
def read_line():
bytes_cache = []
received_bytes = None
while received_bytes != b'\r\n':
received_bytes = self.socket.recv(2)
while received_bytes not in [b'\r\n', b'\n', b'\n\r']:
received_bytes = self.socket.recv(1)
bytes_cache.append(received_bytes)
return b''.join(bytes_cache)

Expand Down Expand Up @@ -241,15 +240,19 @@ def read_line():
headers[key.lower()] = value.strip().lower()

logger.info("handshake complete: %s : %s", status, headers)

self.connected = True
return status, headers

def handle_ping(self, ping_frame):
pong_frame = PongFrame(ping_frame=ping_frame)
bytes = pong_frame.frame
logger.info('sending pong: %s', bytes)
self._send_raw(bytes)

class SecureWebSocket(WebSocket):
def register_router(self, router):
super(SecureWebSocket, self).register_router(router)

self.ipv = router.ipv
class SecureWebSocket(WebSocket):
def __init__(self, server_url, certificate_path, ipv=4):
super(SecureWebSocket, self).__init__(server_url=server_url, ipv=ipv)

# PROTOCOL_TLSv1_1 and PROTOCOL_TLSv1_2 are only available if Python is
# linked with OpenSSL 1.0.1 or later.
Expand All @@ -258,7 +261,7 @@ def register_router(self, router):
except AttributeError:
raise WampyError("Your Python Environment does not support TLS")

self.certificate = router.certificate
self.certificate = certificate_path

def _connect(self):
_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand Down
Loading

0 comments on commit a941484

Please sign in to comment.