Skip to content

Commit

Permalink
Add handshake that confirms whether ModMessages can be sent over the …
Browse files Browse the repository at this point in the history
…control channel
  • Loading branch information
mircearoata committed Jun 28, 2024
1 parent 79e6892 commit 5d08f79
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
100 changes: 99 additions & 1 deletion Mods/SML/Source/SML/Private/Network/NetworkHandler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ DEFINE_LOG_CATEGORY(LogModNetworkHandler);
DEFINE_CONTROL_CHANNEL_MESSAGE_THREEPARAM(ModMessage, 40, FString, int32, FString);
IMPLEMENT_CONTROL_CHANNEL_MESSAGE(ModMessage);

struct FConnectionSMLSupport {
bool bSupportsModMessageType{false};
TArray<TTuple<FMessageType, FString>> PendingMessages;

FORCEINLINE bool IsDefault() const { return !bSupportsModMessageType && PendingMessages.IsEmpty(); }
};

static FUObjectAnnotationSparse<FConnectionSMLSupport, true> GConnectionMetadata;

static FString GSML_HELLO = TEXT("SML_HELLO");

FMessageEntry& UModNetworkHandler::RegisterMessageType(const FMessageType& MessageType) {
UE_LOG(LogModNetworkHandler, Display, TEXT("Registering message type %s:%d"), *MessageType.ModReference, MessageType.MessageId);
TMap<int32, FMessageEntry>& ModEntries = MessageHandlers.FindOrAdd(MessageType.ModReference);
Expand All @@ -27,8 +38,40 @@ void UModNetworkHandler::CloseWithFailureMessage(UNetConnection* Connection, con
}

void UModNetworkHandler::SendMessage(UNetConnection* Connection, FMessageType MessageType, FString Data) {
FNetControlMessage<NMT_ModMessage>::Send(Connection, MessageType.ModReference, MessageType.MessageId, Data);
FConnectionSMLSupport ConnectionMetadata = GConnectionMetadata.GetAnnotation(Connection);
if (ConnectionMetadata.bSupportsModMessageType) {
FNetControlMessage<NMT_ModMessage>::Send(Connection, MessageType.ModReference, MessageType.MessageId, Data);
Connection->FlushNet(true);
} else {
ConnectionMetadata.PendingMessages.Add({MessageType, Data});
GConnectionMetadata.AddAnnotation(Connection, ConnectionMetadata);
}
}

void UModNetworkHandler::SetConnectionSupportsModMessages(UNetConnection* Connection) {
FConnectionSMLSupport ConnectionMetadata = GConnectionMetadata.GetAnnotation(Connection);

if (ConnectionMetadata.bSupportsModMessageType) {
return;
}

ConnectionMetadata.bSupportsModMessageType = true;

// Let other side know we support mod messages
FNetControlMessage<NMT_DebugText>::Send(Connection, GSML_HELLO);
Connection->FlushNet(true);

// Send all pending messages now that we know they are supported
for (const TTuple<FMessageType, FString>& Message : ConnectionMetadata.PendingMessages) {
FMessageType ModMessageType = Message.Get<0>();
FString Data = Message.Get<1>();
FNetControlMessage<NMT_ModMessage>::Send(Connection, ModMessageType.ModReference, ModMessageType.MessageId, Data);
}

Connection->FlushNet(true);

ConnectionMetadata.PendingMessages.Empty();
GConnectionMetadata.AddAnnotation(Connection, ConnectionMetadata);
}

UGameInstance* UModNetworkHandler::GetGameInstanceFromNetDriver( const UNetDriver* NetDriver )
Expand Down Expand Up @@ -62,6 +105,35 @@ void UModNetworkHandler::ReceiveMessage(UNetConnection* Connection, const FStrin
}
}

/**
* SML handshake is done in the following way:
*
* Server Client
* | SML_HELLO |
* |------------------------->|
* | SML_HELLO | bSupportsModMessageType = true
* |<-------------------------|
* bSupportsModMessageType = true | |
* | any pending mod messages |
* |<-------------------------|
* | SML_HELLO |
* |------------------------->|
* | any pending mod messages |
* |------------------------->|
*
* If the client is not running SML, it will not respond to the SML_HELLO message,
* so the server will not mark the connection as supporting mod messages, and never send any mod messages.
*
* If the server is not running SML, it will not send the initial SML_HELLO message,
* so the client will not mark the connection as supporting mod messages, and never send any mod messages.
*
* We cannot simply send an SML_HELLO message from each side at the beginning of the connection,
* because the server expects a specific message order from the client, up until the NMT_Welcome message,
* and disconnects if that order is not followed.
* So if the server is not running SML, nothing will intercept the SML_HELLO message, and it will reach
* UWorld::NotifyControlMessage, which will disconnect the client.
*
*/
void UModNetworkHandler::InitializePatches() {

UWorld* WorldObjectInstance = GetMutableDefault<UWorld>();
Expand All @@ -81,6 +153,32 @@ void UModNetworkHandler::InitializePatches() {
});

auto MessageHandler = [=](auto& Call, void*, UNetConnection* Connection, uint8 MessageType, class FInBunch& Bunch) {
if (MessageType == NMT_Hello) {
// NMT_Hello is only received on the server, sent by UPendingNetGame::SendInitialJoin
// Initiate the SML handshake

FNetControlMessage<NMT_DebugText>::Send(Connection, GSML_HELLO);
Connection->FlushNet(true);
}

if (MessageType == NMT_DebugText) {
const int64 Pos = Bunch.GetPosBits();

FString Text;
if (FNetControlMessage<NMT_DebugText>::Receive(Bunch, Text)) {
if(Text == GSML_HELLO) {
SetConnectionSupportsModMessages(Connection);
}
}

// Only forward the message to the engine if it can be handled (see handshake explanation)
if (Connection->IsClientMsgTypeValid(NMT_DebugText)) {
Bunch.SetReadPosition(Pos);
} else {
Call.Cancel();
}
}

if (MessageType == NMT_ModMessage) {
FString ModId; int32 MessageId; FString Content;
if (FNetControlMessage<NMT_ModMessage>::Receive(Bunch, ModId, MessageId, Content)) {
Expand Down
5 changes: 5 additions & 0 deletions Mods/SML/Source/SML/Public/Network/NetworkHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class SML_API UModNetworkHandler : public UEngineSubsystem {
*/
static void SendMessage(class UNetConnection* Connection, FMessageType MessageType, FString Data);

/**
* Set the connection to support mod messages and send all pending messages
*/
static void SetConnectionSupportsModMessages(class UNetConnection* Connection);

/**
* Retrieves the game instance owning the specified net driver
*/
Expand Down

0 comments on commit 5d08f79

Please sign in to comment.