Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(repeater): resolve bridges service communication #171

Merged
merged 13 commits into from
Jun 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/DefaultRepeaterBusFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using SecTester.Core.Utils;
using SocketIO.Serializer.MessagePack;
using SocketIOClient;
using SocketIOClient.Transport;

namespace SecTester.Repeater.Bus;

Expand Down Expand Up @@ -37,7 +38,7 @@ public IRepeaterBus Create(string repeaterId)
ReconnectionAttempts = options.ReconnectionAttempts,
ReconnectionDelayMax = options.ReconnectionDelayMax,
ConnectionTimeout = options.ConnectionTimeout,
AutoUpgrade = false,
Transport = TransportProtocol.WebSocket,
Auth = new { token = _config.Credentials.Token, domain = repeaterId }
})
{
Expand Down
81 changes: 72 additions & 9 deletions src/SecTester.Repeater/Bus/IncomingRequest.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,82 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using MessagePack;
using SecTester.Core.Bus;
using SecTester.Repeater.Internal;
using SecTester.Repeater.Runners;

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
public record IncomingRequest(Uri Url) : Event, IRequest
[MessagePackObject]
public record IncomingRequest(Uri Url) : IRequest
{
public string? Body { get; set; }
public HttpMethod Method { get; set; } = HttpMethod.Get;
public Protocol Protocol { get; set; } = Protocol.Http;
public Uri Url { get; set; } = Url ?? throw new ArgumentNullException(nameof(Url));
public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; set; } =
new List<KeyValuePair<string, IEnumerable<string>>>();
private const string UrlKey = "url";
private const string MethodKey = "method";
private const string HeadersKey = "headers";
private const string BodyKey = "body";
private const string ProtocolKey = "protocol";

[Key(ProtocolKey)] public Protocol Protocol { get; set; } = Protocol.Http;

[Key(HeadersKey)] public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; set; } = new Dictionary<string, IEnumerable<string>>();

[Key(BodyKey)] public string? Body { get; set; }

[Key(MethodKey)] public HttpMethod Method { get; set; } = HttpMethod.Get;

[Key(UrlKey)] public Uri Url { get; set; } = Url ?? throw new ArgumentNullException(nameof(Url));

public static IncomingRequest FromDictionary(Dictionary<object, object> dictionary)
{
var protocol = GetProtocolFromDictionary(dictionary);
var headers = GetHeadersFromDictionary(dictionary);
var body = GetBodyFromDictionary(dictionary);
var method = GetMethodFromDictionary(dictionary);
var url = GetUrlFromDictionary(dictionary);

return new IncomingRequest(url!)
{
Protocol = protocol,
Headers = headers,
Body = body,
Method = method
};
}

private static Protocol GetProtocolFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(ProtocolKey, out var protocolObj) && protocolObj is string protocolStr
? (Protocol)Enum.Parse(typeof(Protocol), protocolStr, true)
: Protocol.Http;

private static IEnumerable<KeyValuePair<string, IEnumerable<string>>> GetHeadersFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(HeadersKey, out var headersObj) && headersObj is Dictionary<object, object> headersDict
? ConvertToHeaders(headersDict)
: new Dictionary<string, IEnumerable<string>>();

private static string? GetBodyFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(BodyKey, out var bodyObj) ? bodyObj?.ToString() : null;

private static HttpMethod GetMethodFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(MethodKey, out var methodObj) && methodObj is string methodStr
? HttpMethods.Items.TryGetValue(methodStr, out var m) && m is not null
? m
: HttpMethod.Get
: HttpMethod.Get;

private static Uri? GetUrlFromDictionary(Dictionary<object, object> dictionary) =>
dictionary.TryGetValue(UrlKey, out var urlObj) && urlObj is string urlStr
? new Uri(urlStr)
: null;

private static IEnumerable<KeyValuePair<string, IEnumerable<string>>> ConvertToHeaders(Dictionary<object, object> headers) =>
headers.ToDictionary(
kvp => kvp.Key.ToString()!,
kvp => kvp.Value switch
{
IEnumerable<object> list => list.Select(v => v.ToString()!),
string str => new[] { str },
_ => Enumerable.Empty<string>()
}
);
}
18 changes: 14 additions & 4 deletions src/SecTester.Repeater/Bus/OutgoingResponse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
[MessagePackObject]
public record OutgoingResponse : IResponse
{
[Key("protocol")]
public Protocol Protocol { get; set; } = Protocol.Http;

[Key("statusCode")]
public int? StatusCode { get; set; }

[Key("body")]
public string? Body { get; set; }

[Key("message")]
public string? Message { get; set; }

[Key("errorCode")]
public string? ErrorCode { get; set; }
public Protocol Protocol { get; set; } = Protocol.Http;
public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; set; } =
new List<KeyValuePair<string, IEnumerable<string>>>();

[Key("headers")]
public IEnumerable<KeyValuePair<string, IEnumerable<string>>> Headers { get; set; } = new Dictionary<string, IEnumerable<string>>();
}
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/RepeaterError.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
[MessagePackObject]
public sealed record RepeaterError
{
[Key("message")]
public string Message { get; set; } = null!;
}
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/RepeaterInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
[MessagePackObject]
public sealed record RepeaterInfo
{
[Key("repeaterId")]
public string RepeaterId { get; set; } = null!;
}
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/RepeaterVersion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

namespace SecTester.Repeater.Bus;

[MessagePackObject(true)]
[MessagePackObject]
public sealed record RepeaterVersion
{
[Key("version")]
public string Version { get; set; } = null!;
}
3 changes: 2 additions & 1 deletion src/SecTester.Repeater/Bus/SocketIoRepeaterBus.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using System.Timers;
Expand Down Expand Up @@ -47,7 +48,7 @@
_connection.On("error", response =>
{
var err = response.GetValue<RepeaterError>();
ErrorOccurred?.Invoke(new(err.Message));

Check warning on line 51 in src/SecTester.Repeater/Bus/SocketIoRepeaterBus.cs

View workflow job for this annotation

GitHub Actions / windows-2019

Exception type System.Exception is not sufficiently specific
});

_connection.On("update-available", response =>
Expand All @@ -64,7 +65,7 @@
}

var ct = new CancellationTokenSource(_options.AckTimeout);
var request = response.GetValue<IncomingRequest>();
var request = IncomingRequest.FromDictionary(response.GetValue<Dictionary<object, object>>());
var result = await RequestReceived.Invoke(request).ConfigureAwait(false);
await response.CallbackAsync(ct.Token, result).ConfigureAwait(false);
});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using MessagePack;
using MessagePack.Resolvers;

namespace SecTester.Repeater.Internal;

internal static class DefaultMessagePackSerializerOptions
{
internal static readonly MessagePackSerializerOptions Instance = new(
CompositeResolver.Create(
CompositeResolver.Create(
new MessagePackHttpHeadersFormatter(),
new MessagePackStringEnumMemberFormatter<Protocol>(MessagePackNamingPolicy.SnakeCase),
new MessagePackHttpMethodFormatter()),
StandardResolver.Instance
)
);
}
30 changes: 30 additions & 0 deletions src/SecTester.Repeater/Internal/HttpMethods.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Reflection;

namespace SecTester.Repeater.Internal;

public class HttpMethods
{
public static IDictionary<string, HttpMethod> Items { get; } = typeof(HttpMethod)
.GetProperties(BindingFlags.Public | BindingFlags.Static | BindingFlags.DeclaredOnly)
.Where(x => x.PropertyType.IsAssignableFrom(typeof(HttpMethod)))
.Select(x => x.GetValue(null))
.Cast<HttpMethod>()
.Concat(new List<HttpMethod>
{
new("PATCH"),
new("COPY"),
new("LINK"),
new("UNLINK"),
new("PURGE"),
new("LOCK"),
new("UNLOCK"),
new("PROPFIND"),
new("VIEW")
})
.Distinct()
.ToDictionary(x => x.Method, x => x, StringComparer.InvariantCultureIgnoreCase);
}
155 changes: 155 additions & 0 deletions src/SecTester.Repeater/Internal/MessagePackHttpHeadersFormatter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
using System.Collections.Generic;
using System.Linq;
using MessagePack;
using MessagePack.Formatters;

namespace SecTester.Repeater.Internal;

// Headers formatter is to be supporting javascript `undefined` which is treated as null (0xC0)
// https://www.npmjs.com/package/@msgpack/msgpack#messagepack-mapping-table
// https://github.com/msgpack/msgpack/blob/master/spec.md#nil-format

internal class MessagePackHttpHeadersFormatter : IMessagePackFormatter<
IEnumerable<KeyValuePair<string, IEnumerable<string>>>?
>
{
public void Serialize(ref MessagePackWriter writer, IEnumerable<KeyValuePair<string, IEnumerable<string>>>? value,
MessagePackSerializerOptions options)
{
if (value == null)
{
writer.WriteNil();
}
else
{
var count = value.Count();

writer.WriteMapHeader(count);

Serialize(ref writer, value);
}
}

private static void Serialize(ref MessagePackWriter writer, IEnumerable<KeyValuePair<string, IEnumerable<string>>> value)
{
foreach (var item in value)
{
writer.Write(item.Key);

Serialize(ref writer, item);
}
}

private static void Serialize(ref MessagePackWriter writer, KeyValuePair<string, IEnumerable<string>> item)
{
var headersCount = item.Value.Count();

if (headersCount == 1)
{
writer.Write(item.Value.First());
}
else
{
writer.WriteArrayHeader(headersCount);

foreach (var subItem in item.Value)
{
writer.Write(subItem);
}
}
}

public IEnumerable<KeyValuePair<string, IEnumerable<string>>>? Deserialize(ref MessagePackReader reader,
MessagePackSerializerOptions options)
{
if (reader.NextMessagePackType == MessagePackType.Nil)
{
reader.ReadNil();
return null;
}

if (reader.NextMessagePackType != MessagePackType.Map)
{
throw new MessagePackSerializationException($"Unrecognized code: 0x{reader.NextCode:X2} but expected to be a map or null");
}

var length = reader.ReadMapHeader();

options.Security.DepthStep(ref reader);

try
{
return DeserializeMap(ref reader, length, options);
}
finally
{
reader.Depth--;
}
}

private static IEnumerable<KeyValuePair<string, IEnumerable<string>>> DeserializeMap(ref MessagePackReader reader, int length,
MessagePackSerializerOptions options)
{
var result = new List<KeyValuePair<string, IEnumerable<string>>>(length);

for (var i = 0 ; i < length ; i++)
{
var key = DeserializeString(ref reader);

result.Add(new KeyValuePair<string, IEnumerable<string>>(
key,
DeserializeValue(ref reader, options)
));
}

return result;
}

private static IEnumerable<string> DeserializeArray(ref MessagePackReader reader, int length, MessagePackSerializerOptions options)
{
var result = new List<string>(length);

options.Security.DepthStep(ref reader);

try
{
for (var i = 0 ; i < length ; i++)
{
result.Add(DeserializeString(ref reader));
}
}
finally
{
reader.Depth--;
}

return result;
}

private static IEnumerable<string> DeserializeValue(ref MessagePackReader reader, MessagePackSerializerOptions options)
{
switch (reader.NextMessagePackType)
{
case MessagePackType.Nil:
reader.ReadNil();
return new List<string>();
case MessagePackType.String:
return new List<string> { DeserializeString(ref reader) };
case MessagePackType.Array:
return DeserializeArray(ref reader, reader.ReadArrayHeader(), options);
default:
throw new MessagePackSerializationException(
$"Unrecognized code: 0x{reader.NextCode:X2} but expected to be either a string or an array.");
}
}

private static string DeserializeString(ref MessagePackReader reader)
{
if (reader.NextMessagePackType != MessagePackType.String)
{
throw new MessagePackSerializationException($"Unrecognized code: 0x{reader.NextCode:X2} but expected to be a string.");
}

return reader.ReadString() ?? string.Empty;
}
}
Loading
Loading