diff --git a/eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h b/eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h index 34240a7ee..14e701e86 100644 --- a/eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h +++ b/eventstream_rpc/include/aws/eventstreamrpc/EventStreamClient.h @@ -47,7 +47,7 @@ namespace Aws /** * A callback prototype that is called upon flushing a message over the wire. - * @param errorCode A non-zero value if an error occured while attempting to flush the message. + * @param errorCode A non-zero value if an error occurred while attempting to flush the message. */ using OnMessageFlushCallback = std::function; @@ -71,37 +71,16 @@ namespace Aws EventStreamHeader( const struct aws_event_stream_header_value_pair &header, Crt::Allocator *allocator = Crt::g_allocator); - EventStreamHeader(const Crt::String &name, bool value); - EventStreamHeader(const Crt::String &name, int8_t value); - EventStreamHeader(const Crt::String &name, int16_t value); - EventStreamHeader(const Crt::String &name, int32_t value); - EventStreamHeader(const Crt::String &name, int64_t value); - EventStreamHeader(const Crt::String &name, Crt::DateTime &value); EventStreamHeader( const Crt::String &name, const Crt::String &value, Crt::Allocator *allocator = Crt::g_allocator) noexcept; - EventStreamHeader(const Crt::String &name, Crt::ByteBuf &value); - EventStreamHeader(const Crt::String &name, Crt::UUID value); - HeaderValueType GetHeaderValueType(); Crt::String GetHeaderName() const noexcept; - void SetHeaderName(const Crt::String &); - - bool GetValueAsBoolean(bool &); - bool GetValueAsByte(int8_t &); - bool GetValueAsShort(int16_t &); - bool GetValueAsInt(int32_t &); - bool GetValueAsLong(int64_t &); - bool GetValueAsTimestamp(Crt::DateTime &); bool GetValueAsString(Crt::String &) const noexcept; - bool GetValueAsBytes(Crt::ByteBuf &); - bool GetValueAsUUID(Crt::UUID &); const struct aws_event_stream_header_value_pair *GetUnderlyingHandle() const; - bool operator==(const EventStreamHeader &other) const noexcept; - private: Crt::Allocator *m_allocator; Crt::ByteBuf m_valueByteBuf; @@ -117,7 +96,7 @@ namespace Aws { public: MessageAmendment(const MessageAmendment &lhs); - MessageAmendment(MessageAmendment &&rhs); + MessageAmendment(MessageAmendment &&rhs) noexcept; MessageAmendment &operator=(const MessageAmendment &lhs); ~MessageAmendment() noexcept; explicit MessageAmendment(Crt::Allocator *allocator = Crt::g_allocator) noexcept; @@ -162,10 +141,10 @@ namespace Aws OnMessageFlushCallback GetConnectRequestCallback() const noexcept { return m_connectRequestCallback; } ConnectMessageAmender GetConnectMessageAmender() const noexcept { - return [&](void) -> const MessageAmendment & { return m_connectAmendment; }; + return [&]() -> const MessageAmendment & { return m_connectAmendment; }; } - void SetHostName(Crt::String hostName) noexcept { m_hostName = hostName; } + void SetHostName(Crt::String hostName) noexcept { m_hostName = std::move(hostName); } void SetPort(uint32_t port) noexcept { m_port = port; } void SetSocketOptions(const Crt::Io::SocketOptions &socketOptions) noexcept { @@ -185,7 +164,7 @@ namespace Aws } void SetConnectRequestCallback(OnMessageFlushCallback connectRequestCallback) noexcept { - m_connectRequestCallback = connectRequestCallback; + m_connectRequestCallback = std::move(connectRequestCallback); } protected: @@ -226,6 +205,7 @@ namespace Aws class AWS_EVENTSTREAMRPC_API ConnectionLifecycleHandler { public: + virtual ~ConnectionLifecycleHandler() noexcept = default; /** * This callback is only invoked upon receiving a CONNECT_ACK with the * CONNECTION_ACCEPTED flag set by the server. Therefore, once this callback @@ -292,7 +272,7 @@ namespace Aws * the TERMINATE_STREAM flag, or when the connection shuts down. */ virtual void OnContinuationClosed() = 0; - virtual ~ClientContinuationHandler() noexcept; + virtual ~ClientContinuationHandler() noexcept = default; private: friend class ClientContinuation; @@ -307,6 +287,13 @@ namespace Aws ClientContinuationHandler &continuationHandler, Crt::Allocator *allocator) noexcept; ~ClientContinuation() noexcept; + + ClientContinuation(const ClientContinuation &other) = default; + ClientContinuation(ClientContinuation &&other) noexcept = default; + + ClientContinuation &operator=(const ClientContinuation &other) = delete; + ClientContinuation &operator=(ClientContinuation &&other) noexcept = delete; + std::future Activate( const Crt::String &operation, const Crt::List &headers, @@ -315,7 +302,6 @@ namespace Aws uint32_t messageFlags, OnMessageFlushCallback onMessageFlushCallback) noexcept; bool IsClosed() noexcept; - void Release() noexcept; std::future SendMessage( const Crt::List &headers, const Crt::Optional &payload, @@ -342,7 +328,7 @@ namespace Aws { public: AbstractShapeBase() noexcept; - virtual ~AbstractShapeBase() noexcept; + virtual ~AbstractShapeBase() noexcept = default; static void s_customDeleter(AbstractShapeBase *shape) noexcept; virtual void SerializeToJsonObject(Crt::JsonObject &payloadObject) const = 0; virtual Crt::String GetModelName() const noexcept = 0; @@ -354,9 +340,9 @@ namespace Aws class AWS_EVENTSTREAMRPC_API OperationError : public AbstractShapeBase { public: - explicit OperationError() noexcept; + explicit OperationError() noexcept = default; static void s_customDeleter(OperationError *shape) noexcept; - virtual void SerializeToJsonObject(Crt::JsonObject &payloadObject) const override; + void SerializeToJsonObject(Crt::JsonObject &payloadObject) const override; virtual Crt::Optional GetMessage() noexcept = 0; }; @@ -368,6 +354,8 @@ namespace Aws class AWS_EVENTSTREAMRPC_API StreamResponseHandler { public: + virtual ~StreamResponseHandler() noexcept = default; + /** * Invoked when stream is closed, so no more messages will be received. */ @@ -424,7 +412,7 @@ namespace Aws } OperationResult(Crt::ScopedResource &&error) noexcept : m_error(std::move(error)) {} OperationResult() noexcept : m_response(nullptr) {} - ~OperationResult() noexcept {}; + ~OperationResult() noexcept {} Crt::ScopedResource m_response; Crt::ScopedResource m_error; }; @@ -446,6 +434,7 @@ namespace Aws { /* An interface shared by all operations for retrieving the response object given the model name. */ public: + virtual ~ResponseRetriever() noexcept = default; virtual ExpectedResponseFactory GetInitialResponseFromModelName( const Crt::String &modelName) const noexcept = 0; virtual ExpectedResponseFactory GetStreamingResponseFromModelName( @@ -457,6 +446,7 @@ namespace Aws class AWS_EVENTSTREAMRPC_API ServiceModel { public: + virtual ~ServiceModel() noexcept = default; virtual Crt::ScopedResource AllocateOperationErrorFromPayload( const Crt::String &errorModelName, Crt::StringView stringView, @@ -467,6 +457,7 @@ namespace Aws { public: OperationModelContext(const ServiceModel &serviceModel) noexcept; + virtual ~OperationModelContext() noexcept = default; virtual Crt::ScopedResource AllocateInitialResponseFromPayload( Crt::StringView stringView, Crt::Allocator *allocator) const noexcept = 0; @@ -497,7 +488,7 @@ namespace Aws std::shared_ptr streamHandler, const OperationModelContext &operationModelContext, Crt::Allocator *allocator) noexcept; - ~ClientOperation() noexcept; + virtual ~ClientOperation() noexcept; ClientOperation(const ClientOperation &clientOperation) noexcept = delete; ClientOperation(ClientOperation &&clientOperation) noexcept = delete; @@ -512,9 +503,6 @@ namespace Aws std::future Activate( const AbstractShapeBase *shape, OnMessageFlushCallback onMessageFlushCallback) noexcept; - std::future SendStreamEvent( - AbstractShapeBase *shape, - OnMessageFlushCallback onMessageFlushCallback) noexcept; virtual Crt::String GetModelName() const noexcept = 0; const OperationModelContext &m_operationModelContext; std::launch m_asyncLaunchMode; @@ -624,14 +612,14 @@ namespace Aws DISCONNECTING, }; /* This recursive mutex protects m_clientState & m_connectionWillSetup */ - std::recursive_mutex m_stateMutex; + std::mutex m_closeReasonMutex; Crt::Allocator *m_allocator; struct aws_event_stream_rpc_client_connection *m_underlyingConnection; - ClientState m_clientState; + std::atomic m_clientState; ConnectionLifecycleHandler *m_lifecycleHandler; ConnectMessageAmender m_connectMessageAmender; std::promise m_connectionSetupPromise; - bool m_connectionWillSetup; + std::atomic m_connectionWillSetup; std::promise m_connectAckedPromise; std::promise m_closedPromise; bool m_onConnectCalled; diff --git a/eventstream_rpc/source/EventStreamClient.cpp b/eventstream_rpc/source/EventStreamClient.cpp index ad5a9159b..36142ef70 100644 --- a/eventstream_rpc/source/EventStreamClient.cpp +++ b/eventstream_rpc/source/EventStreamClient.cpp @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include @@ -79,6 +79,11 @@ namespace Aws MessageAmendment &MessageAmendment::operator=(const MessageAmendment &lhs) { + if (this == &lhs) + { + return *this; + } + m_headers = lhs.m_headers; if (lhs.m_payload.has_value()) { @@ -90,8 +95,8 @@ namespace Aws return *this; } - MessageAmendment::MessageAmendment(MessageAmendment &&rhs) - : m_headers(std::move(rhs.m_headers)), m_payload(rhs.m_payload), m_allocator(rhs.m_allocator) + MessageAmendment::MessageAmendment(MessageAmendment &&rhs) noexcept + : m_headers(std::move(rhs.m_headers)), m_payload(std::move(rhs.m_payload)), m_allocator(rhs.m_allocator) { rhs.m_allocator = nullptr; rhs.m_payload = Crt::Optional(); @@ -149,13 +154,20 @@ namespace Aws } }; + // FIXME This assignment operator is completely broken. + // 1. rhs' internal state can be changed while members are copying one by one, which can lead to this being + // inconsistent/corrupted. + // 2. if rhs is in the CONNECTED state, then a pointer to it has already been passed to the underlying + // libraries and is used in callbacks. + // 3. If this is connected, what will happen to its underlying connection, state, callbacks, etc.? + // + // Option 1 (preferable): This operator should be marked as deleted. It'll be a BREAKING CHANGE. + // Option 2: As an ugly alternative, throw runtime error if `Connect` method was called on rhs or this. ClientConnection &ClientConnection::operator=(ClientConnection &&rhs) noexcept { m_allocator = std::move(rhs.m_allocator); m_underlyingConnection = rhs.m_underlyingConnection; - rhs.m_stateMutex.lock(); - m_clientState = rhs.m_clientState; - rhs.m_stateMutex.unlock(); + m_clientState = rhs.m_clientState.load(); m_lifecycleHandler = rhs.m_lifecycleHandler; m_connectMessageAmender = rhs.m_connectMessageAmender; m_connectAckedPromise = std::move(rhs.m_connectAckedPromise); @@ -174,6 +186,7 @@ namespace Aws return *this; } + // FIXME Mark as deleted. See comment to the assignment operator. ClientConnection::ClientConnection(ClientConnection &&rhs) noexcept : m_lifecycleHandler(rhs.m_lifecycleHandler) { *this = std::move(rhs); @@ -182,28 +195,23 @@ namespace Aws ClientConnection::ClientConnection(Crt::Allocator *allocator) noexcept : m_allocator(allocator), m_underlyingConnection(nullptr), m_clientState(DISCONNECTED), m_lifecycleHandler(nullptr), m_connectMessageAmender(nullptr), m_connectionWillSetup(false), + m_onConnectCalled(false), m_closeReason{EVENT_STREAM_RPC_UNINITIALIZED, 0}, m_onConnectRequestCallback(nullptr) { } ClientConnection::~ClientConnection() noexcept { - m_stateMutex.lock(); if (m_connectionWillSetup) { - m_stateMutex.unlock(); m_connectionSetupPromise.get_future().wait(); } - m_stateMutex.lock(); + if (m_clientState != DISCONNECTED) { Close(); - m_stateMutex.unlock(); m_closedPromise.get_future().wait(); } - /* Cover the case in which the if statements are not hit. */ - m_stateMutex.unlock(); - m_stateMutex.unlock(); m_underlyingConnection = nullptr; } @@ -269,10 +277,9 @@ namespace Aws Crt::Io::ClientBootstrap &clientBootstrap) noexcept { EventStreamRpcStatusCode baseError = EVENT_STREAM_RPC_SUCCESS; - struct aws_event_stream_rpc_client_connection_options connOptions; + aws_event_stream_rpc_client_connection_options connOptions{}; { - const std::lock_guard lock(m_stateMutex); if (m_clientState == DISCONNECTED) { m_clientState = CONNECTING_SOCKET; @@ -280,7 +287,10 @@ namespace Aws m_connectionSetupPromise = {}; m_connectAckedPromise = {}; m_closedPromise = {}; - m_closeReason = {EVENT_STREAM_RPC_UNINITIALIZED, 0}; + { + const std::lock_guard lock(m_closeReasonMutex); + m_closeReason = {EVENT_STREAM_RPC_UNINITIALIZED, 0}; + } m_connectionConfig = connectionConfig; m_lifecycleHandler = connectionLifecycleHandler; } @@ -323,7 +333,6 @@ namespace Aws errorPromise.set_value({baseError, 0}); if (baseError == EVENT_STREAM_RPC_NULL_PARAMETER) { - const std::lock_guard lock(m_stateMutex); m_clientState = DISCONNECTED; } return errorPromise.get_future(); @@ -358,13 +367,11 @@ namespace Aws "A CRT error occurred while attempting to establish the connection: %s", Crt::ErrorDebugString(crtError)); errorPromise.set_value({EVENT_STREAM_RPC_CRT_ERROR, crtError}); - const std::lock_guard lock(m_stateMutex); m_clientState = DISCONNECTED; return errorPromise.get_future(); } else { - const std::lock_guard lock(m_stateMutex); m_connectionWillSetup = true; } @@ -376,7 +383,7 @@ namespace Aws const Crt::Optional &payload, OnMessageFlushCallback onMessageFlushCallback) noexcept { - return s_sendPing(this, headers, payload, onMessageFlushCallback); + return s_sendPing(this, headers, payload, std::move(onMessageFlushCallback)); } std::future ClientConnection::SendPingResponse( @@ -384,7 +391,7 @@ namespace Aws const Crt::Optional &payload, OnMessageFlushCallback onMessageFlushCallback) noexcept { - return s_sendPingResponse(this, headers, payload, onMessageFlushCallback); + return s_sendPingResponse(this, headers, payload, std::move(onMessageFlushCallback)); } std::future ClientConnection::s_sendPing( @@ -394,7 +401,12 @@ namespace Aws OnMessageFlushCallback onMessageFlushCallback) noexcept { return s_sendProtocolMessage( - connection, headers, payload, AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_PING, 0, onMessageFlushCallback); + connection, + headers, + payload, + AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_PING, + 0, + std::move(onMessageFlushCallback)); } std::future ClientConnection::s_sendPingResponse( @@ -409,7 +421,7 @@ namespace Aws payload, AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_PING_RESPONSE, 0, - onMessageFlushCallback); + std::move(onMessageFlushCallback)); } std::future ClientConnection::SendProtocolMessage( @@ -419,7 +431,8 @@ namespace Aws uint32_t messageFlags, OnMessageFlushCallback onMessageFlushCallback) noexcept { - return s_sendProtocolMessage(this, headers, payload, messageType, messageFlags, onMessageFlushCallback); + return s_sendProtocolMessage( + this, headers, payload, messageType, messageFlags, std::move(onMessageFlushCallback)); } void ClientConnection::s_protocolMessageCallback(int errorCode, void *userData) noexcept @@ -458,7 +471,7 @@ namespace Aws { std::promise onFlushPromise; OnMessageFlushCallbackContainer *callbackContainer = nullptr; - struct aws_array_list headersArray; + aws_array_list headersArray{}; /* The caller should never pass a NULL connection. */ AWS_PRECONDITION(connection != nullptr); @@ -468,8 +481,8 @@ namespace Aws if (!errorCode) { - struct aws_event_stream_rpc_message_args msg_args; - msg_args.headers = (struct aws_event_stream_header_value_pair *)headersArray.data; + aws_event_stream_rpc_message_args msg_args{}; + msg_args.headers = static_cast(headersArray.data); msg_args.headers_count = headers.size(); msg_args.payload = payload.has_value() ? (aws_byte_buf *)(&(payload.value())) : nullptr; msg_args.message_type = messageType; @@ -479,7 +492,7 @@ namespace Aws * returns. */ callbackContainer = Crt::New(connection->m_allocator, connection->m_allocator); - callbackContainer->onMessageFlushCallback = onMessageFlushCallback; + callbackContainer->onMessageFlushCallback = std::move(onMessageFlushCallback); callbackContainer->onFlushPromise = std::move(onFlushPromise); errorCode = aws_event_stream_rpc_client_connection_send_protocol_message( @@ -497,6 +510,7 @@ namespace Aws if (errorCode) { + // FIXME Null pointer dereference if s_fillNativeHeadersArray fails. onFlushPromise = std::move(callbackContainer->onFlushPromise); AWS_LOGF_ERROR( AWS_LS_EVENT_STREAM_RPC_CLIENT, @@ -515,8 +529,6 @@ namespace Aws void ClientConnection::Close() noexcept { - const std::lock_guard lock(m_stateMutex); - if (IsOpen()) { aws_event_stream_rpc_client_connection_close(this->m_underlyingConnection, AWS_OP_SUCCESS); @@ -531,6 +543,7 @@ namespace Aws m_clientState = DISCONNECTING; } + const std::lock_guard lock(m_closeReasonMutex); if (m_closeReason.baseStatus == EVENT_STREAM_RPC_UNINITIALIZED) { m_closeReason = {EVENT_STREAM_RPC_CONNECTION_CLOSED, 0}; @@ -584,6 +597,10 @@ namespace Aws EventStreamHeader &EventStreamHeader::operator=(const EventStreamHeader &lhs) noexcept { + if (this == &lhs) + { + return *this; + } m_allocator = lhs.m_allocator; m_valueByteBuf = Crt::ByteBufNewCopy(lhs.m_allocator, lhs.m_valueByteBuf.buffer, lhs.m_valueByteBuf.len); m_underlyingHandle = lhs.m_underlyingHandle; @@ -637,8 +654,6 @@ namespace Aws /* The `userData` pointer is used to pass `this` of a `ClientConnection` object. */ auto *thisConnection = static_cast(userData); - const std::lock_guard lock(thisConnection->m_stateMutex); - if (errorCode) { thisConnection->m_clientState = DISCONNECTED; @@ -655,7 +670,10 @@ namespace Aws else if (thisConnection->m_clientState == DISCONNECTING || thisConnection->m_clientState == DISCONNECTED) { thisConnection->m_underlyingConnection = connection; - thisConnection->m_closeReason = {EVENT_STREAM_RPC_CONNECTION_CLOSED, 0}; + { + const std::lock_guard lock(thisConnection->m_closeReasonMutex); + thisConnection->m_closeReason = {EVENT_STREAM_RPC_CONNECTION_CLOSED, 0}; + } thisConnection->Close(); } else @@ -706,19 +724,21 @@ namespace Aws /* The `userData` pointer is used to pass `this` of a `ClientConnection` object. */ auto *thisConnection = static_cast(userData); - const std::lock_guard lock(thisConnection->m_stateMutex); - - if (thisConnection->m_closeReason.baseStatus == EVENT_STREAM_RPC_UNINITIALIZED && errorCode) { - thisConnection->m_closeReason = {EVENT_STREAM_RPC_CRT_ERROR, errorCode}; - } + const std::lock_guard lock(thisConnection->m_closeReasonMutex); + + if (thisConnection->m_closeReason.baseStatus == EVENT_STREAM_RPC_UNINITIALIZED && errorCode) + { + thisConnection->m_closeReason = {EVENT_STREAM_RPC_CRT_ERROR, errorCode}; + } - thisConnection->m_underlyingConnection = nullptr; + thisConnection->m_underlyingConnection = nullptr; - if (thisConnection->m_closeReason.baseStatus != EVENT_STREAM_RPC_UNINITIALIZED && - !thisConnection->m_onConnectCalled) - { - thisConnection->m_connectAckedPromise.set_value(thisConnection->m_closeReason); + if (thisConnection->m_closeReason.baseStatus != EVENT_STREAM_RPC_UNINITIALIZED && + !thisConnection->m_onConnectCalled) + { + thisConnection->m_connectAckedPromise.set_value(thisConnection->m_closeReason); + } } thisConnection->m_clientState = DISCONNECTED; @@ -765,7 +785,6 @@ namespace Aws switch (messageArgs->message_type) { case AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_CONNECT_ACK: - thisConnection->m_stateMutex.lock(); if (thisConnection->m_clientState == WAITING_FOR_CONNECT_ACK) { if (messageArgs->message_flags & AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_CONNECTION_ACCEPTED) @@ -777,7 +796,10 @@ namespace Aws } else { - thisConnection->m_closeReason = {EVENT_STREAM_RPC_CONNECTION_ACCESS_DENIED, 0}; + { + const std::lock_guard lock(thisConnection->m_closeReasonMutex); + thisConnection->m_closeReason = {EVENT_STREAM_RPC_CONNECTION_ACCESS_DENIED, 0}; + } thisConnection->Close(); } } @@ -785,7 +807,6 @@ namespace Aws { /* Unexpected CONNECT_ACK received. */ } - thisConnection->m_stateMutex.unlock(); break; @@ -793,8 +814,7 @@ namespace Aws for (size_t i = 0; i < messageArgs->headers_count; ++i) { - pingHeaders.emplace_back( - EventStreamHeader(messageArgs->headers[i], thisConnection->m_allocator)); + pingHeaders.emplace_back(messageArgs->headers[i], thisConnection->m_allocator); } if (messageArgs->payload) @@ -847,7 +867,7 @@ namespace Aws Crt::Allocator *allocator) noexcept : m_allocator(allocator), m_continuationHandler(continuationHandler), m_continuationToken(nullptr) { - struct aws_event_stream_rpc_client_stream_continuation_options options; + aws_event_stream_rpc_client_stream_continuation_options options{}; options.on_continuation = ClientContinuation::s_onContinuationMessage; options.on_continuation_closed = ClientContinuation::s_onContinuationClosed; @@ -878,6 +898,9 @@ namespace Aws } if (m_callbackData != nullptr) { + // FIXME Setting `m_callbackData->continuationDestroyed` indicates that another actor is supposed + // to check this flag (see `ClientContinuation::s_onContinuationMessage`). However, we delete + // `m_callbackData` right after setting the flag, so it doesn't work as intended. { const std::lock_guard lock(m_callbackData->callbackMutex); m_callbackData->continuationDestroyed = true; @@ -886,8 +909,6 @@ namespace Aws } } - ClientContinuationHandler::~ClientContinuationHandler() noexcept {} - void ClientContinuation::s_onContinuationMessage( struct aws_event_stream_rpc_client_continuation_token *continuationToken, const struct aws_event_stream_rpc_message_args *messageArgs, @@ -895,12 +916,18 @@ namespace Aws { (void)continuationToken; /* The `userData` pointer is used to pass a `ContinuationCallbackData` object. */ + // FIXME Can `callbackData` be destroyed at this point? See `ClientContinuation::~ClientContinuation`. + // Probably `callbackData` is guaranteed to be alive after this PR: + // https://github.com/aws/aws-iot-device-sdk-cpp-v2/pull/437. But then we need to get rid of the + // `continuationDestroyed` flag. auto *callbackData = static_cast(userData); auto *thisContinuation = callbackData->clientContinuation; Crt::List continuationMessageHeaders; for (size_t i = 0; i < messageArgs->headers_count; ++i) { + // FIXME Considering that below we check if thisContinuation is alive, this line looks super suspicious. + // Keep allocator in callbackData? continuationMessageHeaders.emplace_back( EventStreamHeader(messageArgs->headers[i], thisContinuation->m_allocator)); } @@ -948,7 +975,7 @@ namespace Aws uint32_t messageFlags, OnMessageFlushCallback onMessageFlushCallback) noexcept { - struct aws_array_list headersArray; + aws_array_list headersArray{}; OnMessageFlushCallbackContainer *callbackContainer = nullptr; std::promise onFlushPromise; @@ -979,7 +1006,7 @@ namespace Aws if (!errorCode) { struct aws_event_stream_rpc_message_args msg_args; - msg_args.headers = (struct aws_event_stream_header_value_pair *)headersArray.data; + msg_args.headers = (aws_event_stream_header_value_pair *)headersArray.data; msg_args.headers_count = headers.size(); msg_args.payload = payload.has_value() ? (aws_byte_buf *)(&(payload.value())) : nullptr; msg_args.message_type = messageType; @@ -1007,6 +1034,7 @@ namespace Aws if (errorCode) { + // FIXME Null pointer dereference when s_fillNativeHeadersArray fails. onFlushPromise = std::move(callbackContainer->onFlushPromise); onFlushPromise.set_value({EVENT_STREAM_RPC_CRT_ERROR, errorCode}); Crt::Delete(callbackContainer, m_allocator); @@ -1022,7 +1050,7 @@ namespace Aws uint32_t messageFlags, OnMessageFlushCallback onMessageFlushCallback) noexcept { - struct aws_array_list headersArray; + aws_array_list headersArray{}; OnMessageFlushCallbackContainer *callbackContainer = nullptr; std::promise onFlushPromise; @@ -1037,7 +1065,7 @@ namespace Aws if (!errorCode) { - struct aws_event_stream_rpc_message_args msg_args; + aws_event_stream_rpc_message_args msg_args{}; msg_args.headers = (struct aws_event_stream_header_value_pair *)headersArray.data; msg_args.headers_count = headers.size(); msg_args.payload = payload.has_value() ? (aws_byte_buf *)(&(payload.value())) : nullptr; @@ -1047,7 +1075,7 @@ namespace Aws /* This heap allocation is necessary so that the flush callback can still be invoked when this function * returns. */ callbackContainer = Crt::New(m_allocator, m_allocator); - callbackContainer->onMessageFlushCallback = onMessageFlushCallback; + callbackContainer->onMessageFlushCallback = std::move(onMessageFlushCallback); callbackContainer->onFlushPromise = std::move(onFlushPromise); if (m_continuationToken) @@ -1068,6 +1096,7 @@ namespace Aws if (errorCode) { + // FIXME Null pointer dereference when s_fillNativeHeadersArray fails. onFlushPromise = std::move(callbackContainer->onFlushPromise); AWS_LOGF_ERROR( AWS_LS_EVENT_STREAM_RPC_CLIENT, @@ -1101,21 +1130,17 @@ namespace Aws { } - OperationError::OperationError() noexcept {} - void OperationError::SerializeToJsonObject(Crt::JsonObject &payloadObject) const { (void)payloadObject; } AbstractShapeBase::AbstractShapeBase() noexcept : m_allocator(nullptr) {} - AbstractShapeBase::~AbstractShapeBase() noexcept {} - ClientOperation::ClientOperation( ClientConnection &connection, std::shared_ptr streamHandler, const OperationModelContext &operationModelContext, Crt::Allocator *allocator) noexcept : m_operationModelContext(operationModelContext), m_asyncLaunchMode(std::launch::deferred), - m_messageCount(0), m_allocator(allocator), m_streamHandler(streamHandler), + m_messageCount(0), m_allocator(allocator), m_streamHandler(std::move(streamHandler)), m_clientContinuation(connection.NewStream(*this)), m_expectedCloses(0), m_streamClosedCalled(false) { } @@ -1238,11 +1263,11 @@ namespace Aws const Crt::List &headers, const Crt::String &name) noexcept { - for (auto it = headers.begin(); it != headers.end(); ++it) + for (const auto &header : headers) { - if (name == it->GetHeaderName()) + if (header.GetHeaderName() == name) { - return &(*it); + return &header; } } return nullptr; @@ -1354,7 +1379,7 @@ namespace Aws return true; } - void StreamResponseHandler::OnStreamEvent(Crt::ScopedResource response) {} + void StreamResponseHandler::OnStreamEvent(Crt::ScopedResource /* response */) {} void StreamResponseHandler::OnStreamClosed() {} @@ -1365,8 +1390,6 @@ namespace Aws uint32_t messageFlags) { EventStreamRpcStatusCode errorCode = EVENT_STREAM_RPC_SUCCESS; - const EventStreamHeader *modelHeader = nullptr; - const EventStreamHeader *contentHeader = nullptr; Crt::String modelName; if (messageFlags & AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM) @@ -1377,7 +1400,7 @@ namespace Aws m_messageCount += 1; - modelHeader = GetHeaderByName(headers, Crt::String(SERVICE_MODEL_TYPE_HEADER)); + const EventStreamHeader *modelHeader = GetHeaderByName(headers, Crt::String(SERVICE_MODEL_TYPE_HEADER)); if (modelHeader == nullptr) { /* Missing required service model type header. */ @@ -1416,7 +1439,7 @@ namespace Aws if (!errorCode) { Crt::String contentType; - contentHeader = GetHeaderByName(headers, Crt::String(CONTENT_TYPE_HEADER)); + const EventStreamHeader *contentHeader = GetHeaderByName(headers, Crt::String(CONTENT_TYPE_HEADER)); if (contentHeader == nullptr) { /* Missing required content type header. */ @@ -1456,7 +1479,7 @@ namespace Aws { const std::lock_guard lock(m_continuationMutex); m_resultReceived = true; - RpcError promiseValue = {(EventStreamRpcStatusCode)errorCode, 0}; + RpcError promiseValue = {errorCode, 0}; m_initialResponsePromise.set_value(TaggedResult(promiseValue)); } else @@ -1478,17 +1501,36 @@ namespace Aws { /* Promises must be reset in case the client would like to send a subsequent request with the same * `ClientOperation`. */ + // TODO When compiling with fno-exceptions, resetting promise can lead to abort() if someone is waiting + // for the associated future. m_initialResponsePromise = {}; { + // FIXME Possible race condition with ClientOperation::HandleData and/or other functions. + // t2: Calls ClientOperation::HandleData + // t2: Acquires m_continuationMutex + // t1: "Resets" m_initialResponsePromise + // t1: Locks here on m_continuationMutex + // t2: m_initialResponsePromise.set_value() + // t2: m_resultReceived = true + // t2: Releases m_continuationMutex (and completes ClientOperation::HandleData) + // t1: Acquires m_continuationMutex + // t1: m_resultReceived = false + // t1: Releases m_continuationMutex + // t2: Calls ClientOperation::OnContinuationClosed + // t2: Acquires m_continuationMutex + // t2: !m_resultReceived is true + // t2: m_initialResponsePromise.set_value() is called the second time + // t2: Calls abort() + // Not sure if this exact scenario is possible, but considering that this scheme is used in other + // methods, it'll be safer to fix this construction. const std::lock_guard lock(m_continuationMutex); m_resultReceived = false; } Crt::List headers; - headers.emplace_back(EventStreamHeader( - Crt::String(CONTENT_TYPE_HEADER), Crt::String(CONTENT_TYPE_APPLICATION_JSON), m_allocator)); headers.emplace_back( - EventStreamHeader(Crt::String(SERVICE_MODEL_TYPE_HEADER), GetModelName(), m_allocator)); + Crt::String(CONTENT_TYPE_HEADER), Crt::String(CONTENT_TYPE_APPLICATION_JSON), m_allocator); + headers.emplace_back(Crt::String(SERVICE_MODEL_TYPE_HEADER), GetModelName(), m_allocator); Crt::JsonObject payloadObject; shape->SerializeToJsonObject(payloadObject); Crt::String payloadString = payloadObject.View().WriteCompact(); @@ -1498,7 +1540,7 @@ namespace Aws Crt::ByteBufFromCString(payloadString.c_str()), AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_APPLICATION_MESSAGE, 0, - onMessageFlushCallback); + std::move(onMessageFlushCallback)); } void ClientOperation::OnContinuationClosed() @@ -1533,53 +1575,37 @@ namespace Aws errorPromise.set_value({EVENT_STREAM_RPC_CONTINUATION_CLOSED, 0}); return errorPromise.get_future(); } - else - { - std::promise onTerminatePromise; - int errorCode = AWS_OP_ERR; - struct aws_event_stream_rpc_message_args msg_args; - msg_args.headers = nullptr; - msg_args.headers_count = 0; - msg_args.payload = nullptr; - msg_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_APPLICATION_MESSAGE; - msg_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM; + aws_event_stream_rpc_message_args msg_args{}; + msg_args.message_type = AWS_EVENT_STREAM_RPC_MESSAGE_TYPE_APPLICATION_MESSAGE; + msg_args.message_flags = AWS_EVENT_STREAM_RPC_MESSAGE_FLAG_TERMINATE_STREAM; - /* This heap allocation is necessary so that the flush callback can still be invoked when this function - * returns. */ - OnMessageFlushCallbackContainer *callbackContainer = - Crt::New(m_allocator, m_allocator); - callbackContainer->onMessageFlushCallback = onMessageFlushCallback; - callbackContainer->onFlushPromise = std::move(onTerminatePromise); + /* This heap allocation is necessary so that the flush callback can still be invoked when this function + * returns. */ + auto *callbackContainer = Crt::New(m_allocator, m_allocator); + callbackContainer->onMessageFlushCallback = std::move(onMessageFlushCallback); - if (m_clientContinuation.m_continuationToken) - { - errorCode = aws_event_stream_rpc_client_continuation_send_message( - m_clientContinuation.m_continuationToken, - &msg_args, - ClientConnection::s_protocolMessageCallback, - reinterpret_cast(callbackContainer)); - } + int errorCode = aws_event_stream_rpc_client_continuation_send_message( + m_clientContinuation.m_continuationToken, + &msg_args, + ClientConnection::s_protocolMessageCallback, + callbackContainer); - if (errorCode) - { - onTerminatePromise = std::move(callbackContainer->onFlushPromise); - std::promise errorPromise; - AWS_LOGF_ERROR( - AWS_LS_EVENT_STREAM_RPC_CLIENT, - "A CRT error occurred while closing the stream: %s", - Crt::ErrorDebugString(errorCode)); - onTerminatePromise.set_value({EVENT_STREAM_RPC_CRT_ERROR, errorCode}); - Crt::Delete(callbackContainer, m_allocator); - } - else - { - m_expectedCloses.fetch_add(1); - return callbackContainer->onFlushPromise.get_future(); - } + if (errorCode) + { + AWS_LOGF_ERROR( + AWS_LS_EVENT_STREAM_RPC_CLIENT, + "A CRT error occurred while closing the stream: %s", + Crt::ErrorDebugString(errorCode)); + Crt::Delete(callbackContainer, m_allocator); + std::promise onTerminatePromise; + onTerminatePromise.set_value({EVENT_STREAM_RPC_CRT_ERROR, errorCode}); return onTerminatePromise.get_future(); } + + m_expectedCloses.fetch_add(1); + return callbackContainer->onFlushPromise.get_future(); } void OperationError::s_customDeleter(OperationError *shape) noexcept