Skip to content

Commit

Permalink
Add support for custom headers (#200)
Browse files Browse the repository at this point in the history
* Added custom headers

* code cleaning
  • Loading branch information
realcoloride authored Jul 29, 2024
1 parent e496ee9 commit 9962eb9
Showing 1 changed file with 50 additions and 57 deletions.
107 changes: 50 additions & 57 deletions src/WebSocket4Net/WebSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,17 @@ public class WebSocket : EasyClient<WebSocketPackage>, IWebSocket

public PingPongStatus PingPongStatus { get; private set; }

private string _origin;
private readonly string _origin;

private EndPoint _remoteEndPoint;
private readonly EndPoint _remoteEndPoint;

private static readonly IPackageEncoder<WebSocketPackage> _packageEncoder = new WebSocketEncoder();

private List<string> _subProtocols;

public IReadOnlyList<string> SubProtocols
{
get { return _subProtocols; }
}
public IReadOnlyList<string> SubProtocols => _subProtocols;

public Dictionary<string, string> Headers = new();

public WebSocketState State { get; private set; } = WebSocketState.None;

Expand Down Expand Up @@ -84,16 +83,14 @@ public WebSocket(string url, ConnectionOptions connectionOptions)

private EndPoint ResolveUri(Uri uri, int defaultPort)
{
IPAddress ipAddress;

EndPoint remoteEndPoint;

var port = uri.Port;

if (port <= 0)
port = defaultPort;

if (IPAddress.TryParse(uri.Host, out ipAddress))
if (IPAddress.TryParse(uri.Host, out IPAddress ipAddress))
remoteEndPoint = new IPEndPoint(ipAddress, port);
else
remoteEndPoint = new DnsEndPoint(uri.Host, port);
Expand All @@ -105,30 +102,28 @@ public void AddSubProtocol(string protocol)
{
var subProtocols = _subProtocols;

if (subProtocols == null)
subProtocols = _subProtocols = new List<string>();

subProtocols ??= _subProtocols = new List<string>();
subProtocols.Add(protocol);
}

protected override void SetupConnection(IConnection connection)
{
this.Closed += OnConnectionClosed;
Closed += OnConnectionClosed;
base.SetupConnection(connection);
}

public async ValueTask<bool> OpenAsync(CancellationToken cancellationToken = default)
{
State = WebSocketState.Connecting;

if (!await this.ConnectAsync(_remoteEndPoint, cancellationToken))
if (!await ConnectAsync(_remoteEndPoint, cancellationToken))
{
State = WebSocketState.Closed;
return false;
}

var (key, acceptKey) = MakeSecureKey();
await this.Connection.SendAsync((writer) => WriteHandshakeRequest(writer, key));
await Connection.SendAsync((writer) => WriteHandshakeRequest(writer, key));

var handshakeResponse = await ReceiveAsync();

Expand Down Expand Up @@ -171,13 +166,11 @@ public async ValueTask<bool> OpenAsync(CancellationToken cancellationToken = def
}

private string CalculateChallenge(string secKey, string magic)
{
return Convert.ToBase64String(SHA1.Create().ComputeHash(_asciiEncoding.GetBytes(secKey + magic)));
}
=> Convert.ToBase64String(SHA1.Create().ComputeHash(_asciiEncoding.GetBytes(secKey + magic)));

private (string, string) MakeSecureKey()
{
var secKey = Convert.ToBase64String(_asciiEncoding.GetBytes(Guid.NewGuid().ToString().Substring(0, 16)));
var secKey = Convert.ToBase64String(_asciiEncoding.GetBytes(Guid.NewGuid().ToString()[..16]));
return (secKey, CalculateChallenge(secKey, _magic));
}

Expand All @@ -189,7 +182,7 @@ private void WriteHandshakeRequest(PipeWriter writer, string secKey)
writer.Write($"{WebSocketConstant.ResponseConnectionLine}", _asciiEncoding);
writer.Write($"{WebSocketConstant.SecWebSocketKey}: {secKey}\r\n", _asciiEncoding);
writer.Write($"{WebSocketConstant.Origin}: {_origin}\r\n", _asciiEncoding);

var subProtocols = _subProtocols;

if (subProtocols != null && subProtocols.Count > 0)
Expand All @@ -199,19 +192,21 @@ private void WriteHandshakeRequest(PipeWriter writer, string secKey)
}

writer.Write($"{WebSocketConstant.SecWebSocketVersion}: 13\r\n\r\n", _asciiEncoding);
}

public new void StartReceive()
{
base.StartReceive();
// Write extra headers
foreach (var header in Headers)
writer.Write($"{header.Key}: {header.Value}\r\n", _asciiEncoding);

// Ensure end of the handshake request handshake
writer.Write("\r\n", _asciiEncoding);
}

public new void StartReceive() => base.StartReceive();

public new async ValueTask<WebSocketPackage> ReceiveAsync()
{
return await ReceiveAsync(
=> await ReceiveAsync(
handleControlPackage: true,
returnControlPackage: false);
}

internal async ValueTask<WebSocketPackage> ReceiveAsync(bool handleControlPackage, bool returnControlPackage)
{
Expand Down Expand Up @@ -251,39 +246,41 @@ private async ValueTask HandleControlPackage(WebSocketPackage package)
{
switch (package.OpCode)
{
case (OpCode.Close):
case OpCode.Close:
await HandleCloseHandshake(package);
break;

case (OpCode.Ping):
case OpCode.Ping:
PingPongStatus.OnPingReceived(package);
package.OpCode = OpCode.Pong;
await this.SendAsync(package);
await SendAsync(package);
break;

case (OpCode.Pong):
case OpCode.Pong:
PingPongStatus.OnPongReceived(package);
break;
}
}

public async ValueTask SendAsync(string message)
{
var package = new WebSocketPackage();
package.OpCode = OpCode.Text;
package.Message = message;
var package = new WebSocketPackage
{
OpCode = OpCode.Text,
Message = message
};
await SendAsync(package);
}

internal async ValueTask SendAsync(WebSocketPackage package)
{
await SendAsync(_packageEncoder, package);
}
=> await SendAsync(_packageEncoder, package);

public new async ValueTask SendAsync(ReadOnlyMemory<byte> data)
{
var package = new WebSocketPackage();
package.OpCode = OpCode.Binary;
var package = new WebSocketPackage
{
OpCode = OpCode.Binary
};

var sequenceElement = new SequenceSegment(data);
package.Data = new ReadOnlySequence<byte>(sequenceElement, 0, sequenceElement, sequenceElement.Memory.Length);
Expand All @@ -293,27 +290,25 @@ internal async ValueTask SendAsync(WebSocketPackage package)

public ValueTask SendAsync(ref ReadOnlySequence<byte> sequence)
{
var package = new WebSocketPackage();
package.OpCode = OpCode.Binary;
package.Data = sequence;
var package = new WebSocketPackage
{
OpCode = OpCode.Binary,
Data = sequence
};
return SendAsync(_packageEncoder, package);
}

private byte[] GetBuffer(int size)
{
return new byte[size];
}
private byte[] GetBuffer(int size) => new byte[size];

public override async ValueTask CloseAsync()
{
await CloseAsync(CloseReason.NormalClosure, string.Empty);
}
=> await CloseAsync(CloseReason.NormalClosure, string.Empty);

public async ValueTask CloseAsync(CloseReason closeReason, string message = null)
{
var package = new WebSocketPackage();

package.OpCode = OpCode.Close;
var package = new WebSocketPackage
{
OpCode = OpCode.Close
};

var bufferSize = !string.IsNullOrEmpty(message) ? _utf8Encoding.GetMaxByteCount(message.Length) : 0;
bufferSize += 2;
Expand Down Expand Up @@ -379,7 +374,7 @@ private async ValueTask HandleCloseHandshake(WebSocketPackage receivedClosePacka

if (closeStatus == null)
{
this.State = WebSocketState.CloseReceived;
State = WebSocketState.CloseReceived;
closeStatusFromRemote.RemoteInitiated = true;
CloseStatus = closeStatusFromRemote;
// Send close pong message to server side.
Expand All @@ -399,9 +394,7 @@ private async ValueTask HandleCloseHandshake(WebSocketPackage receivedClosePacka
}
}

private void OnConnectionClosed(object sender, EventArgs eventArgs)
{
this.State = WebSocketState.Closed;
}
private void OnConnectionClosed(object sender, EventArgs eventArgs)
=> State = WebSocketState.Closed;
}
}

0 comments on commit 9962eb9

Please sign in to comment.