diff --git a/src/core/coap/coap.cpp b/src/core/coap/coap.cpp index 242ece0f163..b57bb6ae0b0 100644 --- a/src/core/coap/coap.cpp +++ b/src/core/coap/coap.cpp @@ -53,10 +53,11 @@ CoapBase::CoapBase(Instance &aInstance, Sender aSender) { } -void CoapBase::ClearRequestsAndResponses(void) +void CoapBase::ClearAllRequestsAndResponses(void) { ClearRequests(nullptr); // Clear requests matching any address. mResponsesQueue.DequeueAllResponses(); + mRetransmissionTimer.Stop(); } void CoapBase::ClearRequests(const Ip6::Address &aAddress) { ClearRequests(&aAddress); } @@ -1553,7 +1554,11 @@ void ResponsesQueue::UpdateQueue(void) void ResponsesQueue::DequeueResponse(Message &aMessage) { mQueue.DequeueAndFree(aMessage); } -void ResponsesQueue::DequeueAllResponses(void) { mQueue.DequeueAndFreeAll(); } +void ResponsesQueue::DequeueAllResponses(void) +{ + mQueue.DequeueAndFreeAll(); + mTimer.Stop(); +} void ResponsesQueue::HandleTimer(Timer &aTimer) { @@ -1693,7 +1698,7 @@ Error Coap::Stop(void) VerifyOrExit(mSocket.IsBound()); SuccessOrExit(error = mSocket.Close()); - ClearRequestsAndResponses(); + ClearAllRequestsAndResponses(); exit: return error; diff --git a/src/core/coap/coap.hpp b/src/core/coap/coap.hpp index 03f0d4b4de9..e21879dc426 100644 --- a/src/core/coap/coap.hpp +++ b/src/core/coap/coap.hpp @@ -358,9 +358,9 @@ class CoapBase : public InstanceLocator, private NonCopyable typedef Error (*Interceptor)(const Message &aMessage, const Ip6::MessageInfo &aMessageInfo, void *aContext); /** - * Clears requests and responses used by this CoAP agent. + * Clears all requests and responses used by this CoAP agent and stops all timers. */ - void ClearRequestsAndResponses(void); + void ClearAllRequestsAndResponses(void); /** * Clears requests with specified source address used by this CoAP agent. diff --git a/src/core/coap/coap_secure.cpp b/src/core/coap/coap_secure.cpp index 91e236dc56e..1d794ecc535 100644 --- a/src/core/coap/coap_secure.cpp +++ b/src/core/coap/coap_secure.cpp @@ -51,6 +51,13 @@ SecureSession::SecureSession(Instance &aInstance, Dtls::Transport &aDtlsTranspor Dtls::Session::SetReceiveCallback(HandleDtlsReceive, this); } +void SecureSession::Cleanup(void) +{ + ClearAllRequestsAndResponses(); + mTransmitQueue.DequeueAndFreeAll(); + mTransmitTask.Unpost(); +} + #if OPENTHREAD_CONFIG_COAP_BLOCKWISE_TRANSFER_ENABLE Error SecureSession::SendMessage(Message &aMessage, @@ -103,7 +110,7 @@ void SecureSession::HandleDtlsConnectEvent(ConnectEvent aEvent) if (aEvent != kConnected) { mTransmitQueue.DequeueAndFreeAll(); - ClearRequestsAndResponses(); + ClearAllRequestsAndResponses(); } mConnectCallback.InvokeIfSet(aEvent); @@ -151,6 +158,22 @@ void SecureSession::HandleTransmitTask(void) FreeMessageOnError(message, error); } +#if OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE + +MeshCoP::SecureSession *ApplicationCoapSecure::HandleDtlsAccept(void *aContext, const Ip6::MessageInfo &aMessageInfo) +{ + OT_UNUSED_VARIABLE(aMessageInfo); + + return static_cast(aContext)->HandleDtlsAccept(); +} + +SecureSession *ApplicationCoapSecure::HandleDtlsAccept(void) +{ + return IsSessionInUse() ? nullptr : static_cast(this); +} + +#endif + } // namespace Coap } // namespace ot diff --git a/src/core/coap/coap_secure.hpp b/src/core/coap/coap_secure.hpp index 16bbabd5491..4ac022621f2 100644 --- a/src/core/coap/coap_secure.hpp +++ b/src/core/coap/coap_secure.hpp @@ -60,6 +60,11 @@ typedef MeshCoP::Dtls Dtls; class SecureSession : public CoapBase, public Dtls::Session { public: + /** + * Dequeues and frees all queued messages (requests and responses) and stops all timers and tasklets. + */ + void Cleanup(void); + /** * Sets the connection event callback. * @@ -148,8 +153,13 @@ class ApplicationCoapSecure : public Dtls::Transport, public Dtls::Transport::Ex , Dtls::Transport::Extension(static_cast(*this)) , SecureSession(aInstance, static_cast(*this)) { + Dtls::Transport::SetAcceptCallback(HandleDtlsAccept, this); Dtls::Transport::SetExtension(static_cast(*this)); } + +private: + static MeshCoP::SecureSession *HandleDtlsAccept(void *aContext, const Ip6::MessageInfo &aMessageInfo); + SecureSession *HandleDtlsAccept(void); }; #endif // OPENTHREAD_CONFIG_COAP_SECURE_API_ENABLE diff --git a/src/core/meshcop/border_agent.cpp b/src/core/meshcop/border_agent.cpp index 1c3c92820fd..2792bb27be5 100644 --- a/src/core/meshcop/border_agent.cpp +++ b/src/core/meshcop/border_agent.cpp @@ -51,6 +51,7 @@ BorderAgent::BorderAgent(Instance &aInstance) , mUdpReceiver(BorderAgent::HandleUdpReceive, this) , mTimer(aInstance) , mDtlsTransport(aInstance, kNoLinkSecurity) + , mCoapDtlsSession(nullptr) #if OPENTHREAD_CONFIG_BORDER_AGENT_ID_ENABLE , mIdInitialized(false) #endif @@ -129,16 +130,14 @@ Error BorderAgent::Start(uint16_t aUdpPort, const uint8_t *aPsk, uint8_t aPskLen } #endif - mCoapDtlsSession.Reset(CoapDtlsSession::Allocate(GetInstance(), mDtlsTransport)); - VerifyOrExit(mCoapDtlsSession != nullptr, error = kErrorNoBufs); + mDtlsTransport.SetAcceptCallback(HandleAcceptSession, this); + mDtlsTransport.SetRemoveSessionCallback(HandleRemoveSession, this); SuccessOrExit(error = mDtlsTransport.Open()); SuccessOrExit(error = mDtlsTransport.Bind(aUdpPort)); SuccessOrExit(error = mDtlsTransport.SetPsk(aPsk, aPskLength)); - mCoapDtlsSession->SetConnectCallback(HandleConnected, this); - mState = kStateStarted; LogInfo("Border Agent start listening on port %u", GetUdpPort()); @@ -163,7 +162,6 @@ void BorderAgent::Stop(void) mTimer.Stop(); mDtlsTransport.Close(); - mCoapDtlsSession.Free(); mState = kStateStopped; LogInfo("Border Agent stopped"); @@ -175,6 +173,7 @@ void BorderAgent::Stop(void) void BorderAgent::Disconnect(void) { VerifyOrExit(mState == kStateConnected || mState == kStateAccepted); + VerifyOrExit(mCoapDtlsSession != nullptr); mCoapDtlsSession->Disconnect(); @@ -225,11 +224,51 @@ void BorderAgent::HandleNotifierEvents(Events aEvents) void BorderAgent::HandleTimeout(void) { - if (mCoapDtlsSession->IsConnected()) - { - mCoapDtlsSession->Disconnect(); - LogWarn("Reset secure session"); - } + VerifyOrExit(mCoapDtlsSession != nullptr); + VerifyOrExit(mCoapDtlsSession->IsConnected()); + + mCoapDtlsSession->Disconnect(); + LogWarn("Reset secure session"); + +exit: + return; +} + +SecureSession *BorderAgent::HandleAcceptSession(void *aContext, const Ip6::MessageInfo &aMessageInfo) +{ + OT_UNUSED_VARIABLE(aMessageInfo); + + return static_cast(aContext)->HandleAcceptSession(); +} + +BorderAgent::CoapDtlsSession *BorderAgent::HandleAcceptSession(void) +{ + CoapDtlsSession *session = nullptr; + + VerifyOrExit(mCoapDtlsSession == nullptr); + + session = CoapDtlsSession::Allocate(GetInstance(), mDtlsTransport); + VerifyOrExit(session != nullptr); + + session->SetConnectCallback(HandleConnected, this); + mCoapDtlsSession = session; + +exit: + return session; +} + +void BorderAgent::HandleRemoveSession(void *aContext, SecureSession &aSesssion) +{ + static_cast(aContext)->HandleRemoveSession(aSesssion); +} + +void BorderAgent::HandleRemoveSession(SecureSession &aSesssion) +{ + CoapDtlsSession &coapSession = static_cast(aSesssion); + + coapSession.Cleanup(); + coapSession.Free(); + mCoapDtlsSession = nullptr; } void BorderAgent::HandleConnected(Dtls::Session::ConnectEvent aEvent, void *aContext) @@ -315,6 +354,7 @@ Error BorderAgent::ForwardToLeader(const Coap::Message &aMessage, const Ip6::Mes OffsetRange offsetRange; VerifyOrExit(mState != kStateStopped); + VerifyOrExit(mCoapDtlsSession != nullptr); switch (aUri) { @@ -389,6 +429,7 @@ void BorderAgent::HandleCoapResponse(const ForwardContext &aForwardContext, Error error; SuccessOrExit(error = aResult); + VerifyOrExit(mCoapDtlsSession != nullptr); VerifyOrExit((message = mCoapDtlsSession->NewPriorityMessage()) != nullptr, error = kErrorNoBufs); if (aForwardContext.IsPetition() && aResponse->GetCode() == Coap::kCodeChanged) @@ -464,6 +505,7 @@ bool BorderAgent::HandleUdpReceive(const Message &aMessage, const Ip6::MessageIn OffsetRange offsetRange; VerifyOrExit(aMessageInfo.GetSockAddr() == mCommissionerAloc.GetAddress()); + VerifyOrExit(mCoapDtlsSession != nullptr); didHandle = true; @@ -499,9 +541,11 @@ bool BorderAgent::HandleUdpReceive(const Message &aMessage, const Ip6::MessageIn Error BorderAgent::ForwardToCommissioner(Coap::Message &aForwardMessage, const Message &aMessage) { - Error error; + Error error = kErrorNone; OffsetRange offsetRange; + VerifyOrExit(mCoapDtlsSession != nullptr); + offsetRange.InitFromMessageOffsetToEnd(aMessage); SuccessOrExit(error = aForwardMessage.AppendBytesFromMessage(aMessage, offsetRange)); @@ -519,6 +563,8 @@ void BorderAgent::SendErrorMessage(const ForwardContext &aForwardContext, Error Error error = kErrorNone; Coap::Message *message = nullptr; + VerifyOrExit(mCoapDtlsSession != nullptr); + VerifyOrExit((message = mCoapDtlsSession->NewPriorityMessage()) != nullptr, error = kErrorNoBufs); SuccessOrExit(error = aForwardContext.ToHeader(*message, CoapCodeFromError(aError))); SuccessOrExit(error = mCoapDtlsSession->SendMessage(*message)); @@ -533,6 +579,8 @@ void BorderAgent::SendErrorMessage(const Coap::Message &aRequest, bool aSeparate Error error = kErrorNone; Coap::Message *message = nullptr; + VerifyOrExit(mCoapDtlsSession != nullptr); + VerifyOrExit((message = mCoapDtlsSession->NewPriorityMessage()) != nullptr, error = kErrorNoBufs); if (aRequest.IsNonConfirmable() || aSeparate) @@ -593,6 +641,8 @@ template <> void BorderAgent::HandleTmf(Coap::Message &aMessage, co VerifyOrExit(aMessage.IsNonConfirmablePostRequest(), error = kErrorDrop); + VerifyOrExit(mCoapDtlsSession != nullptr); + message = mCoapDtlsSession->NewPriorityNonConfirmablePostMessage(kUriRelayRx); VerifyOrExit(message != nullptr, error = kErrorNoBufs); @@ -681,6 +731,8 @@ void BorderAgent::HandleTmfDatasetGet(Coap::Message &aMessage, const Ip6::Messag Error error = kErrorNone; Coap::Message *response = nullptr; + VerifyOrExit(mCoapDtlsSession != nullptr); + // When processing `MGMT_GET` request directly on Border Agent, // the Security Policy flags (O-bit) should be ignored to allow // the commissioner candidate to get the full Operational Dataset. diff --git a/src/core/meshcop/border_agent.hpp b/src/core/meshcop/border_agent.hpp index 25751c2ea42..960399070c8 100644 --- a/src/core/meshcop/border_agent.hpp +++ b/src/core/meshcop/border_agent.hpp @@ -300,6 +300,11 @@ class BorderAgent : public InstanceLocator, private NonCopyable template void HandleTmf(Coap::Message &aMessage, const Ip6::MessageInfo &aMessageInfo); + static SecureSession *HandleAcceptSession(void *aContext, const Ip6::MessageInfo &aMessageInfo); + CoapDtlsSession *HandleAcceptSession(void); + static void HandleRemoveSession(void *aContext, SecureSession &aSesssion); + void HandleRemoveSession(SecureSession &aSesssion); + static void HandleConnected(Dtls::Session::ConnectEvent aEvent, void *aContext); void HandleConnected(Dtls::Session::ConnectEvent aEvent); static void HandleCoapResponse(void *aContext, @@ -331,7 +336,7 @@ class BorderAgent : public InstanceLocator, private NonCopyable Ip6::Netif::UnicastAddress mCommissionerAloc; TimeoutTimer mTimer; Dtls::Transport mDtlsTransport; - OwnedPtr mCoapDtlsSession; + CoapDtlsSession *mCoapDtlsSession; #if OPENTHREAD_CONFIG_BORDER_AGENT_ID_ENABLE Id mId; bool mIdInitialized; diff --git a/src/core/meshcop/secure_transport.cpp b/src/core/meshcop/secure_transport.cpp index 27030bcf1b1..234228cab08 100644 --- a/src/core/meshcop/secure_transport.cpp +++ b/src/core/meshcop/secure_transport.cpp @@ -51,13 +51,22 @@ RegisterLogModule("SecTransport"); // SecureSession SecureSession::SecureSession(SecureTransport &aTransport) - : mTimerSet(false) - , mState(kStateDisconnected) - , mMessageSubType(Message::kSubTypeNone) - , mConnectEvent(kDisconnectedError) - , mTransport(aTransport) - , mReceiveMessage(nullptr) + : mTransport(aTransport) { + Init(); +} + +void SecureSession::Init(void) +{ + mTimerSet = false; + mIsServer = false; + mState = kStateDisconnected; + mMessageSubType = Message::kSubTypeNone; + mConnectEvent = kDisconnectedError; + mReceiveMessage = nullptr; + mMessageInfo.Clear(); + + MarkAsNotUsed(); ClearAllBytes(mSsl); ClearAllBytes(mConf); #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_COOKIE_C) @@ -99,35 +108,39 @@ Error SecureSession::Connect(const Ip6::SockAddr &aSockAddr) Error error; VerifyOrExit(mTransport.mIsOpen, error = kErrorInvalidState); - VerifyOrExit(IsDisconnected(), error = kErrorInvalidState); + VerifyOrExit(!IsSessionInUse(), error = kErrorInvalidState); - mTransport.DecremenetRemainingConnectionAttempts(); + Init(); mMessageInfo.SetPeerAddr(aSockAddr.GetAddress()); mMessageInfo.SetPeerPort(aSockAddr.mPort); - mTransport.mIsServer = false; + SuccessOrExit(error = Setup()); - error = Setup(); + mTransport.mSessions.Push(*this); exit: return error; } -void SecureSession::HandleTransportReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo) +void SecureSession::Accept(Message &aMessage, const Ip6::MessageInfo &aMessageInfo) { - if (IsDisconnected()) - { - mTransport.DecremenetRemainingConnectionAttempts(); + mMessageInfo.SetPeerAddr(aMessageInfo.GetPeerAddr()); + mMessageInfo.SetPeerPort(aMessageInfo.GetPeerPort()); + mMessageInfo.SetIsHostInterface(aMessageInfo.IsHostInterface()); + mMessageInfo.SetSockAddr(aMessageInfo.GetSockAddr()); + mMessageInfo.SetSockPort(aMessageInfo.GetSockPort()); - mMessageInfo.SetPeerAddr(aMessageInfo.GetPeerAddr()); - mMessageInfo.SetPeerPort(aMessageInfo.GetPeerPort()); - mMessageInfo.SetIsHostInterface(aMessageInfo.IsHostInterface()); - - mMessageInfo.SetSockAddr(aMessageInfo.GetSockAddr()); - mMessageInfo.SetSockPort(aMessageInfo.GetSockPort()); + mIsServer = true; - SuccessOrExit(Setup()); + if (Setup() == kErrorNone) + { + HandleTransportReceive(aMessage); } +} + +void SecureSession::HandleTransportReceive(Message &aMessage) +{ + VerifyOrExit(!IsDisconnected()); #ifdef MBEDTLS_SSL_SRV_C if (IsConnecting()) @@ -151,17 +164,23 @@ Error SecureSession::Setup(void) OT_ASSERT(mTransport.mCipherSuite != SecureTransport::kUnspecifiedCipherSuite); - VerifyOrExit(mTransport.mIsOpen, error = kErrorInvalidState); - VerifyOrExit(IsDisconnected(), error = kErrorBusy); - SetState(kStateInitializing); + if (mTransport.HasNoRemainingConnectionAttempts()) + { + mConnectEvent = kDisconnectedMaxAttempts; + error = kErrorNoBufs; + ExitNow(); + } + + mTransport.DecremenetRemainingConnectionAttempts(); + //- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // Setup the mbedtls_ssl_config `mConf`. mbedtls_ssl_config_init(&mConf); - rval = mbedtls_ssl_config_defaults(&mConf, mTransport.mIsServer ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, + rval = mbedtls_ssl_config_defaults(&mConf, mIsServer ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, mTransport.mDatagramTransport ? MBEDTLS_SSL_TRANSPORT_DATAGRAM : MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); @@ -257,7 +276,7 @@ Error SecureSession::Setup(void) { mbedtls_ssl_cookie_init(&mCookieCtx); - if (mTransport.mIsServer) + if (mIsServer) { rval = mbedtls_ssl_cookie_setup(&mCookieCtx, Crypto::MbedTls::CryptoSecurePrng, nullptr); VerifyOrExit(rval == 0); @@ -302,18 +321,11 @@ Error SecureSession::Setup(void) exit: if (IsInitializing()) { - error = Crypto::MbedTls::MapError(rval); + error = (error == kErrorNone) ? Crypto::MbedTls::MapError(rval) : error; - if (mTransport.HasNoRemainingConnectionAttempts()) - { - mTransport.Close(); - mTransport.mAutoCloseCallback.InvokeIfSet(); - } - else - { - SetState(kStateDisconnected); - FreeMbedtls(); - } + SetState(kStateDisconnected); + FreeMbedtls(); + mTransport.mUpdateTask.Post(); } return error; @@ -332,8 +344,6 @@ void SecureSession::Disconnect(ConnectEvent aEvent) mTimerFinish = TimerMilli::GetNow() + kGuardTimeNewConnectionMilli; mTransport.mTimer.FireAtIfEarlier(mTimerFinish); - mMessageInfo.Clear(); - FreeMbedtls(); exit: @@ -488,18 +498,8 @@ void SecureSession::HandleTimer(TimeMilli aNow) ExitNow(); } - if (mTransport.HasNoRemainingConnectionAttempts()) - { - mTransport.Close(); - mConnectEvent = kDisconnectedMaxAttempts; - mTransport.mAutoCloseCallback.InvokeIfSet(); - } - else - { - SetState(kStateDisconnected); - } - - mConnectedCallback.InvokeIfSet(mConnectEvent); + SetState(kStateDisconnected); + mTransport.mUpdateTask.Post(); } exit: @@ -659,15 +659,15 @@ SecureTransport::SecureTransport(Instance &aInstance, LinkSecurityMode aLayerTwo : mLayerTwoSecurity(aLayerTwoSecurity) , mDatagramTransport(aDatagramTransport) , mIsOpen(false) - , mIsServer(true) + , mIsClosing(false) , mVerifyPeerCertificate(true) , mCipherSuite(kUnspecifiedCipherSuite) , mPskLength(0) , mMaxConnectionAttempts(0) , mRemainingConnectionAttempts(0) - , mSession(nullptr) , mSocket(aInstance, *this) - , mTimer(aInstance, SecureTransport::HandleTimer, this) + , mTimer(aInstance, HandleTimer, this) + , mUpdateTask(aInstance, HandleUpdateTask, this) #if OPENTHREAD_CONFIG_TLS_API_ENABLE , mExtension(nullptr) #endif @@ -705,14 +705,29 @@ Error SecureTransport::SetMaxConnectionAttempts(uint16_t aMaxAttempts, AutoClose void SecureTransport::HandleReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo) { + SecureSession *session; + VerifyOrExit(mIsOpen); - if (!mSession->IsDisconnected()) + session = mSessions.FindMatching(aMessageInfo); + + if (session != nullptr) { - VerifyOrExit(mSession->Matches(aMessageInfo)); + session->HandleTransportReceive(aMessage); + ExitNow(); } - mSession->HandleTransportReceive(aMessage, aMessageInfo); + // A new connection request + + VerifyOrExit(mAcceptCallback.IsSet()); + + session = mAcceptCallback.Invoke(aMessageInfo); + VerifyOrExit(session != nullptr); + + session->Init(); + mSessions.Push(*session); + + session->Accept(aMessage, aMessageInfo); exit: return; @@ -725,10 +740,9 @@ Error SecureTransport::Bind(uint16_t aPort) VerifyOrExit(mIsOpen, error = kErrorInvalidState); VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready); - VerifyOrExit(mSession->IsDisconnected(), error = kErrorInvalidState); + VerifyOrExit(mSessions.IsEmpty(), error = kErrorInvalidState); - SuccessOrExit(error = mSocket.Bind(aPort)); - mIsServer = true; + error = mSocket.Bind(aPort); exit: return error; @@ -742,10 +756,9 @@ Error SecureTransport::Bind(TransportCallback aCallback, void *aContext) VerifyOrExit(!mSocket.IsBound(), error = kErrorAlready); VerifyOrExit(!mTransportCallback.IsSet(), error = kErrorAlready); - VerifyOrExit(mSession->IsDisconnected(), error = kErrorInvalidState); + VerifyOrExit(mSessions.IsEmpty(), error = kErrorInvalidState); mTransportCallback.Set(aCallback, aContext); - mIsServer = true; exit: return error; @@ -754,11 +767,26 @@ Error SecureTransport::Bind(TransportCallback aCallback, void *aContext) void SecureTransport::Close(void) { VerifyOrExit(mIsOpen); + VerifyOrExit(!mIsClosing); - mSession->Disconnect(SecureSession::kDisconnectedLocalClosed); - mSession->SetState(SecureSession::kStateDisconnected); + // `mIsClosing` is used to protect against multiple + // calls to `Close()` and re-entry. As the transport is closed, + // all existing sessions are disconnected, which can trigger + // connect and remove callbacks to be invoked. These callbacks + // may call `Close()` again. - mIsOpen = false; + mIsClosing = true; + + for (SecureSession &session : mSessions) + { + session.Disconnect(SecureSession::kDisconnectedLocalClosed); + session.SetState(SecureSession::kStateDisconnected); + } + + RemoveDisconnectedSessions(); + + mIsOpen = false; + mIsClosing = false; mTransportCallback.Clear(); IgnoreError(mSocket.Close()); mTimer.Stop(); @@ -767,6 +795,22 @@ void SecureTransport::Close(void) return; } +void SecureTransport::RemoveDisconnectedSessions(void) +{ + LinkedList disconnectedSessions; + SecureSession *session; + + mSessions.RemoveAllMatching(disconnectedSessions, SecureSession::kStateDisconnected); + + while ((session = disconnectedSessions.Pop()) != nullptr) + { + session->mConnectedCallback.InvokeIfSet(session->mConnectEvent); + session->MarkAsNotUsed(); + session->mMessageInfo.Clear(); + mRemoveSessionCallback.InvokeIfSet(*session); + } +} + void SecureTransport::DecremenetRemainingConnectionAttempts(void) { if (mRemainingConnectionAttempts > 0) @@ -931,6 +975,22 @@ int SecureTransport::HandleMbedtlsExportKeys(const unsigned char *aMasterSecret, #endif // (MBEDTLS_VERSION_NUMBER >= 0x03000000) +void SecureTransport::HandleUpdateTask(Tasklet &aTasklet) +{ + static_cast(static_cast(aTasklet).GetContext())->HandleUpdateTask(); +} + +void SecureTransport::HandleUpdateTask(void) +{ + RemoveDisconnectedSessions(); + + if (mSessions.IsEmpty() && HasNoRemainingConnectionAttempts()) + { + Close(); + mAutoCloseCallback.InvokeIfSet(); + } +} + void SecureTransport::HandleTimer(Timer &aTimer) { static_cast(static_cast(aTimer).GetContext())->HandleTimer(); @@ -938,12 +998,17 @@ void SecureTransport::HandleTimer(Timer &aTimer) void SecureTransport::HandleTimer(void) { - if (mIsOpen) - { - TimeMilli now = TimerMilli::GetNow(); + TimeMilli now = TimerMilli::GetNow(); - mSession->HandleTimer(now); + VerifyOrExit(mIsOpen); + + for (SecureSession &session : mSessions) + { + session.HandleTimer(now); } + +exit: + return; } void SecureTransport::HandleMbedtlsDebug(void *aContext, int aLevel, const char *aFile, int aLine, const char *aStr) @@ -1138,8 +1203,9 @@ Error SecureTransport::Extension::GetPeerCertificateBase64(unsigned char *aPeerC size_t aCertBufferSize) { Error error = kErrorNone; - SecureSession *session = mSecureTransport.mSession; + SecureSession *session = mSecureTransport.mSessions.GetHead(); + VerifyOrExit(session != nullptr, error = kErrorInvalidState); VerifyOrExit(session->IsConnected(), error = kErrorInvalidState); #if (MBEDTLS_VERSION_NUMBER >= 0x03010000) @@ -1174,8 +1240,13 @@ Error SecureTransport::Extension::GetPeerSubjectAttributeByOid(const char *aOid, const mbedtls_asn1_named_data *data; size_t length; size_t attributeBufferSize; - SecureSession *session = mSecureTransport.mSession; - mbedtls_x509_crt *peerCert = const_cast(mbedtls_ssl_get_peer_cert(&session->mSsl)); + SecureSession *session; + mbedtls_x509_crt *peerCert; + + session = mSecureTransport.mSessions.GetHead(); + VerifyOrExit(session != nullptr, error = kErrorInvalidState); + + peerCert = const_cast(mbedtls_ssl_get_peer_cert(&session->mSsl)); VerifyOrExit(aAttributeLength != nullptr, error = kErrorInvalidArgs); attributeBufferSize = *aAttributeLength; @@ -1206,9 +1277,16 @@ Error SecureTransport::Extension::GetThreadAttributeFromPeerCertificate(int uint8_t *aAttributeBuffer, size_t *aAttributeLength) { - const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&mSecureTransport.mSession->mSsl); + Error error; + SecureSession *session = mSecureTransport.mSessions.GetHead(); + const mbedtls_x509_crt *cert; - return GetThreadAttributeFromCertificate(cert, aThreadOidDescriptor, aAttributeBuffer, aAttributeLength); + VerifyOrExit(session != nullptr, error = kErrorInvalidState); + cert = mbedtls_ssl_get_peer_cert(&session->mSsl); + error = GetThreadAttributeFromCertificate(cert, aThreadOidDescriptor, aAttributeBuffer, aAttributeLength); + +exit: + return error; } #endif // defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE) @@ -1300,6 +1378,22 @@ Error SecureTransport::Extension::GetThreadAttributeFromCertificate(const mbedtl #endif // OPENTHREAD_CONFIG_TLS_API_ENABLE +#if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE + +//--------------------------------------------------------------------------------------------------------------------- +// Tls + +SecureSession *Tls::HandleAccept(void *aContext, const Ip6::MessageInfo &aMessageInfo) +{ + OT_UNUSED_VARIABLE(aMessageInfo); + + return static_cast(aContext)->HandleAccept(); +} + +SecureSession *Tls::HandleAccept(void) { return IsSessionInUse() ? nullptr : static_cast(this); } + +#endif + } // namespace MeshCoP } // namespace ot diff --git a/src/core/meshcop/secure_transport.hpp b/src/core/meshcop/secure_transport.hpp index 81f4a3ca7f6..0c07e43fef8 100644 --- a/src/core/meshcop/secure_transport.hpp +++ b/src/core/meshcop/secure_transport.hpp @@ -74,6 +74,7 @@ #include #include "common/callback.hpp" +#include "common/linked_list.hpp" #include "common/locator.hpp" #include "common/log.hpp" #include "common/message.hpp" @@ -99,8 +100,10 @@ class Tls; /** * Represents a secure session. */ -class SecureSession : private NonCopyable +class SecureSession : private LinkedListEntry, private NonCopyable { + friend class LinkedListEntry; + friend class LinkedList; friend class SecureTransport; friend class Dtls; #if OPENTHREAD_CONFIG_BLE_TCAT_ENABLE @@ -155,10 +158,16 @@ class SecureSession : private NonCopyable /** * Establishes a secure session (as client). * + * On success, ownership of the session is passed to the associated secure transport (`GetTransport()`). + * The transport will then manage the session. Once the session is disconnected and removed from the transport, the + * secure transport signals this using the `RemoveSessionCallback` callback, where ownership is + * released. + * * @param[in] aSockAddr The server address to connect to. * * @retval kErrorNone Successfully started session establishment * @retval kErrorInvalidState Transport is not ready. + * @retval kErrorNoBufs Has reached max number of allowed connection attempts. */ Error Connect(const Ip6::SockAddr &aSockAddr); @@ -213,6 +222,8 @@ class SecureSession : private NonCopyable protected: explicit SecureSession(SecureTransport &aTransport); + bool IsSessionInUse(void) const { return (mNext != this); } + private: static constexpr uint32_t kGuardTimeNewConnectionMilli = 2000; static constexpr uint16_t kMaxContentLen = OPENTHREAD_CONFIG_DTLS_MAX_CONTENT_LEN; @@ -232,14 +243,18 @@ class SecureSession : private NonCopyable kStateDisconnecting, }; + void Init(void); bool IsDisconnected(void) const { return mState == kStateDisconnected; } bool IsInitializing(void) const { return mState == kStateInitializing; } bool IsConnecting(void) const { return mState == kStateConnecting; } bool IsDisconnecting(void) const { return mState == kStateDisconnecting; } bool IsConnectingOrConnected(void) const { return mState == kStateConnecting || mState == kStateConnected; } + void MarkAsNotUsed(void) { mNext = this; } void SetState(State aState); - bool Matches(const Ip6::MessageInfo &aInfo) { return mMessageInfo.HasSamePeerAddrAndPort(aInfo); } - void HandleTransportReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo); + bool Matches(const Ip6::MessageInfo &aInfo) const { return mMessageInfo.HasSamePeerAddrAndPort(aInfo); } + bool Matches(State aState) const { return (mState == aState); } + void Accept(Message &aMessage, const Ip6::MessageInfo &aMessageInfo); + void HandleTransportReceive(Message &aMessage); Error Setup(void); void Disconnect(ConnectEvent aEvent); void HandleTimer(TimeMilli aNow); @@ -262,11 +277,13 @@ class SecureSession : private NonCopyable #endif bool mTimerSet : 1; + bool mIsServer : 1; State mState; Message::SubType mMessageSubType; ConnectEvent mConnectEvent; TimeMilli mTimerIntermediate; TimeMilli mTimerFinish; + SecureSession *mNext; SecureTransport &mTransport; Message *mReceiveMessage; Ip6::MessageInfo mMessageInfo; @@ -306,6 +323,30 @@ class SecureTransport : private NonCopyable */ typedef void (*AutoCloseCallback)(void *aContext); + /** + * Callback to accept a new session connection request, providing the secure session to use. + * + * This method returns a pointer to a new `SecureSession` to use for the new session. The `SecureTransport` takes + * over the ownership of the given `SecureSession`. Once the session is disconnected and removed from the transport, + * the secure transport signals this using the `RemoveSessionCallback` callback, where ownership is released. + * + * `nullptr` can be returned to reject the new session connection request. + * + * @param[in] aContex A pointer to arbitrary context information. + * @param[in] aMessageInfo The message info from the new session connection request message. + * + * @returns A pointer to `SecureSession` to use for new session or `nullptr` if new connection is rejected. + */ + typedef SecureSession *(*AcceptCallback)(void *aContext, const Ip6::MessageInfo &aMessageInfo); + + /** + * Callback to signal a session is removed, releasing the ownership of the session (by `SecureTransport`). + * + * @param[in] aContex A pointer to arbitrary context information. + * @param[in] aSesssion The session being removed. + */ + typedef void (*RemoveSessionCallback)(void *aContext, SecureSession &aSesssion); + #if OPENTHREAD_CONFIG_TLS_API_ENABLE /** * Represents an API extension for a `SecureTransport` (DTLS or TLS). @@ -560,6 +601,25 @@ class SecureTransport : private NonCopyable */ Error SetMaxConnectionAttempts(uint16_t aMaxAttempts, AutoCloseCallback aCallback, void *aContext); + /** + * Sets the `AcceptCallback` used to accept new session connection requests. + * + * @param[in] aCallback The `AcceptCallback`. + * @param[in] aConext A pointer to arbitrary context to use with `AcceptCallback`. + */ + void SetAcceptCallback(AcceptCallback aCallback, void *aContext) { mAcceptCallback.Set(aCallback, aContext); } + + /** + * Sets the `RemoveSessionCallback` used to signal when a session is removed. + * + * @param[in] aCallback The `RemoveSessionCallback`. + * @param[in] aConext A pointer to arbitrary context to use with `RemoveSessionCallback`. + */ + void SetRemoveSessionCallback(RemoveSessionCallback aCallback, void *aContext) + { + mRemoveSessionCallback.Set(aCallback, aContext); + } + /** * Binds this DTLS to a UDP port. * @@ -629,11 +689,16 @@ class SecureTransport : private NonCopyable */ void HandleReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo); + /** + * Get the list of sessions associated with the `SecureTransport`. + * + * @returns The list of associated sessions. + */ + LinkedList &GetSessions(void) { return mSessions; } + protected: SecureTransport(Instance &aInstance, LinkSecurityMode aLayerTwoSecurity, bool aDatagramTransport); - void SetSession(SecureSession &aSesssion) { mSession = &aSesssion; } - #if OPENTHREAD_CONFIG_TLS_API_ENABLE void SetExtension(Extension &aExtension) { mExtension = &aExtension; } #endif @@ -655,6 +720,7 @@ class SecureTransport : private NonCopyable kUnspecifiedCipherSuite, }; + void RemoveDisconnectedSessions(void); void DecremenetRemainingConnectionAttempts(void); bool HasNoRemainingConnectionAttempts(void) const; int Transmit(const unsigned char *aBuf, @@ -700,6 +766,8 @@ class SecureTransport : private NonCopyable #endif // (MBEDTLS_VERSION_NUMBER >= 0x03000000) #endif // MBEDTLS_SSL_EXPORT_KEYS + static void HandleUpdateTask(Tasklet &aTasklet); + void HandleUpdateTask(void); static void HandleTimer(Timer &aTimer); void HandleTimer(void); @@ -721,21 +789,24 @@ class SecureTransport : private NonCopyable static const int kCipherSuites[][2]; - bool mLayerTwoSecurity : 1; - bool mDatagramTransport : 1; - bool mIsOpen : 1; - bool mIsServer : 1; - bool mVerifyPeerCertificate : 1; - CipherSuite mCipherSuite; - uint8_t mPskLength; - uint16_t mMaxConnectionAttempts; - uint16_t mRemainingConnectionAttempts; - SecureSession *mSession; - TransportSocket mSocket; - uint8_t mPsk[kPskMaxLength]; - TimerMilliContext mTimer; - Callback mAutoCloseCallback; - Callback mTransportCallback; + bool mLayerTwoSecurity : 1; + bool mDatagramTransport : 1; + bool mIsOpen : 1; + bool mIsClosing : 1; + bool mVerifyPeerCertificate : 1; + CipherSuite mCipherSuite; + uint8_t mPskLength; + uint16_t mMaxConnectionAttempts; + uint16_t mRemainingConnectionAttempts; + LinkedList mSessions; + TransportSocket mSocket; + uint8_t mPsk[kPskMaxLength]; + TimerMilliContext mTimer; + TaskletContext mUpdateTask; + Callback mAutoCloseCallback; + Callback mAcceptCallback; + Callback mRemoveSessionCallback; + Callback mTransportCallback; #if OPENTHREAD_CONFIG_TLS_API_ENABLE Extension *mExtension; #endif @@ -767,9 +838,6 @@ class Dtls : SecureTransport(aInstance, aLayerTwoSecurity, /* aDatagramTransport */ true) { } - - private: - void SetSession(Session &aSesssion) { SecureTransport::SetSession(aSesssion); } }; /** @@ -786,7 +854,6 @@ class Dtls Session(Transport &aTransport) : SecureSession(aTransport) { - aTransport.SetSession(*this); } /** @@ -817,9 +884,13 @@ class Tls : public SecureTransport, public SecureSession : SecureTransport(aInstance, aLayerTwoSecurity, /* aDatagramTransport */ false) , SecureSession(*static_cast(this)) { - SetSession(*static_cast(this)); SetExtension(aExtension); + SetAcceptCallback(&HandleAccept, this); } + +private: + static SecureSession *HandleAccept(void *aContext, const Ip6::MessageInfo &aMessageInfo); + SecureSession *HandleAccept(void); }; #endif diff --git a/src/core/thread/tmf.cpp b/src/core/thread/tmf.cpp index d6cf5407835..37b7f763c90 100644 --- a/src/core/thread/tmf.cpp +++ b/src/core/thread/tmf.cpp @@ -276,11 +276,25 @@ SecureAgent::SecureAgent(Instance &aInstance) : Coap::Dtls::Transport(aInstance, kNoLinkSecurity) , Coap::SecureSession(aInstance, static_cast(*this)) { + SetAcceptCallback(&HandleDtlsAccept, this); + #if OPENTHREAD_FTD && OPENTHREAD_CONFIG_COMMISSIONER_ENABLE SetResourceHandler(&HandleResource); #endif } +MeshCoP::SecureSession *SecureAgent::HandleDtlsAccept(void *aContext, const Ip6::MessageInfo &aMessageInfo) +{ + OT_UNUSED_VARIABLE(aMessageInfo); + + return static_cast(aContext)->HandleDtlsAccept(); +} + +Coap::SecureSession *SecureAgent::HandleDtlsAccept(void) +{ + return IsSessionInUse() ? nullptr : static_cast(this); +} + #if OPENTHREAD_FTD && OPENTHREAD_CONFIG_COMMISSIONER_ENABLE bool SecureAgent::HandleResource(CoapBase &aCoapBase, diff --git a/src/core/thread/tmf.hpp b/src/core/thread/tmf.hpp index 386f133bce5..31fecd9cc2b 100644 --- a/src/core/thread/tmf.hpp +++ b/src/core/thread/tmf.hpp @@ -208,6 +208,9 @@ class SecureAgent : public Coap::Dtls::Transport, public Coap::SecureSession explicit SecureAgent(Instance &aInstance); private: + static MeshCoP::SecureSession *HandleDtlsAccept(void *aContext, const Ip6::MessageInfo &aMessageInfo); + Coap::SecureSession *HandleDtlsAccept(void); + #if OPENTHREAD_FTD && OPENTHREAD_CONFIG_COMMISSIONER_ENABLE static bool HandleResource(CoapBase &aCoapBase, const char *aUriPath, diff --git a/tests/nexus/test_dtls.cpp b/tests/nexus/test_dtls.cpp index 9a93106cfd3..6f52b0b7fd3 100644 --- a/tests/nexus/test_dtls.cpp +++ b/tests/nexus/test_dtls.cpp @@ -48,6 +48,7 @@ static const uint8_t kPsk[] = {0x10, 0x20, 0x03, 0x15, 0x10, 0x00, 0x60, 0x16}; static Dtls::Session::ConnectEvent sDtlsEvent[kMaxNodes]; static Array sDtlsLastReceive[kMaxNodes]; static bool sDtlsAutoClosed[kMaxNodes]; +static uint32_t sHeapSessionsAllocated = 0; const char *ConnectEventToString(Dtls::Session::ConnectEvent aEvent) { @@ -131,24 +132,113 @@ OwnedPtr PrepareMessage(Node &aNode) return OwnedPtr(message); } -class DtlsTransportAndSession : public InstanceLocator, public Dtls::Transport, public Dtls::Session +class DtlsTransportAndSingleSession : public InstanceLocator, public Dtls::Transport, public Dtls::Session { + // A DTLS transport and single session public: - explicit DtlsTransportAndSession(Node &aNode) + explicit DtlsTransportAndSingleSession(Node &aNode) : InstanceLocator(aNode.GetInstance()) , Dtls::Transport(aNode.GetInstance(), kWithLinkSecurity) , Dtls::Session(static_cast(*this)) + , mNode(aNode) { + SetAcceptCallback(HandleAccept, this); + + VerifyOrQuit(!IsSessionInUse()); + } + +private: + static MeshCoP::SecureSession *HandleAccept(void *aContext, const Ip6::MessageInfo &aMessageInfo) + { + return static_cast(aContext)->HandleAccept(); } + + Dtls::Session *HandleAccept(void) + { + Dtls::Session *session = IsSessionInUse() ? nullptr : static_cast(this); + + Log(" node%u: HandleAccept(), %s", mNode.GetId(), (session != nullptr) ? "accepted" : "rejected"); + return session; + } + + Node &mNode; }; -void TestDtls(void) +class DtlsTransportAndHeapSession : public InstanceLocator, public Dtls::Transport +{ + // A DTLS session with heap allocated sessions. + +public: + explicit DtlsTransportAndHeapSession(Node &aNode) + : InstanceLocator(aNode.GetInstance()) + , Dtls::Transport(aNode.GetInstance(), kWithLinkSecurity) + , mNode(aNode) + { + SetAcceptCallback(HandleAccept, this); + SetRemoveSessionCallback(HandleRemoveSession, this); + } + +private: + class HeapDtlsSession : public Dtls::Session, public Heap::Allocatable + { + friend Heap::Allocatable; + + private: + HeapDtlsSession(Dtls::Transport &aTransport) + : Dtls::Session(aTransport) + { + sHeapSessionsAllocated++; + } + }; + + static MeshCoP::SecureSession *HandleAccept(void *aContext, const Ip6::MessageInfo &aMessageInfo) + { + DtlsTransportAndHeapSession *transport; + HeapDtlsSession *session; + + VerifyOrQuit(aContext != nullptr); + transport = static_cast(aContext); + + Log(" node%u: HandleAccept()", transport->mNode.GetId()); + + session = HeapDtlsSession::Allocate(*transport); + VerifyOrQuit(session != nullptr); + + session->SetReceiveCallback(&ot::Nexus::HandleReceive, &transport->mNode); + session->SetConnectCallback(&ot::Nexus::HandleConnectEvent, &transport->mNode); + + return session; + } + + static void HandleRemoveSession(void *aContext, MeshCoP::SecureSession &aSesssion) + { + DtlsTransportAndHeapSession *transport; + + VerifyOrQuit(aContext != nullptr); + transport = static_cast(aContext); + + Log(" node%u: HandleRemoveSession()", transport->mNode.GetId()); + + VerifyOrQuit(sHeapSessionsAllocated > 0); + + static_cast(aSesssion).Free(); + sHeapSessionsAllocated--; + } + +private: + Node &mNode; +}; + +void TestDtlsSingleSession(void) { Core nexus; Node &node0 = nexus.CreateNode(); Node &node1 = nexus.CreateNode(); Node &node2 = nexus.CreateNode(); + Log("------------------------------------------------------------------------------------------------------"); + Log("TestDtlsSingleSession"); + nexus.AdvanceTime(0); // Form the topology: node0 leader, with node1 & node2 as its FTD children @@ -167,13 +257,11 @@ void TestDtls(void) nexus.AdvanceTime(20 * Time::kOneSecondInMsec); VerifyOrQuit(node2.Get().IsChild()); - Log("------------------------------------------------------------------------------------------------------"); - { - DtlsTransportAndSession dtls0(node0); - DtlsTransportAndSession dtls1(node1); - DtlsTransportAndSession dtls2(node2); - Ip6::SockAddr sockAddr; + DtlsTransportAndSingleSession dtls0(node0); + DtlsTransportAndSingleSession dtls1(node1); + DtlsTransportAndSingleSession dtls2(node2); + Ip6::SockAddr sockAddr; // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - Log("Start DTLS (server) on node0 bound to port %u", kUdpPort); @@ -370,7 +458,7 @@ void TestDtls(void) SuccessOrQuit(dtls1.Connect(sockAddr)); nexus.AdvanceTime(3 * Time::kOneSecondInMsec); - VerifyOrQuit(sDtlsEvent[node0.GetId()] == Dtls::Session::kDisconnectedMaxAttempts); + VerifyOrQuit(sDtlsEvent[node0.GetId()] == Dtls::Session::kDisconnectedError); VerifyOrQuit(sDtlsEvent[node1.GetId()] == Dtls::Session::kDisconnectedError); VerifyOrQuit(sDtlsAutoClosed[node0.GetId()]); @@ -383,12 +471,190 @@ void TestDtls(void) } } +void TestDtlsMultiSession(void) +{ + Core nexus; + Node &node0 = nexus.CreateNode(); + Node &node1 = nexus.CreateNode(); + Node &node2 = nexus.CreateNode(); + + Log("------------------------------------------------------------------------------------------------------"); + Log("TestDtlsMultiSession"); + + nexus.AdvanceTime(0); + + // Form the topology: node0 leader, with node1 & node2 as its FTD children + + node0.Form(); + nexus.AdvanceTime(50 * Time::kOneSecondInMsec); + VerifyOrQuit(node0.Get().IsLeader()); + + SuccessOrQuit(node1.Get().SetRouterEligible(false)); + node1.Join(node0); + nexus.AdvanceTime(20 * Time::kOneSecondInMsec); + VerifyOrQuit(node1.Get().IsChild()); + + SuccessOrQuit(node2.Get().SetRouterEligible(false)); + node2.Join(node0); + nexus.AdvanceTime(20 * Time::kOneSecondInMsec); + VerifyOrQuit(node2.Get().IsChild()); + + { + DtlsTransportAndHeapSession dtls0(node0); + DtlsTransportAndSingleSession dtls1(node1); + DtlsTransportAndSingleSession dtls2(node2); + Ip6::SockAddr sockAddr; + uint16_t numSessions; + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + Log("Start DTLS (server) on node0 bound to port %u", kUdpPort); + + SuccessOrQuit(dtls0.SetPsk(kPsk, sizeof(kPsk))); + SuccessOrQuit(dtls0.Open()); + SuccessOrQuit(dtls0.Bind(kUdpPort)); + + nexus.AdvanceTime(1 * Time::kOneSecondInMsec); + + VerifyOrQuit(dtls0.GetUdpPort() == kUdpPort); + + sockAddr.SetAddress(node0.Get().GetMeshLocalRloc()); + sockAddr.SetPort(kUdpPort); + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + Log("Establish a DTLS connection with node 0 from node1"); + + memset(sDtlsEvent, Dtls::Session::kDisconnectedError, sizeof(sDtlsEvent)); + + SuccessOrQuit(dtls1.SetPsk(kPsk, sizeof(kPsk))); + dtls1.SetReceiveCallback(HandleReceive, &node1); + dtls1.SetConnectCallback(HandleConnectEvent, &node1); + SuccessOrQuit(dtls1.Open()); + SuccessOrQuit(dtls1.Connect(sockAddr)); + + nexus.AdvanceTime(1 * Time::kOneSecondInMsec); + + VerifyOrQuit(dtls1.IsConnected()); + + VerifyOrQuit(sDtlsEvent[node0.GetId()] == Dtls::Session::kConnected); + VerifyOrQuit(sDtlsEvent[node1.GetId()] == Dtls::Session::kConnected); + + numSessions = 0; + + for (MeshCoP::SecureSession &session : dtls0.GetSessions()) + { + VerifyOrQuit(session.IsConnected()); + numSessions++; + } + + VerifyOrQuit(numSessions == 1); + VerifyOrQuit(sHeapSessionsAllocated == 1); + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + Log("Establish a second DTLS connection with node0 from node2"); + + memset(sDtlsEvent, Dtls::Session::kDisconnectedError, sizeof(sDtlsEvent)); + + SuccessOrQuit(dtls2.SetPsk(kPsk, sizeof(kPsk))); + dtls2.SetReceiveCallback(HandleReceive, &node2); + dtls2.SetConnectCallback(HandleConnectEvent, &node2); + SuccessOrQuit(dtls2.Open()); + SuccessOrQuit(dtls2.Connect(sockAddr)); + + nexus.AdvanceTime(1 * Time::kOneSecondInMsec); + + VerifyOrQuit(dtls2.IsConnected()); + + VerifyOrQuit(sDtlsEvent[node0.GetId()] == Dtls::Session::kConnected); + VerifyOrQuit(sDtlsEvent[node2.GetId()] == Dtls::Session::kConnected); + + numSessions = 0; + + for (MeshCoP::SecureSession &session : dtls0.GetSessions()) + { + VerifyOrQuit(session.IsConnected()); + numSessions++; + } + + VerifyOrQuit(numSessions == 2); + VerifyOrQuit(sHeapSessionsAllocated == 2); + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + Log("Disconnect from node1 - validate the disconnect events"); + + dtls1.Disconnect(); + + nexus.AdvanceTime(3 * Time::kOneSecondInMsec); + + VerifyOrQuit(!dtls1.IsConnected()); + + VerifyOrQuit(sDtlsEvent[node0.GetId()] == Dtls::Session::kDisconnectedPeerClosed); + VerifyOrQuit(sDtlsEvent[node1.GetId()] == Dtls::Session::kDisconnectedLocalClosed); + + numSessions = 0; + + for (MeshCoP::SecureSession &session : dtls0.GetSessions()) + { + VerifyOrQuit(session.IsConnected()); + numSessions++; + } + + VerifyOrQuit(numSessions == 1); + VerifyOrQuit(sHeapSessionsAllocated == 1); + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + Log("Disconnect session with node2 from node0 (server) - validate the disconnect events"); + + memset(sDtlsEvent, Dtls::Session::kConnected, sizeof(sDtlsEvent)); + + dtls0.GetSessions().GetHead()->Disconnect(); + + nexus.AdvanceTime(3 * Time::kOneSecondInMsec); + + VerifyOrQuit(!dtls2.IsConnected()); + + VerifyOrQuit(sDtlsEvent[node0.GetId()] == Dtls::Session::kDisconnectedLocalClosed); + VerifyOrQuit(sDtlsEvent[node2.GetId()] == Dtls::Session::kDisconnectedPeerClosed); + + VerifyOrQuit(dtls0.GetSessions().IsEmpty()); + VerifyOrQuit(sHeapSessionsAllocated == 0); + + // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + Log("Establish two DTLS connections from node1 and node2 at the same time"); + + memset(sDtlsEvent, Dtls::Session::kDisconnectedError, sizeof(sDtlsEvent)); + + SuccessOrQuit(dtls1.Connect(sockAddr)); + SuccessOrQuit(dtls2.Connect(sockAddr)); + + nexus.AdvanceTime(1 * Time::kOneSecondInMsec); + + VerifyOrQuit(dtls1.IsConnected()); + VerifyOrQuit(dtls2.IsConnected()); + + VerifyOrQuit(sDtlsEvent[node0.GetId()] == Dtls::Session::kConnected); + VerifyOrQuit(sDtlsEvent[node1.GetId()] == Dtls::Session::kConnected); + VerifyOrQuit(sDtlsEvent[node2.GetId()] == Dtls::Session::kConnected); + + numSessions = 0; + + for (MeshCoP::SecureSession &session : dtls0.GetSessions()) + { + VerifyOrQuit(session.IsConnected()); + numSessions++; + } + + VerifyOrQuit(numSessions == 2); + VerifyOrQuit(sHeapSessionsAllocated == 2); + } +} + } // namespace Nexus } // namespace ot int main(void) { - ot::Nexus::TestDtls(); + ot::Nexus::TestDtlsSingleSession(); + ot::Nexus::TestDtlsMultiSession(); printf("All tests passed\n"); return 0; }