diff --git a/eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h b/eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h index d2d8d293c..399399c3a 100644 --- a/eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h +++ b/eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h @@ -47,7 +47,7 @@ namespace Aws /** * A callback prototype that is called upon flushing a message over the wire. - * @param errorCode A non-zero value if an error occured while attempting to flush the message. + * @param errorCode A non-zero value if an error occurred while attempting to flush the message. */ using OnMessageFlushCallback = std::function; @@ -77,13 +77,10 @@ namespace Aws Crt::Allocator *allocator = Crt::g_allocator) noexcept; Crt::String GetHeaderName() const noexcept; - bool GetValueAsString(Crt::String &) const noexcept; const struct aws_event_stream_header_value_pair *GetUnderlyingHandle() const; - bool operator==(const EventStreamHeader &other) const noexcept; - private: Crt::Allocator *m_allocator; Crt::ByteBuf m_valueByteBuf; @@ -99,7 +96,7 @@ namespace Aws { public: MessageAmendment(const MessageAmendment &lhs); - MessageAmendment(MessageAmendment &&rhs); + MessageAmendment(MessageAmendment &&rhs) noexcept; MessageAmendment &operator=(const MessageAmendment &lhs); ~MessageAmendment() noexcept; explicit MessageAmendment(Crt::Allocator *allocator = Crt::g_allocator) noexcept; @@ -144,10 +141,10 @@ namespace Aws OnMessageFlushCallback GetConnectRequestCallback() const noexcept { return m_connectRequestCallback; } ConnectMessageAmender GetConnectMessageAmender() const noexcept { - return [&](void) -> const MessageAmendment & { return m_connectAmendment; }; + return [&]() -> const MessageAmendment & { return m_connectAmendment; }; } - void SetHostName(Crt::String hostName) noexcept { m_hostName = hostName; } + void SetHostName(Crt::String hostName) noexcept { m_hostName = std::move(hostName); } void SetPort(uint32_t port) noexcept { m_port = port; } void SetSocketOptions(const Crt::Io::SocketOptions &socketOptions) noexcept { @@ -167,7 +164,7 @@ namespace Aws } void SetConnectRequestCallback(OnMessageFlushCallback connectRequestCallback) noexcept { - m_connectRequestCallback = connectRequestCallback; + m_connectRequestCallback = std::move(connectRequestCallback); } protected: @@ -212,7 +209,6 @@ namespace Aws { public: virtual ~ConnectionLifecycleHandler() noexcept = default; - /** * This callback is only invoked upon receiving a CONNECT_ACK with the * CONNECTION_ACCEPTED flag set by the server. Therefore, once this callback @@ -312,6 +308,12 @@ namespace Aws Crt::Allocator *allocator) noexcept; ~ClientContinuation() noexcept; + ClientContinuation(const ClientContinuation &other) = default; + ClientContinuation(ClientContinuation &&other) noexcept = default; + + ClientContinuation &operator=(const ClientContinuation &other) = delete; + ClientContinuation &operator=(ClientContinuation &&other) noexcept = delete; + /** * Initiate a new client stream. Send new message for the new stream. * @param operation Name for the operation to be invoked by the peer endpoint. @@ -394,7 +396,7 @@ namespace Aws public: explicit OperationError() noexcept = default; static void s_customDeleter(OperationError *shape) noexcept; - virtual void SerializeToJsonObject(Crt::JsonObject &payloadObject) const override; + void SerializeToJsonObject(Crt::JsonObject &payloadObject) const override; virtual Crt::Optional GetMessage() noexcept = 0; }; @@ -486,7 +488,7 @@ namespace Aws } OperationResult(Crt::ScopedResource &&error) noexcept : m_error(std::move(error)) {} OperationResult() noexcept : m_response(nullptr) {} - ~OperationResult() noexcept {}; + ~OperationResult() noexcept {} Crt::ScopedResource m_response; Crt::ScopedResource m_error; }; @@ -613,7 +615,7 @@ namespace Aws std::shared_ptr streamHandler, const OperationModelContext &operationModelContext, Crt::Allocator *allocator) noexcept; - ~ClientOperation() noexcept; + virtual ~ClientOperation() noexcept; ClientOperation(const ClientOperation &clientOperation) noexcept = delete; ClientOperation(ClientOperation &&clientOperation) noexcept = delete; @@ -790,14 +792,14 @@ namespace Aws DISCONNECTING, }; /* This recursive mutex protects m_clientState & m_connectionWillSetup */ - std::recursive_mutex m_stateMutex; + std::mutex m_closeReasonMutex; Crt::Allocator *m_allocator; struct aws_event_stream_rpc_client_connection *m_underlyingConnection; - ClientState m_clientState; + std::atomic m_clientState; ConnectionLifecycleHandler *m_lifecycleHandler; ConnectMessageAmender m_connectMessageAmender; std::promise m_connectionSetupPromise; - bool m_connectionWillSetup; + std::atomic m_connectionWillSetup; std::promise m_connectAckedPromise; std::promise m_closedPromise; bool m_onConnectCalled; diff --git a/eventstream_rpc/source/EventStreamClient.cpp b/eventstream_rpc/source/EventStreamClient.cpp index efa46ddf7..36142ef70 100644 --- a/eventstream_rpc/source/EventStreamClient.cpp +++ b/eventstream_rpc/source/EventStreamClient.cpp @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include @@ -79,6 +79,11 @@ namespace Aws MessageAmendment &MessageAmendment::operator=(const MessageAmendment &lhs) { + if (this == &lhs) + { + return *this; + } + m_headers = lhs.m_headers; if (lhs.m_payload.has_value()) { @@ -90,8 +95,8 @@ namespace Aws return *this; } - MessageAmendment::MessageAmendment(MessageAmendment &&rhs) - : m_headers(std::move(rhs.m_headers)), m_payload(rhs.m_payload), m_allocator(rhs.m_allocator) + MessageAmendment::MessageAmendment(MessageAmendment &&rhs) noexcept + : m_headers(std::move(rhs.m_headers)), m_payload(std::move(rhs.m_payload)), m_allocator(rhs.m_allocator) { rhs.m_allocator = nullptr; rhs.m_payload = Crt::Optional(); @@ -149,13 +154,20 @@ namespace Aws } }; + // FIXME This assignment operator is completely broken. + // 1. rhs' internal state can be changed while members are copying one by one, which can lead to this being + // inconsistent/corrupted. + // 2. if rhs is in the CONNECTED state, then a pointer to it has already been passed to the underlying + // libraries and is used in callbacks. + // 3. If this is connected, what will happen to its underlying connection, state, callbacks, etc.? + // + // Option 1 (preferable): This operator should be marked as deleted. It'll be a BREAKING CHANGE. + // Option 2: As an ugly alternative, throw runtime error if `Connect` method was called on rhs or this. ClientConnection &ClientConnection::operator=(ClientConnection &&rhs) noexcept { m_allocator = std::move(rhs.m_allocator); m_underlyingConnection = rhs.m_underlyingConnection; - rhs.m_stateMutex.lock(); - m_clientState = rhs.m_clientState; - rhs.m_stateMutex.unlock(); + m_clientState = rhs.m_clientState.load(); m_lifecycleHandler = rhs.m_lifecycleHandler; m_connectMessageAmender = rhs.m_connectMessageAmender; m_connectAckedPromise = std::move(rhs.m_connectAckedPromise); @@ -174,6 +186,7 @@ namespace Aws return *this; } + // FIXME Mark as deleted. See comment to the assignment operator. ClientConnection::ClientConnection(ClientConnection &&rhs) noexcept : m_lifecycleHandler(rhs.m_lifecycleHandler) { *this = std::move(rhs); @@ -182,28 +195,23 @@ namespace Aws ClientConnection::ClientConnection(Crt::Allocator *allocator) noexcept : m_allocator(allocator), m_underlyingConnection(nullptr), m_clientState(DISCONNECTED), m_lifecycleHandler(nullptr), m_connectMessageAmender(nullptr), m_connectionWillSetup(false), + m_onConnectCalled(false), m_closeReason{EVENT_STREAM_RPC_UNINITIALIZED, 0}, m_onConnectRequestCallback(nullptr) { } ClientConnection::~ClientConnection() noexcept { - m_stateMutex.lock(); if (m_connectionWillSetup) { - m_stateMutex.unlock(); m_connectionSetupPromise.get_future().wait(); } - m_stateMutex.lock(); + if (m_clientState != DISCONNECTED) { Close(); - m_stateMutex.unlock(); m_closedPromise.get_future().wait(); } - /* Cover the case in which the if statements are not hit. */ - m_stateMutex.unlock(); - m_stateMutex.unlock(); m_underlyingConnection = nullptr; } @@ -269,10 +277,9 @@ namespace Aws Crt::Io::ClientBootstrap &clientBootstrap) noexcept { EventStreamRpcStatusCode baseError = EVENT_STREAM_RPC_SUCCESS; - struct aws_event_stream_rpc_client_connection_options connOptions; + aws_event_stream_rpc_client_connection_options connOptions{}; { - const std::lock_guard lock(m_stateMutex); if (m_clientState == DISCONNECTED) { m_clientState = CONNECTING_SOCKET; @@ -280,7 +287,10 @@ namespace Aws m_connectionSetupPromise = {}; m_connectAckedPromise = {}; m_closedPromise = {}; - m_closeReason = {EVENT_STREAM_RPC_UNINITIALIZED, 0}; + { + const std::lock_guard lock(m_closeReasonMutex); + m_closeReason = {EVENT_STREAM_RPC_UNINITIALIZED, 0}; + } m_connectionConfig = connectionConfig; m_lifecycleHandler = connectionLifecycleHandler; } @@ -323,7 +333,6 @@ namespace Aws errorPromise.set_value({baseError, 0}); if (baseError == EVENT_STREAM_RPC_NULL_PARAMETER) { - const std::lock_guard lock(m_stateMutex); m_clientState = DISCONNECTED; } return errorPromise.get_future(); @@ -358,13 +367,11 @@ namespace Aws "A CRT error occurred while attempting to establish the connection: %s", Crt::ErrorDebugString(crtError)); errorPromise.set_value({EVENT_STREAM_RPC_CRT_ERROR, crtError}); - const std::lock_guard lock(m_stateMutex); m_clientState = DISCONNECTED; return errorPromise.get_future(); } else { - const std::lock_guard lock(m_stateMutex); m_connectionWillSetup = true; } @@ -376,7 +383,7 @@ namespace Aws const Crt::Optional &payload, OnMessageFlushCallback onMessageFlushCallback) noexcept { - return s_sendPing(this, headers, payload, onMessageFlushCallback); + return s_sendPing(this, headers, payload, std::move(onMessageFlushCallback)); } std::future ClientConnection::SendPingResponse( @@ -384,7 +391,7 @@ namespace Aws const Crt::Optional &payload, OnMessageFlushCallback onMessageFlushCallback) noexcept { - return s_sendPingResponse(this, headers, payload, onMessageFlushCallback); + return s_sendPingResponse(this, headers, payload, std::move(onMessageFlushCallback)); } std::future ClientConnection::s_sendPing( @@ -394,7 +401,12 @@ namespace Aws OnMessageFlushCallback onMessageFlushCallback) noexcept { return s_sendProtocolMessage( - connection, headers, payload, AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_PING, 0, onMessageFlushCallback); + connection, + headers, + payload, + AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_PING, + 0, + std::move(onMessageFlushCallback)); } std::future ClientConnection::s_sendPingResponse( @@ -409,7 +421,7 @@ namespace Aws payload, AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_PING_RESPONSE, 0, - onMessageFlushCallback); + std::move(onMessageFlushCallback)); } std::future ClientConnection::SendProtocolMessage( @@ -419,7 +431,8 @@ namespace Aws uint32_t messageFlags, OnMessageFlushCallback onMessageFlushCallback) noexcept { - return s_sendProtocolMessage(this, headers, payload, messageType, messageFlags, onMessageFlushCallback); + return s_sendProtocolMessage( + this, headers, payload, messageType, messageFlags, std::move(onMessageFlushCallback)); } void ClientConnection::s_protocolMessageCallback(int errorCode, void *userData) noexcept @@ -458,7 +471,7 @@ namespace Aws { std::promise onFlushPromise; OnMessageFlushCallbackContainer *callbackContainer = nullptr; - struct aws_array_list headersArray; + aws_array_list headersArray{}; /* The caller should never pass a NULL connection. */ AWS_PRECONDITION(connection != nullptr); @@ -468,8 +481,8 @@ namespace Aws if (!errorCode) { - struct aws_event_stream_rpc_message_args msg_args; - msg_args.headers = (struct aws_event_stream_header_value_pair *)headersArray.data; + aws_event_stream_rpc_message_args msg_args{}; + msg_args.headers = static_cast(headersArray.data); msg_args.headers_count = headers.size(); msg_args.payload = payload.has_value() ? (aws_byte_buf *)(&(payload.value())) : nullptr; msg_args.message_type = messageType; @@ -479,7 +492,7 @@ namespace Aws * returns. */ callbackContainer = Crt::New(connection->m_allocator, connection->m_allocator); - callbackContainer->onMessageFlushCallback = onMessageFlushCallback; + callbackContainer->onMessageFlushCallback = std::move(onMessageFlushCallback); callbackContainer->onFlushPromise = std::move(onFlushPromise); errorCode = aws_event_stream_rpc_client_connection_send_protocol_message( @@ -497,6 +510,7 @@ namespace Aws if (errorCode) { + // FIXME Null pointer dereference if s_fillNativeHeadersArray fails. onFlushPromise = std::move(callbackContainer->onFlushPromise); AWS_LOGF_ERROR( AWS_LS_EVENT_STREAM_RPC_CLIENT, @@ -515,8 +529,6 @@ namespace Aws void ClientConnection::Close() noexcept { - const std::lock_guard lock(m_stateMutex); - if (IsOpen()) { aws_event_stream_rpc_client_connection_close(this->m_underlyingConnection, AWS_OP_SUCCESS); @@ -531,6 +543,7 @@ namespace Aws m_clientState = DISCONNECTING; } + const std::lock_guard lock(m_closeReasonMutex); if (m_closeReason.baseStatus == EVENT_STREAM_RPC_UNINITIALIZED) { m_closeReason = {EVENT_STREAM_RPC_CONNECTION_CLOSED, 0}; @@ -584,6 +597,10 @@ namespace Aws EventStreamHeader &EventStreamHeader::operator=(const EventStreamHeader &lhs) noexcept { + if (this == &lhs) + { + return *this; + } m_allocator = lhs.m_allocator; m_valueByteBuf = Crt::ByteBufNewCopy(lhs.m_allocator, lhs.m_valueByteBuf.buffer, lhs.m_valueByteBuf.len); m_underlyingHandle = lhs.m_underlyingHandle; @@ -637,8 +654,6 @@ namespace Aws /* The `userData` pointer is used to pass `this` of a `ClientConnection` object. */ auto *thisConnection = static_cast(userData); - const std::lock_guard lock(thisConnection->m_stateMutex); - if (errorCode) { thisConnection->m_clientState = DISCONNECTED; @@ -655,7 +670,10 @@ namespace Aws else if (thisConnection->m_clientState == DISCONNECTING || thisConnection->m_clientState == DISCONNECTED) { thisConnection->m_underlyingConnection = connection; - thisConnection->m_closeReason = {EVENT_STREAM_RPC_CONNECTION_CLOSED, 0}; + { + const std::lock_guard lock(thisConnection->m_closeReasonMutex); + thisConnection->m_closeReason = {EVENT_STREAM_RPC_CONNECTION_CLOSED, 0}; + } thisConnection->Close(); } else @@ -706,19 +724,21 @@ namespace Aws /* The `userData` pointer is used to pass `this` of a `ClientConnection` object. */ auto *thisConnection = static_cast(userData); - const std::lock_guard lock(thisConnection->m_stateMutex); - - if (thisConnection->m_closeReason.baseStatus == EVENT_STREAM_RPC_UNINITIALIZED && errorCode) { - thisConnection->m_closeReason = {EVENT_STREAM_RPC_CRT_ERROR, errorCode}; - } + const std::lock_guard lock(thisConnection->m_closeReasonMutex); - thisConnection->m_underlyingConnection = nullptr; + if (thisConnection->m_closeReason.baseStatus == EVENT_STREAM_RPC_UNINITIALIZED && errorCode) + { + thisConnection->m_closeReason = {EVENT_STREAM_RPC_CRT_ERROR, errorCode}; + } - if (thisConnection->m_closeReason.baseStatus != EVENT_STREAM_RPC_UNINITIALIZED && - !thisConnection->m_onConnectCalled) - { - thisConnection->m_connectAckedPromise.set_value(thisConnection->m_closeReason); + thisConnection->m_underlyingConnection = nullptr; + + if (thisConnection->m_closeReason.baseStatus != EVENT_STREAM_RPC_UNINITIALIZED && + !thisConnection->m_onConnectCalled) + { + thisConnection->m_connectAckedPromise.set_value(thisConnection->m_closeReason); + } } thisConnection->m_clientState = DISCONNECTED; @@ -765,7 +785,6 @@ namespace Aws switch (messageArgs->message_type) { case AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK: - thisConnection->m_stateMutex.lock(); if (thisConnection->m_clientState == WAITING_FOR_CONNECT_ACK) { if (messageArgs->message_flags & AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_CONNECTION_ACCEPTED) @@ -777,7 +796,10 @@ namespace Aws } else { - thisConnection->m_closeReason = {EVENT_STREAM_RPC_CONNECTION_ACCESS_DENIED, 0}; + { + const std::lock_guard lock(thisConnection->m_closeReasonMutex); + thisConnection->m_closeReason = {EVENT_STREAM_RPC_CONNECTION_ACCESS_DENIED, 0}; + } thisConnection->Close(); } } @@ -785,7 +807,6 @@ namespace Aws { /* Unexpected CONNECT_ACK received. */ } - thisConnection->m_stateMutex.unlock(); break; @@ -793,8 +814,7 @@ namespace Aws for (size_t i = 0; i < messageArgs->headers_count; ++i) { - pingHeaders.emplace_back( - EventStreamHeader(messageArgs->headers[i], thisConnection->m_allocator)); + pingHeaders.emplace_back(messageArgs->headers[i], thisConnection->m_allocator); } if (messageArgs->payload) @@ -847,7 +867,7 @@ namespace Aws Crt::Allocator *allocator) noexcept : m_allocator(allocator), m_continuationHandler(continuationHandler), m_continuationToken(nullptr) { - struct aws_event_stream_rpc_client_stream_continuation_options options; + aws_event_stream_rpc_client_stream_continuation_options options{}; options.on_continuation = ClientContinuation::s_onContinuationMessage; options.on_continuation_closed = ClientContinuation::s_onContinuationClosed; @@ -878,6 +898,9 @@ namespace Aws } if (m_callbackData != nullptr) { + // FIXME Setting `m_callbackData->continuationDestroyed` indicates that another actor is supposed + // to check this flag (see `ClientContinuation::s_onContinuationMessage`). However, we delete + // `m_callbackData` right after setting the flag, so it doesn't work as intended. { const std::lock_guard lock(m_callbackData->callbackMutex); m_callbackData->continuationDestroyed = true; @@ -893,12 +916,18 @@ namespace Aws { (void)continuationToken; /* The `userData` pointer is used to pass a `ContinuationCallbackData` object. */ + // FIXME Can `callbackData` be destroyed at this point? See `ClientContinuation::~ClientContinuation`. + // Probably `callbackData` is guaranteed to be alive after this PR: + // https://github.com/aws/aws-iot-device-sdk-cpp-v2/pull/437. But then we need to get rid of the + // `continuationDestroyed` flag. auto *callbackData = static_cast(userData); auto *thisContinuation = callbackData->clientContinuation; Crt::List continuationMessageHeaders; for (size_t i = 0; i < messageArgs->headers_count; ++i) { + // FIXME Considering that below we check if thisContinuation is alive, this line looks super suspicious. + // Keep allocator in callbackData? continuationMessageHeaders.emplace_back( EventStreamHeader(messageArgs->headers[i], thisContinuation->m_allocator)); } @@ -946,7 +975,7 @@ namespace Aws uint32_t messageFlags, OnMessageFlushCallback onMessageFlushCallback) noexcept { - struct aws_array_list headersArray; + aws_array_list headersArray{}; OnMessageFlushCallbackContainer *callbackContainer = nullptr; std::promise onFlushPromise; @@ -977,7 +1006,7 @@ namespace Aws if (!errorCode) { struct aws_event_stream_rpc_message_args msg_args; - msg_args.headers = (struct aws_event_stream_header_value_pair *)headersArray.data; + msg_args.headers = (aws_event_stream_header_value_pair *)headersArray.data; msg_args.headers_count = headers.size(); msg_args.payload = payload.has_value() ? (aws_byte_buf *)(&(payload.value())) : nullptr; msg_args.message_type = messageType; @@ -1005,6 +1034,7 @@ namespace Aws if (errorCode) { + // FIXME Null pointer dereference when s_fillNativeHeadersArray fails. onFlushPromise = std::move(callbackContainer->onFlushPromise); onFlushPromise.set_value({EVENT_STREAM_RPC_CRT_ERROR, errorCode}); Crt::Delete(callbackContainer, m_allocator); @@ -1020,7 +1050,7 @@ namespace Aws uint32_t messageFlags, OnMessageFlushCallback onMessageFlushCallback) noexcept { - struct aws_array_list headersArray; + aws_array_list headersArray{}; OnMessageFlushCallbackContainer *callbackContainer = nullptr; std::promise onFlushPromise; @@ -1035,7 +1065,7 @@ namespace Aws if (!errorCode) { - struct aws_event_stream_rpc_message_args msg_args; + aws_event_stream_rpc_message_args msg_args{}; msg_args.headers = (struct aws_event_stream_header_value_pair *)headersArray.data; msg_args.headers_count = headers.size(); msg_args.payload = payload.has_value() ? (aws_byte_buf *)(&(payload.value())) : nullptr; @@ -1045,7 +1075,7 @@ namespace Aws /* This heap allocation is necessary so that the flush callback can still be invoked when this function * returns. */ callbackContainer = Crt::New(m_allocator, m_allocator); - callbackContainer->onMessageFlushCallback = onMessageFlushCallback; + callbackContainer->onMessageFlushCallback = std::move(onMessageFlushCallback); callbackContainer->onFlushPromise = std::move(onFlushPromise); if (m_continuationToken) @@ -1066,6 +1096,7 @@ namespace Aws if (errorCode) { + // FIXME Null pointer dereference when s_fillNativeHeadersArray fails. onFlushPromise = std::move(callbackContainer->onFlushPromise); AWS_LOGF_ERROR( AWS_LS_EVENT_STREAM_RPC_CLIENT, @@ -1109,7 +1140,7 @@ namespace Aws const OperationModelContext &operationModelContext, Crt::Allocator *allocator) noexcept : m_operationModelContext(operationModelContext), m_asyncLaunchMode(std::launch::deferred), - m_messageCount(0), m_allocator(allocator), m_streamHandler(streamHandler), + m_messageCount(0), m_allocator(allocator), m_streamHandler(std::move(streamHandler)), m_clientContinuation(connection.NewStream(*this)), m_expectedCloses(0), m_streamClosedCalled(false) { } @@ -1232,11 +1263,11 @@ namespace Aws const Crt::List &headers, const Crt::String &name) noexcept { - for (auto it = headers.begin(); it != headers.end(); ++it) + for (const auto &header : headers) { - if (name == it->GetHeaderName()) + if (header.GetHeaderName() == name) { - return &(*it); + return &header; } } return nullptr; @@ -1348,7 +1379,7 @@ namespace Aws return true; } - void StreamResponseHandler::OnStreamEvent(Crt::ScopedResource response) {} + void StreamResponseHandler::OnStreamEvent(Crt::ScopedResource /* response */) {} void StreamResponseHandler::OnStreamClosed() {} @@ -1359,8 +1390,6 @@ namespace Aws uint32_t messageFlags) { EventStreamRpcStatusCode errorCode = EVENT_STREAM_RPC_SUCCESS; - const EventStreamHeader *modelHeader = nullptr; - const EventStreamHeader *contentHeader = nullptr; Crt::String modelName; if (messageFlags & AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM) @@ -1371,7 +1400,7 @@ namespace Aws m_messageCount += 1; - modelHeader = GetHeaderByName(headers, Crt::String(SERVICE_MODEL_TYPE_HEADER)); + const EventStreamHeader *modelHeader = GetHeaderByName(headers, Crt::String(SERVICE_MODEL_TYPE_HEADER)); if (modelHeader == nullptr) { /* Missing required service model type header. */ @@ -1410,7 +1439,7 @@ namespace Aws if (!errorCode) { Crt::String contentType; - contentHeader = GetHeaderByName(headers, Crt::String(CONTENT_TYPE_HEADER)); + const EventStreamHeader *contentHeader = GetHeaderByName(headers, Crt::String(CONTENT_TYPE_HEADER)); if (contentHeader == nullptr) { /* Missing required content type header. */ @@ -1450,7 +1479,7 @@ namespace Aws { const std::lock_guard lock(m_continuationMutex); m_resultReceived = true; - RpcError promiseValue = {(EventStreamRpcStatusCode)errorCode, 0}; + RpcError promiseValue = {errorCode, 0}; m_initialResponsePromise.set_value(TaggedResult(promiseValue)); } else @@ -1472,17 +1501,36 @@ namespace Aws { /* Promises must be reset in case the client would like to send a subsequent request with the same * `ClientOperation`. */ + // TODO When compiling with fno-exceptions, resetting promise can lead to abort() if someone is waiting + // for the associated future. m_initialResponsePromise = {}; { + // FIXME Possible race condition with ClientOperation::HandleData and/or other functions. + // t2: Calls ClientOperation::HandleData + // t2: Acquires m_continuationMutex + // t1: "Resets" m_initialResponsePromise + // t1: Locks here on m_continuationMutex + // t2: m_initialResponsePromise.set_value() + // t2: m_resultReceived = true + // t2: Releases m_continuationMutex (and completes ClientOperation::HandleData) + // t1: Acquires m_continuationMutex + // t1: m_resultReceived = false + // t1: Releases m_continuationMutex + // t2: Calls ClientOperation::OnContinuationClosed + // t2: Acquires m_continuationMutex + // t2: !m_resultReceived is true + // t2: m_initialResponsePromise.set_value() is called the second time + // t2: Calls abort() + // Not sure if this exact scenario is possible, but considering that this scheme is used in other + // methods, it'll be safer to fix this construction. const std::lock_guard lock(m_continuationMutex); m_resultReceived = false; } Crt::List headers; - headers.emplace_back(EventStreamHeader( - Crt::String(CONTENT_TYPE_HEADER), Crt::String(CONTENT_TYPE_APPLICATION_JSON), m_allocator)); headers.emplace_back( - EventStreamHeader(Crt::String(SERVICE_MODEL_TYPE_HEADER), GetModelName(), m_allocator)); + Crt::String(CONTENT_TYPE_HEADER), Crt::String(CONTENT_TYPE_APPLICATION_JSON), m_allocator); + headers.emplace_back(Crt::String(SERVICE_MODEL_TYPE_HEADER), GetModelName(), m_allocator); Crt::JsonObject payloadObject; shape->SerializeToJsonObject(payloadObject); Crt::String payloadString = payloadObject.View().WriteCompact(); @@ -1492,7 +1540,7 @@ namespace Aws Crt::ByteBufFromCString(payloadString.c_str()), AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_APPLICATION_MESSAGE, 0, - onMessageFlushCallback); + std::move(onMessageFlushCallback)); } void ClientOperation::OnContinuationClosed() @@ -1527,53 +1575,37 @@ namespace Aws errorPromise.set_value({EVENT_STREAM_RPC_CONTINUATION_CLOSED, 0}); return errorPromise.get_future(); } - else - { - std::promise onTerminatePromise; - int errorCode = AWS_OP_ERR; - struct aws_event_stream_rpc_message_args msg_args; - msg_args.headers = nullptr; - msg_args.headers_count = 0; - msg_args.payload = nullptr; - msg_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_APPLICATION_MESSAGE; - msg_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM; + aws_event_stream_rpc_message_args msg_args{}; + msg_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_APPLICATION_MESSAGE; + msg_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM; - /* This heap allocation is necessary so that the flush callback can still be invoked when this function - * returns. */ - OnMessageFlushCallbackContainer *callbackContainer = - Crt::New(m_allocator, m_allocator); - callbackContainer->onMessageFlushCallback = onMessageFlushCallback; - callbackContainer->onFlushPromise = std::move(onTerminatePromise); + /* This heap allocation is necessary so that the flush callback can still be invoked when this function + * returns. */ + auto *callbackContainer = Crt::New(m_allocator, m_allocator); + callbackContainer->onMessageFlushCallback = std::move(onMessageFlushCallback); - if (m_clientContinuation.m_continuationToken) - { - errorCode = aws_event_stream_rpc_client_continuation_send_message( - m_clientContinuation.m_continuationToken, - &msg_args, - ClientConnection::s_protocolMessageCallback, - reinterpret_cast(callbackContainer)); - } + int errorCode = aws_event_stream_rpc_client_continuation_send_message( + m_clientContinuation.m_continuationToken, + &msg_args, + ClientConnection::s_protocolMessageCallback, + callbackContainer); - if (errorCode) - { - onTerminatePromise = std::move(callbackContainer->onFlushPromise); - std::promise errorPromise; - AWS_LOGF_ERROR( - AWS_LS_EVENT_STREAM_RPC_CLIENT, - "A CRT error occurred while closing the stream: %s", - Crt::ErrorDebugString(errorCode)); - onTerminatePromise.set_value({EVENT_STREAM_RPC_CRT_ERROR, errorCode}); - Crt::Delete(callbackContainer, m_allocator); - } - else - { - m_expectedCloses.fetch_add(1); - return callbackContainer->onFlushPromise.get_future(); - } + if (errorCode) + { + AWS_LOGF_ERROR( + AWS_LS_EVENT_STREAM_RPC_CLIENT, + "A CRT error occurred while closing the stream: %s", + Crt::ErrorDebugString(errorCode)); + Crt::Delete(callbackContainer, m_allocator); + std::promise onTerminatePromise; + onTerminatePromise.set_value({EVENT_STREAM_RPC_CRT_ERROR, errorCode}); return onTerminatePromise.get_future(); } + + m_expectedCloses.fetch_add(1); + return callbackContainer->onFlushPromise.get_future(); } void OperationError::s_customDeleter(OperationError *shape) noexcept