Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1375 #1237 #925 #920 Fix DownstreamRoute DangerousAcceptAnyServerCertificateValidator #1377

Merged
merged 14 commits into from
Sep 28, 2023
Merged
2 changes: 2 additions & 0 deletions src/Ocelot/DependencyInjection/OcelotBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
using Ocelot.Security.IPSecurity;
using Ocelot.ServiceDiscovery;
using Ocelot.ServiceDiscovery.Providers;
using Ocelot.WebSockets;
using System.Reflection;

namespace Ocelot.DependencyInjection
Expand Down Expand Up @@ -138,6 +139,7 @@ public OcelotBuilder(IServiceCollection services, IConfiguration configurationRo
Services.TryAddSingleton<IQoSFactory, QoSFactory>();
Services.TryAddSingleton<IExceptionToErrorMapper, HttpExceptionToErrorMapper>();
Services.TryAddSingleton<IVersionCreator, HttpVersionCreator>();
Services.TryAddSingleton<IWebSocketsFactory, WebSocketsFactory>();

// Add security
Services.TryAddSingleton<ISecurityOptionsCreator, SecurityOptionsCreator>();
Expand Down
25 changes: 15 additions & 10 deletions src/Ocelot/LoadBalancer/LoadBalancers/RoundRobin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,35 @@ namespace Ocelot.LoadBalancer.LoadBalancers
{
public class RoundRobin : ILoadBalancer
{
private readonly Func<Task<List<Service>>> _services;
private readonly Func<Task<List<Service>>> _servicesDelegate;
private readonly object _lock = new();

private int _last;

public RoundRobin(Func<Task<List<Service>>> services)
{
_services = services;
_servicesDelegate = services;
}

public async Task<Response<ServiceHostAndPort>> Lease(HttpContext httpContext)
{
var services = await _services();
lock (_lock)
var services = await _servicesDelegate?.Invoke() ?? new List<Service>();

if (services?.Count != 0)
{
if (_last >= services.Count)
lock (_lock)
{
_last = 0;
}
if (_last >= services.Count)
{
_last = 0;
}

var next = services[_last];
_last++;
return new OkResponse<ServiceHostAndPort>(next.HostAndPort);
var next = services[_last++];
return new OkResponse<ServiceHostAndPort>(next.HostAndPort);
}
}

return new ErrorResponse<ServiceHostAndPort>(new ServicesAreEmptyError($"There were no services in {nameof(RoundRobin)} during {nameof(Lease)} operation."));
}

public void Release(ServiceHostAndPort hostAndPort)
Expand Down
2 changes: 1 addition & 1 deletion src/Ocelot/Middleware/OcelotPipelineExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
using Ocelot.RequestId.Middleware;
using Ocelot.Responder.Middleware;
using Ocelot.Security.Middleware;
using Ocelot.WebSockets.Middleware;
using Ocelot.WebSockets;

namespace Ocelot.Middleware
{
Expand Down
35 changes: 35 additions & 0 deletions src/Ocelot/WebSockets/ClientWebSocketOptionsProxy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System.Net.Security;
using System.Net.WebSockets;
using System.Security.Cryptography.X509Certificates;

namespace Ocelot.WebSockets;

public class ClientWebSocketOptionsProxy : IClientWebSocketOptions
{
private readonly ClientWebSocketOptions _real;

public ClientWebSocketOptionsProxy(ClientWebSocketOptions options)
{
_real = options;
}

public Version HttpVersion { get => _real.HttpVersion; set => _real.HttpVersion = value; }
public HttpVersionPolicy HttpVersionPolicy { get => _real.HttpVersionPolicy; set => _real.HttpVersionPolicy = value; }
public bool UseDefaultCredentials { get => _real.UseDefaultCredentials; set => _real.UseDefaultCredentials = value; }
public ICredentials Credentials { get => _real.Credentials; set => _real.Credentials = value; }
public IWebProxy Proxy { get => _real.Proxy; set => _real.Proxy = value; }
public X509CertificateCollection ClientCertificates { get => _real.ClientCertificates; set => _real.ClientCertificates = value; }
public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get => _real.RemoteCertificateValidationCallback; set => _real.RemoteCertificateValidationCallback = value; }
public CookieContainer Cookies { get => _real.Cookies; set => _real.Cookies = value; }
public TimeSpan KeepAliveInterval { get => _real.KeepAliveInterval; set => _real.KeepAliveInterval = value; }
public WebSocketDeflateOptions DangerousDeflateOptions { get => _real.DangerousDeflateOptions; set => _real.DangerousDeflateOptions = value; }
public bool CollectHttpResponseDetails { get => _real.CollectHttpResponseDetails; set => _real.CollectHttpResponseDetails = value; }

public void AddSubProtocol(string subProtocol) => _real.AddSubProtocol(subProtocol);

public void SetBuffer(int receiveBufferSize, int sendBufferSize) => _real.SetBuffer(receiveBufferSize, sendBufferSize);

public void SetBuffer(int receiveBufferSize, int sendBufferSize, ArraySegment<byte> buffer) => _real.SetBuffer(receiveBufferSize, sendBufferSize, buffer);

public void SetRequestHeader(string headerName, string headerValue) => _real.SetRequestHeader(headerName, headerValue);
}
49 changes: 49 additions & 0 deletions src/Ocelot/WebSockets/ClientWebSocketProxy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using System.Net.WebSockets;

namespace Ocelot.WebSockets;

public class ClientWebSocketProxy : WebSocket, IClientWebSocket
{
// RealSubject (Service) class of Proxy design pattern
private readonly ClientWebSocket _realService;
private readonly IClientWebSocketOptions _options;

public ClientWebSocketProxy()
{
_realService = new ClientWebSocket();
_options = new ClientWebSocketOptionsProxy(_realService.Options);
}

// ClientWebSocket implementations
public IClientWebSocketOptions Options => _options;

public Task ConnectAsync(Uri uri, CancellationToken cancellationToken)
=> _realService.ConnectAsync(uri, cancellationToken);

// WebSocket implementations
public override WebSocketCloseStatus? CloseStatus => _realService.CloseStatus;

public override string CloseStatusDescription => _realService.CloseStatusDescription;

public override WebSocketState State => _realService.State;

public override string SubProtocol => _realService.SubProtocol;

public override void Abort() => _realService.Abort();

public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
=> _realService.CloseAsync(closeStatus, statusDescription, cancellationToken);

public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
=> _realService.CloseOutputAsync(closeStatus, statusDescription, cancellationToken);

public override void Dispose() => _realService.Dispose();

public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
=> _realService.ReceiveAsync(buffer, cancellationToken);

public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
=> _realService.SendAsync(buffer, messageType, endOfMessage, cancellationToken);

public WebSocket ToWebSocket() => _realService;
}
24 changes: 24 additions & 0 deletions src/Ocelot/WebSockets/IClientWebSocket.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System.Net.WebSockets;

namespace Ocelot.WebSockets;

public interface IClientWebSocket
{
WebSocket ToWebSocket();

// ClientWebSocket definitions
IClientWebSocketOptions Options { get; }
Task ConnectAsync(Uri uri, CancellationToken cancellationToken);

// WebSocket definitions
WebSocketCloseStatus? CloseStatus { get; }
string CloseStatusDescription { get; }
WebSocketState State { get; }
string SubProtocol { get; }
void Abort();
Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken);
Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken);
void Dispose();
Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken);
Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken);
}
24 changes: 24 additions & 0 deletions src/Ocelot/WebSockets/IClientWebSocketOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
using System.Net.Security;
using System.Net.WebSockets;
using System.Security.Cryptography.X509Certificates;

namespace Ocelot.WebSockets;

public interface IClientWebSocketOptions
{
Version HttpVersion { get; set; }
HttpVersionPolicy HttpVersionPolicy { get; set; }
void SetRequestHeader(string headerName, string headerValue);
bool UseDefaultCredentials { get; set; }
ICredentials Credentials { get; set; }
IWebProxy Proxy { get; set; }
X509CertificateCollection ClientCertificates { get; set; }
RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get; set; }
CookieContainer Cookies { get; set; }
void AddSubProtocol(string subProtocol);
TimeSpan KeepAliveInterval { get; set; }
WebSocketDeflateOptions DangerousDeflateOptions { get; set; }
void SetBuffer(int receiveBufferSize, int sendBufferSize);
void SetBuffer(int receiveBufferSize, int sendBufferSize, ArraySegment<byte> buffer);
bool CollectHttpResponseDetails { get; set; }
}
6 changes: 6 additions & 0 deletions src/Ocelot/WebSockets/IWebSocketsFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Ocelot.WebSockets;

public interface IWebSocketsFactory
{
IClientWebSocket CreateClient();
}
6 changes: 6 additions & 0 deletions src/Ocelot/WebSockets/WebSocketsFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Ocelot.WebSockets;

public class WebSocketsFactory : IWebSocketsFactory
{
public IClientWebSocket CreateClient() => new ClientWebSocketProxy();
}
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
// Modified https://github.com/aspnet/Proxy websockets class to use in Ocelot.

using Microsoft.AspNetCore.Http;
using Ocelot.Configuration;
using Ocelot.Logging;
using Ocelot.Middleware;
using System.Net.WebSockets;

namespace Ocelot.WebSockets.Middleware
namespace Ocelot.WebSockets
{
public class WebSocketsProxyMiddleware : OcelotMiddleware
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{
private static readonly string[] NotForwardedWebSocketHeaders = new[] { "Connection", "Host", "Upgrade", "Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions" };
private static readonly string[] NotForwardedWebSocketHeaders = new[]
{
"Connection", "Host", "Upgrade",
"Sec-WebSocket-Accept", "Sec-WebSocket-Protocol", "Sec-WebSocket-Key", "Sec-WebSocket-Version", "Sec-WebSocket-Extensions",
};
private const int DefaultWebSocketBufferSize = 4096;
private const int StreamCopyBufferSize = 81920;
private readonly RequestDelegate _next;
private readonly IWebSocketsFactory _factory;

public WebSocketsProxyMiddleware(RequestDelegate next,
IOcelotLoggerFactory loggerFactory)
: base(loggerFactory.CreateLogger<WebSocketsProxyMiddleware>())
public WebSocketsProxyMiddleware(IOcelotLoggerFactory loggerFactory,
RequestDelegate next,
IWebSocketsFactory factory)
: base(loggerFactory.CreateLogger<WebSocketsProxyMiddleware>())
{
_next = next;
_factory = factory;
}

private static async Task PumpWebSocket(WebSocket source, WebSocket destination, int bufferSize, CancellationToken cancellationToken)
Expand Down Expand Up @@ -67,10 +74,11 @@ private static async Task PumpWebSocket(WebSocket source, WebSocket destination,
public async Task Invoke(HttpContext httpContext)
{
var uri = httpContext.Items.DownstreamRequest().ToUri();
await Proxy(httpContext, uri);
var downstreamRoute = httpContext.Items.DownstreamRoute();
await Proxy(httpContext, uri, downstreamRoute);
}

private static async Task Proxy(HttpContext context, string serverEndpoint)
private async Task Proxy(HttpContext context, string serverEndpoint, DownstreamRoute downstreamRoute)
{
if (context == null)
{
Expand All @@ -87,7 +95,14 @@ private static async Task Proxy(HttpContext context, string serverEndpoint)
throw new InvalidOperationException();
}

var client = new ClientWebSocket();
var client = _factory.CreateClient(); // new ClientWebSocket();

if (downstreamRoute.DangerousAcceptAnyServerCertificateValidator)
{
client.Options.RemoteCertificateValidationCallback = (request, certificate, chain, errors) => true;
Logger.LogWarning($"You have ignored all SSL warnings by using {nameof(DownstreamRoute.DangerousAcceptAnyServerCertificateValidator)} for this downstream route! {nameof(DownstreamRoute.UpstreamPathTemplate)}: '{downstreamRoute.UpstreamPathTemplate}', {nameof(DownstreamRoute.DownstreamPathTemplate)}: '{downstreamRoute.DownstreamPathTemplate}'.");
}

foreach (var protocol in context.WebSockets.WebSocketRequestedProtocols)
{
client.Options.AddSubProtocol(protocol);
Expand All @@ -112,10 +127,12 @@ private static async Task Proxy(HttpContext context, string serverEndpoint)

var destinationUri = new Uri(serverEndpoint);
await client.ConnectAsync(destinationUri, context.RequestAborted);

using (var server = await context.WebSockets.AcceptWebSocketAsync(client.SubProtocol))
{
var bufferSize = DefaultWebSocketBufferSize;
await Task.WhenAll(PumpWebSocket(client, server, bufferSize, context.RequestAborted), PumpWebSocket(server, client, bufferSize, context.RequestAborted));
await Task.WhenAll(
PumpWebSocket(client.ToWebSocket(), server, DefaultWebSocketBufferSize, context.RequestAborted),
PumpWebSocket(server, client.ToWebSocket(), DefaultWebSocketBufferSize, context.RequestAborted));
}
}
}
Expand Down
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The UseWebSocketsProxyMiddleware method enables custom WebSocketsProxyMiddleware in the pipeline.
But there is no usage of native WebSocketMiddleware class which is industry standard. But instead of standard framework middleware a custom WebSocketsProxyMiddleware class was written with custom Invoke method. This is design issue...

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Microsoft.AspNetCore.Builder;

namespace Ocelot.WebSockets.Middleware
namespace Ocelot.WebSockets
{
public static class WebSocketsProxyMiddlewareExtensions
{
Expand Down
5 changes: 3 additions & 2 deletions test/Ocelot.AcceptanceTests/ConsulWebSocketTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.AspNetCore.Http;
using Newtonsoft.Json;
using Ocelot.Configuration.File;
using Ocelot.WebSockets;
using System.Net.WebSockets;
using System.Text;

Expand Down Expand Up @@ -142,7 +143,7 @@ private async Task WhenIStartTheClients()

private async Task StartClient(string url)
{
var client = new ClientWebSocket();
IClientWebSocket client = new ClientWebSocketProxy();

await client.ConnectAsync(new Uri(url), CancellationToken.None);

Expand Down Expand Up @@ -194,7 +195,7 @@ private async Task StartSecondClient(string url)
{
await Task.Delay(500);

var client = new ClientWebSocket();
IClientWebSocket client = new ClientWebSocketProxy();

await client.ConnectAsync(new Uri(url), CancellationToken.None);

Expand Down
5 changes: 3 additions & 2 deletions test/Ocelot.AcceptanceTests/WebSocketTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Ocelot.Configuration.File;
using Ocelot.WebSockets;
using System.Net.WebSockets;
using System.Text;

Expand Down Expand Up @@ -124,7 +125,7 @@ private async Task WhenIStartTheClients()

private async Task StartClient(string url)
{
var client = new ClientWebSocket();
IClientWebSocket client = new ClientWebSocketProxy();

await client.ConnectAsync(new Uri(url), CancellationToken.None);

Expand Down Expand Up @@ -176,7 +177,7 @@ private async Task StartSecondClient(string url)
{
await Task.Delay(500);

var client = new ClientWebSocket();
IClientWebSocket client = new ClientWebSocketProxy();

await client.ConnectAsync(new Uri(url), CancellationToken.None);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using Ocelot.LoadBalancer.Middleware;
using Ocelot.Middleware;
using Ocelot.Request.Middleware;
using Ocelot.WebSockets.Middleware;
using Ocelot.WebSockets;

namespace Ocelot.UnitTests.Middleware
{
Expand Down
1 change: 1 addition & 0 deletions test/Ocelot.UnitTests/Ocelot.UnitTests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.6.3" />
<PackageReference Include="Nito.AsyncEx" Version="5.1.2" />
<PackageReference Include="xunit" Version="2.5.0" />
<PackageReference Include="xunit.runner.visualstudio" Version="2.5.0">
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
Loading