From 9962eb9ecc416a04e3ff0a40708a41452dd9c2b6 Mon Sep 17 00:00:00 2001 From: Coloride <108619637+realcoloride@users.noreply.github.com> Date: Mon, 29 Jul 2024 20:01:13 +0200 Subject: [PATCH] Add support for custom headers (#200) * Added custom headers * code cleaning --- src/WebSocket4Net/WebSocket.cs | 107 +++++++++++++++------------------ 1 file changed, 50 insertions(+), 57 deletions(-) diff --git a/src/WebSocket4Net/WebSocket.cs b/src/WebSocket4Net/WebSocket.cs index 49fb708..13c1f17 100644 --- a/src/WebSocket4Net/WebSocket.cs +++ b/src/WebSocket4Net/WebSocket.cs @@ -32,18 +32,17 @@ public class WebSocket : EasyClient, IWebSocket public PingPongStatus PingPongStatus { get; private set; } - private string _origin; + private readonly string _origin; - private EndPoint _remoteEndPoint; + private readonly EndPoint _remoteEndPoint; private static readonly IPackageEncoder _packageEncoder = new WebSocketEncoder(); private List _subProtocols; - public IReadOnlyList SubProtocols - { - get { return _subProtocols; } - } + public IReadOnlyList SubProtocols => _subProtocols; + + public Dictionary Headers = new(); public WebSocketState State { get; private set; } = WebSocketState.None; @@ -84,8 +83,6 @@ public WebSocket(string url, ConnectionOptions connectionOptions) private EndPoint ResolveUri(Uri uri, int defaultPort) { - IPAddress ipAddress; - EndPoint remoteEndPoint; var port = uri.Port; @@ -93,7 +90,7 @@ private EndPoint ResolveUri(Uri uri, int defaultPort) if (port <= 0) port = defaultPort; - if (IPAddress.TryParse(uri.Host, out ipAddress)) + if (IPAddress.TryParse(uri.Host, out IPAddress ipAddress)) remoteEndPoint = new IPEndPoint(ipAddress, port); else remoteEndPoint = new DnsEndPoint(uri.Host, port); @@ -105,15 +102,13 @@ public void AddSubProtocol(string protocol) { var subProtocols = _subProtocols; - if (subProtocols == null) - subProtocols = _subProtocols = new List(); - + subProtocols ??= _subProtocols = new List(); subProtocols.Add(protocol); } protected override void SetupConnection(IConnection connection) { - this.Closed += OnConnectionClosed; + Closed += OnConnectionClosed; base.SetupConnection(connection); } @@ -121,14 +116,14 @@ public async ValueTask OpenAsync(CancellationToken cancellationToken = def { State = WebSocketState.Connecting; - if (!await this.ConnectAsync(_remoteEndPoint, cancellationToken)) + if (!await ConnectAsync(_remoteEndPoint, cancellationToken)) { State = WebSocketState.Closed; return false; } var (key, acceptKey) = MakeSecureKey(); - await this.Connection.SendAsync((writer) => WriteHandshakeRequest(writer, key)); + await Connection.SendAsync((writer) => WriteHandshakeRequest(writer, key)); var handshakeResponse = await ReceiveAsync(); @@ -171,13 +166,11 @@ public async ValueTask OpenAsync(CancellationToken cancellationToken = def } private string CalculateChallenge(string secKey, string magic) - { - return Convert.ToBase64String(SHA1.Create().ComputeHash(_asciiEncoding.GetBytes(secKey + magic))); - } + => Convert.ToBase64String(SHA1.Create().ComputeHash(_asciiEncoding.GetBytes(secKey + magic))); private (string, string) MakeSecureKey() { - var secKey = Convert.ToBase64String(_asciiEncoding.GetBytes(Guid.NewGuid().ToString().Substring(0, 16))); + var secKey = Convert.ToBase64String(_asciiEncoding.GetBytes(Guid.NewGuid().ToString()[..16])); return (secKey, CalculateChallenge(secKey, _magic)); } @@ -189,7 +182,7 @@ private void WriteHandshakeRequest(PipeWriter writer, string secKey) writer.Write($"{WebSocketConstant.ResponseConnectionLine}", _asciiEncoding); writer.Write($"{WebSocketConstant.SecWebSocketKey}: {secKey}\r\n", _asciiEncoding); writer.Write($"{WebSocketConstant.Origin}: {_origin}\r\n", _asciiEncoding); - + var subProtocols = _subProtocols; if (subProtocols != null && subProtocols.Count > 0) @@ -199,19 +192,21 @@ private void WriteHandshakeRequest(PipeWriter writer, string secKey) } writer.Write($"{WebSocketConstant.SecWebSocketVersion}: 13\r\n\r\n", _asciiEncoding); - } - public new void StartReceive() - { - base.StartReceive(); + // Write extra headers + foreach (var header in Headers) + writer.Write($"{header.Key}: {header.Value}\r\n", _asciiEncoding); + + // Ensure end of the handshake request handshake + writer.Write("\r\n", _asciiEncoding); } + public new void StartReceive() => base.StartReceive(); + public new async ValueTask ReceiveAsync() - { - return await ReceiveAsync( + => await ReceiveAsync( handleControlPackage: true, returnControlPackage: false); - } internal async ValueTask ReceiveAsync(bool handleControlPackage, bool returnControlPackage) { @@ -251,17 +246,17 @@ private async ValueTask HandleControlPackage(WebSocketPackage package) { switch (package.OpCode) { - case (OpCode.Close): + case OpCode.Close: await HandleCloseHandshake(package); break; - case (OpCode.Ping): + case OpCode.Ping: PingPongStatus.OnPingReceived(package); package.OpCode = OpCode.Pong; - await this.SendAsync(package); + await SendAsync(package); break; - case (OpCode.Pong): + case OpCode.Pong: PingPongStatus.OnPongReceived(package); break; } @@ -269,21 +264,23 @@ private async ValueTask HandleControlPackage(WebSocketPackage package) public async ValueTask SendAsync(string message) { - var package = new WebSocketPackage(); - package.OpCode = OpCode.Text; - package.Message = message; + var package = new WebSocketPackage + { + OpCode = OpCode.Text, + Message = message + }; await SendAsync(package); } internal async ValueTask SendAsync(WebSocketPackage package) - { - await SendAsync(_packageEncoder, package); - } + => await SendAsync(_packageEncoder, package); public new async ValueTask SendAsync(ReadOnlyMemory data) { - var package = new WebSocketPackage(); - package.OpCode = OpCode.Binary; + var package = new WebSocketPackage + { + OpCode = OpCode.Binary + }; var sequenceElement = new SequenceSegment(data); package.Data = new ReadOnlySequence(sequenceElement, 0, sequenceElement, sequenceElement.Memory.Length); @@ -293,27 +290,25 @@ internal async ValueTask SendAsync(WebSocketPackage package) public ValueTask SendAsync(ref ReadOnlySequence sequence) { - var package = new WebSocketPackage(); - package.OpCode = OpCode.Binary; - package.Data = sequence; + var package = new WebSocketPackage + { + OpCode = OpCode.Binary, + Data = sequence + }; return SendAsync(_packageEncoder, package); } - private byte[] GetBuffer(int size) - { - return new byte[size]; - } + private byte[] GetBuffer(int size) => new byte[size]; public override async ValueTask CloseAsync() - { - await CloseAsync(CloseReason.NormalClosure, string.Empty); - } + => await CloseAsync(CloseReason.NormalClosure, string.Empty); public async ValueTask CloseAsync(CloseReason closeReason, string message = null) { - var package = new WebSocketPackage(); - - package.OpCode = OpCode.Close; + var package = new WebSocketPackage + { + OpCode = OpCode.Close + }; var bufferSize = !string.IsNullOrEmpty(message) ? _utf8Encoding.GetMaxByteCount(message.Length) : 0; bufferSize += 2; @@ -379,7 +374,7 @@ private async ValueTask HandleCloseHandshake(WebSocketPackage receivedClosePacka if (closeStatus == null) { - this.State = WebSocketState.CloseReceived; + State = WebSocketState.CloseReceived; closeStatusFromRemote.RemoteInitiated = true; CloseStatus = closeStatusFromRemote; // Send close pong message to server side. @@ -399,9 +394,7 @@ private async ValueTask HandleCloseHandshake(WebSocketPackage receivedClosePacka } } - private void OnConnectionClosed(object sender, EventArgs eventArgs) - { - this.State = WebSocketState.Closed; - } + private void OnConnectionClosed(object sender, EventArgs eventArgs) + => State = WebSocketState.Closed; } }