Skip to content

Commit

Permalink
Simplify event dispatch in WebSocket
Browse files Browse the repository at this point in the history
Instead of using custom class hierarchy of messages, we can just use native C# lambdas stored as Actions.

OnConnectError case can be further simplified by merging two branches for different exception types into one.
  • Loading branch information
RReverser committed May 13, 2024
1 parent e79dcea commit 6a85c0c
Showing 1 changed file with 23 additions and 114 deletions.
137 changes: 23 additions & 114 deletions src/WebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,99 +8,6 @@

namespace SpacetimeDB
{
internal abstract class MainThreadDispatch
{
public abstract void Execute();
}

class OnConnectMessage : MainThreadDispatch
{
private WebSocketOpenEventHandler receiver;

public OnConnectMessage(WebSocketOpenEventHandler receiver)
{
this.receiver = receiver;
}

public override void Execute()
{
receiver.Invoke();
}
}

class OnDisconnectMessage : MainThreadDispatch
{
private WebSocketCloseEventHandler receiver;
private WebSocketError? error;
private WebSocketCloseStatus? status;

public OnDisconnectMessage(WebSocketCloseEventHandler receiver, WebSocketCloseStatus? status,
WebSocketError? error)
{
this.receiver = receiver;
this.error = error;
this.status = status;
}

public override void Execute()
{
receiver.Invoke(status, error);
}
}

class OnConnectErrorMessage : MainThreadDispatch
{
private WebSocketConnectErrorEventHandler receiver;
private WebSocketError? error;
private string? errorMsg;

public OnConnectErrorMessage(WebSocketConnectErrorEventHandler receiver, WebSocketError? error, string? errorMsg)
{
this.receiver = receiver;
this.error = error;
this.errorMsg = errorMsg;
}

public override void Execute()
{
receiver.Invoke(error, errorMsg);
}
}

class OnSendErrorMessage : MainThreadDispatch
{
private WebSocketSendErrorEventHandler receiver;
private Exception e;

public OnSendErrorMessage(WebSocketSendErrorEventHandler receiver, Exception e)
{
this.receiver = receiver;
this.e = e;
}

public override void Execute()
{
receiver.Invoke(e);
}
}

class OnMessage : MainThreadDispatch
{
private WebSocketMessageEventHandler receiver;
private byte[] message;

public OnMessage(WebSocketMessageEventHandler receiver, byte[] message)
{
this.receiver = receiver;
this.message = message;
}

public override void Execute()
{
receiver.Invoke(message);
}
}

public delegate void WebSocketOpenEventHandler();

public delegate void WebSocketMessageEventHandler(byte[] message);
Expand All @@ -124,7 +31,7 @@ public class WebSocket
// Connection parameters
private readonly ConnectOptions _options;
private readonly byte[] _receiveBuffer = new byte[MAXMessageSize];
private readonly ConcurrentQueue<MainThreadDispatch> dispatchQueue = new();
private readonly ConcurrentQueue<Action> dispatchQueue = new();

protected ClientWebSocket Ws = new();

Expand Down Expand Up @@ -161,24 +68,22 @@ public async Task Connect(string auth, string host, string nameOrAddress, Addres
try
{
await Ws.ConnectAsync(url, source.Token);
if (OnConnect != null) dispatchQueue.Enqueue(new OnConnectMessage(OnConnect));
if (OnConnect != null) dispatchQueue.Enqueue(() => OnConnect());
}
catch (WebSocketException ex)
catch (Exception ex)
{
string message = ex.Message;
if (ex.WebSocketErrorCode == WebSocketError.NotAWebSocket)
Logger.LogException(ex);
if (OnConnectError != null)
{
// not a websocket happens when there is no module published under the address specified
message = $"{message} Did you forget to publish your module?";
var message = ex.Message;
var code = (ex as WebSocketException)?.WebSocketErrorCode;
if (code == WebSocketError.NotAWebSocket)
{
// not a websocket happens when there is no module published under the address specified
message = $"{message} Did you forget to publish your module?";
}
dispatchQueue.Enqueue(() => OnConnectError(code, message));
}
Logger.LogException(ex);
if (OnConnectError != null) dispatchQueue.Enqueue(new OnConnectErrorMessage(OnConnectError, ex.WebSocketErrorCode, message));
return;
}
catch (Exception e)
{
Logger.LogException(e);
if (OnConnectError != null) dispatchQueue.Enqueue(new OnConnectErrorMessage(OnConnectError, null, e.Message));
return;
}

Expand All @@ -195,7 +100,7 @@ public async Task Connect(string auth, string host, string nameOrAddress, Addres
await Ws.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty,
CancellationToken.None);
}
if (OnClose != null) dispatchQueue.Enqueue(new OnDisconnectMessage(OnClose, receiveResult.CloseStatus, null));
if (OnClose != null) dispatchQueue.Enqueue(() => OnClose(receiveResult.CloseStatus, null));
return;
}

Expand All @@ -208,7 +113,7 @@ await Ws.CloseAsync(WebSocketCloseStatus.NormalClosure, string.Empty,
var closeMessage = $"Maximum message size: {MAXMessageSize} bytes.";
await Ws.CloseAsync(WebSocketCloseStatus.MessageTooBig, closeMessage,
CancellationToken.None);
if (OnClose != null) dispatchQueue.Enqueue(new OnDisconnectMessage(OnClose, WebSocketCloseStatus.MessageTooBig, null));
if (OnClose != null) dispatchQueue.Enqueue(() => OnClose(WebSocketCloseStatus.MessageTooBig, null));
return;
}

Expand All @@ -218,11 +123,15 @@ await Ws.CloseAsync(WebSocketCloseStatus.MessageTooBig, closeMessage,
count += receiveResult.Count;
}

if (OnMessage != null) dispatchQueue.Enqueue(new OnMessage(OnMessage, _receiveBuffer.Take(count).ToArray()));
if (OnMessage != null)
{
var message = _receiveBuffer.Take(count).ToArray();
dispatchQueue.Enqueue(() => OnMessage(message));
}
}
catch (WebSocketException ex)
{
if (OnClose != null) dispatchQueue.Enqueue(new OnDisconnectMessage(OnClose, null, ex.WebSocketErrorCode));
if (OnClose != null) dispatchQueue.Enqueue(() => OnClose(null, ex.WebSocketErrorCode));
return;
}
}
Expand Down Expand Up @@ -278,7 +187,7 @@ private async Task ProcessSendQueue()
catch (Exception e)
{
senderTask = null;
if (OnSendError != null) dispatchQueue.Enqueue(new OnSendErrorMessage(OnSendError, e));
if (OnSendError != null) dispatchQueue.Enqueue(() => OnSendError(e));
}
}

Expand All @@ -291,7 +200,7 @@ public void Update()
{
while (dispatchQueue.TryDequeue(out var result))
{
result.Execute();
result();
}
}
}
Expand Down

0 comments on commit 6a85c0c

Please sign in to comment.