diff --git a/CMakeLists.txt b/CMakeLists.txt index afa26c59b..90cc4e6c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,19 +44,10 @@ else() target_compile_options(usrsctp-static PRIVATE -Wno-error=address-of-packed-member -Wno-error=format-truncation) endif() +option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF) -find_package(GnuTLS REQUIRED) find_package(LibNice REQUIRED) -if(NOT TARGET GnuTLS::GnuTLS) - add_library(GnuTLS::GnuTLS UNKNOWN IMPORTED) - set_target_properties(GnuTLS::GnuTLS PROPERTIES - INTERFACE_INCLUDE_DIRECTORIES "${GNUTLS_INCLUDE_DIRS}" - INTERFACE_COMPILE_DEFINITIONS "${GNUTLS_DEFINITIONS}" - IMPORTED_LINK_INTERFACE_LANGUAGES "C" - IMPORTED_LOCATION "${GNUTLS_LIBRARIES}") -endif() - add_library(datachannel SHARED ${LIBDATACHANNEL_SOURCES}) set_target_properties(datachannel PROPERTIES VERSION ${PROJECT_VERSION} @@ -65,7 +56,7 @@ set_target_properties(datachannel PROPERTIES target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc) target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) -target_link_libraries(datachannel usrsctp-static GnuTLS::GnuTLS LibNice::LibNice) +target_link_libraries(datachannel usrsctp-static LibNice::LibNice) add_library(datachannel-static STATIC EXCLUDE_FROM_ALL ${LIBDATACHANNEL_SOURCES}) set_target_properties(datachannel-static PROPERTIES @@ -75,7 +66,29 @@ set_target_properties(datachannel-static PROPERTIES target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include) target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc) target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src) -target_link_libraries(datachannel-static usrsctp-static GnuTLS::GnuTLS LibNice::LibNice) +target_link_libraries(datachannel-static usrsctp-static LibNice::LibNice) + +if (USE_GNUTLS) + find_package(GnuTLS REQUIRED) + if(NOT TARGET GnuTLS::GnuTLS) + add_library(GnuTLS::GnuTLS UNKNOWN IMPORTED) + set_target_properties(GnuTLS::GnuTLS PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${GNUTLS_INCLUDE_DIRS}" + INTERFACE_COMPILE_DEFINITIONS "${GNUTLS_DEFINITIONS}" + IMPORTED_LINK_INTERFACE_LANGUAGES "C" + IMPORTED_LOCATION "${GNUTLS_LIBRARIES}") + endif() + target_compile_definitions(datachannel PRIVATE USE_GNUTLS=1) + target_link_libraries(datachannel GnuTLS::GnuTLS) + target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=1) + target_link_libraries(datachannel-static GnuTLS::GnuTLS) +else() + find_package(OpenSSL REQUIRED) + target_compile_definitions(datachannel PRIVATE USE_GNUTLS=0) + target_link_libraries(datachannel OpenSSL::SSL) + target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=0) + target_link_libraries(datachannel-static OpenSSL::SSL) +endif() add_library(LibDataChannel::LibDataChannel ALIAS datachannel) add_library(LibDataChannel::LibDataChannelStatic ALIAS datachannel-static) diff --git a/Jamfile b/Jamfile index 45f525a12..97a06a3ce 100644 --- a/Jamfile +++ b/Jamfile @@ -6,13 +6,15 @@ lib libdatachannel [ glob ./src/*.cpp ] : # requirements ./include/rtc - "`pkg-config --cflags gnutls glib-2.0 gobject-2.0 nice`" + USE_GNUTLS=0 + "`pkg-config --cflags openssl glib-2.0 gobject-2.0 nice`" /libdatachannel//usrsctp : # default build static : # usage requirements ./include - "`pkg-config --libs gnutls glib-2.0 gobject-2.0 nice`" + -pthread + "`pkg-config --libs openssl glib-2.0 gobject-2.0 nice`" ; alias usrsctp diff --git a/Makefile b/Makefile index ad1e9d871..b6140f3e2 100644 --- a/Makefile +++ b/Makefile @@ -7,12 +7,21 @@ RM=rm -f CPPFLAGS=-O2 -pthread -fPIC -Wall -Wno-address-of-packed-member CXXFLAGS=-std=c++17 LDFLAGS=-pthread -LIBS=gnutls glib-2.0 gobject-2.0 nice +LIBS=glib-2.0 gobject-2.0 nice +USRSCTP_DIR=usrsctp + +USE_GNUTLS ?= 0 +ifeq ($(USE_GNUTLS), 1) + CPPFLAGS+= -DUSE_GNUTLS=1 + LIBS+= gnutls +else + CPPFLAGS+= -DUSE_GNUTLS=0 + LIBS+= openssl +endif + LDLIBS= $(shell pkg-config --libs $(LIBS)) INCLUDES=-Iinclude/rtc -I$(USRSCTP_DIR)/usrsctplib $(shell pkg-config --cflags $(LIBS)) -USRSCTP_DIR:=usrsctp - SRCS=$(shell printf "%s " src/*.cpp) OBJS=$(subst .cpp,.o,$(SRCS)) diff --git a/README.md b/README.md index 79fd35ee4..e3329e4cb 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ The library aims at fully implementing SCTP DataChannels ([draft-ietf-rtcweb-dat ## Dependencies - libnice: https://github.com/libnice/libnice -- GnuTLS: https://www.gnutls.org/ +- GnuTLS: https://www.gnutls.org/ or OpenSSL: https://www.openssl.org/ Submodules: - usrsctp: https://github.com/sctplab/usrsctp @@ -24,7 +24,7 @@ Submodules: $ git submodule update --init --recursive $ mkdir build $ cd build -$ cmake .. +$ cmake -DUSE_GNUTLS=1 .. $ make ``` diff --git a/src/certificate.cpp b/src/certificate.cpp index 28bc26dec..c96341195 100644 --- a/src/certificate.cpp +++ b/src/certificate.cpp @@ -25,10 +25,13 @@ #include #include -#include - using std::shared_ptr; using std::string; +using std::unique_ptr; + +#if USE_GNUTLS + +#include namespace { @@ -117,10 +120,10 @@ Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey) "Unable to set certificate and key pair in credentials"); } -string Certificate::fingerprint() const { return mFingerprint; } - gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; } +string Certificate::fingerprint() const { return mFingerprint; } + string make_fingerprint(gnutls_x509_crt_t crt) { const size_t size = 32; unsigned char buffer[size]; @@ -177,3 +180,120 @@ shared_ptr make_certificate(const string &commonName) { } } // namespace rtc + +#else + +#include +#include +#include + +namespace rtc { + +Certificate::Certificate(string crt_pem, string key_pem) { + BIO *bio; + + bio = BIO_new(BIO_s_mem()); + BIO_write(bio, crt_pem.data(), crt_pem.size()); + mX509 = shared_ptr(PEM_read_bio_X509(bio, nullptr, 0, 0), X509_free); + BIO_free(bio); + if (!mX509) + throw std::invalid_argument("Unable to import certificate PEM"); + + bio = BIO_new(BIO_s_mem()); + BIO_write(bio, key_pem.data(), key_pem.size()); + mPKey = shared_ptr(PEM_read_bio_PrivateKey(bio, nullptr, 0, 0), EVP_PKEY_free); + BIO_free(bio); + if (!mPKey) + throw std::invalid_argument("Unable to import PEM key PEM"); + + mFingerprint = make_fingerprint(mX509.get()); +} + +Certificate::Certificate(shared_ptr x509, shared_ptr pkey) : + mX509(std::move(x509)), mPKey(std::move(pkey)) +{ + mFingerprint = make_fingerprint(mX509.get()); +} + +string Certificate::fingerprint() const { return mFingerprint; } + +std::tuple Certificate::credentials() const { return {mX509.get(), mPKey.get()}; } + +string make_fingerprint(X509 *x509) { + const size_t size = 32; + unsigned char buffer[size]; + unsigned int len = size; + if (!X509_digest(x509, EVP_sha256(), buffer, &len)) + throw std::runtime_error("X509 fingerprint error"); + + std::ostringstream oss; + oss << std::hex << std::uppercase << std::setfill('0'); + for (size_t i = 0; i < len; ++i) { + if (i) + oss << std::setw(1) << ':'; + oss << std::setw(2) << unsigned(buffer[i]); + } + return oss.str(); +} + + +shared_ptr make_certificate(const string &commonName) { + static std::unordered_map> cache; + static std::mutex cacheMutex; + + std::lock_guard lock(cacheMutex); + if (auto it = cache.find(commonName); it != cache.end()) + return it->second; + + if (cache.empty()) { + // This is the first call to OpenSSL + OPENSSL_init_ssl(0, NULL); + SSL_load_error_strings(); + ERR_load_crypto_strings(); + } + + shared_ptr x509(X509_new(), X509_free); + shared_ptr pkey(EVP_PKEY_new(), EVP_PKEY_free); + + unique_ptr rsa(RSA_new(), RSA_free); + unique_ptr exponent(BN_new(), BN_free); + unique_ptr serial_number(BN_new(), BN_free); + unique_ptr name(X509_NAME_new(), X509_NAME_free); + + if (!x509 || !pkey || !rsa || !exponent || !serial_number || !name) + throw std::runtime_error("Unable allocate structures for certificate generation"); + + const int bits = 4096; + const unsigned int e = 65537; // 2^16 + 1 + + if (!pkey || !rsa || !exponent || !BN_set_word(exponent.get(), e) || + !RSA_generate_key_ex(rsa.get(), bits, exponent.get(), NULL) || + !EVP_PKEY_assign_RSA(pkey.get(), rsa.release())) // the key will be freed when pkey is freed + throw std::runtime_error("Unable to generate key pair"); + + const size_t serialSize = 16; + const auto *commonNameBytes = reinterpret_cast(commonName.c_str()); + + if (!X509_gmtime_adj(X509_get_notBefore(x509.get()), 3600 * -1) || + !X509_gmtime_adj(X509_get_notAfter(x509.get()), 3600 * 24 * 365) || + !X509_set_version(x509.get(), 1) || !X509_set_pubkey(x509.get(), pkey.get()) || + !BN_pseudo_rand(serial_number.get(), serialSize, 0, 0) || + !BN_to_ASN1_INTEGER(serial_number.get(), X509_get_serialNumber(x509.get())) || + !X509_NAME_add_entry_by_NID(name.get(), NID_commonName, MBSTRING_UTF8, commonNameBytes, -1, + -1, 0) || + !X509_set_subject_name(x509.get(), name.get()) || + !X509_set_issuer_name(x509.get(), name.get())) + throw std::runtime_error("Unable to set certificate properties"); + + if (!X509_sign(x509.get(), pkey.get(), EVP_sha256())) + throw std::runtime_error("Unable to auto-sign certificate"); + + auto certificate = std::make_shared(x509, pkey); + cache.emplace(std::make_pair(commonName, certificate)); + return certificate; +} + +} // namespace rtc + +#endif + diff --git a/src/certificate.hpp b/src/certificate.hpp index a50ddfef5..cda2fb516 100644 --- a/src/certificate.hpp +++ b/src/certificate.hpp @@ -21,24 +21,47 @@ #include "include.hpp" +#include + +#if USE_GNUTLS #include +#else +#include +#endif namespace rtc { class Certificate { public: - Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey); Certificate(string crt_pem, string key_pem); - string fingerprint() const; +#if USE_GNUTLS + Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey); gnutls_certificate_credentials_t credentials() const; +#else + Certificate(std::shared_ptr x509, std::shared_ptr pkey); + std::tuple credentials() const; +#endif + + string fingerprint() const; private: +#if USE_GNUTLS std::shared_ptr mCredentials; +#else + std::shared_ptr mX509; + std::shared_ptr mPKey; +#endif + string mFingerprint; }; +#if USE_GNUTLS string make_fingerprint(gnutls_x509_crt_t crt); +#else +string make_fingerprint(X509 *x509); +#endif + std::shared_ptr make_certificate(const string &commonName); } // namespace rtc diff --git a/src/dtlstransport.cpp b/src/dtlstransport.cpp index faaf0a931..02d7dedd6 100644 --- a/src/dtlstransport.cpp +++ b/src/dtlstransport.cpp @@ -24,10 +24,14 @@ #include #include -#include - using std::shared_ptr; using std::string; +using std::unique_ptr; +using std::weak_ptr; + +#if USE_GNUTLS + +#include namespace { @@ -44,8 +48,6 @@ static bool check_gnutls(int ret, const string &message = "GnuTLS error") { namespace rtc { -using std::shared_ptr; - DtlsTransport::DtlsTransport(shared_ptr lower, shared_ptr certificate, verifier_callback verifierCallback, state_callback stateChangeCallback) @@ -61,7 +63,7 @@ DtlsTransport::DtlsTransport(shared_ptr lower, shared_ptr +#include +#include +#include + +namespace { + +const int BIO_EOF = -1; + +string openssl_error_string(unsigned long err) { + const size_t bufferSize = 256; + char buffer[bufferSize]; + ERR_error_string_n(err, buffer, bufferSize); + return string(buffer); +} + +bool check_openssl(int success, const string &message = "OpenSSL error") { + if (success) + return true; + else + throw std::runtime_error(message + ": " + openssl_error_string(ERR_get_error())); +} + +bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") { + if (ret == BIO_EOF) + return true; + + unsigned long err = SSL_get_error(ssl, ret); + if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) + return true; + else if (err == SSL_ERROR_ZERO_RETURN) + return false; + else + throw std::runtime_error(message + ": " + openssl_error_string(err)); +} + +} // namespace + +namespace rtc { + +int DtlsTransport::TransportExIndex = -1; +std::mutex DtlsTransport::GlobalMutex; + +void DtlsTransport::GlobalInit() { + std::lock_guard lock(GlobalMutex); + if (TransportExIndex < 0) { + TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL); + } +} + +DtlsTransport::DtlsTransport(shared_ptr lower, shared_ptr certificate, + verifier_callback verifierCallback, state_callback stateChangeCallback) + : Transport(lower), mCertificate(certificate), mState(State::Disconnected), + mVerifierCallback(std::move(verifierCallback)), + mStateChangeCallback(std::move(stateChangeCallback)) { + + GlobalInit(); + + if (!(mCtx = SSL_CTX_new(DTLS_method()))) + throw std::runtime_error("Unable to create SSL context"); + + check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"), + "Unable to set SSL priorities"); + + // RFC 8261: SCTP performs segmentation and reassembly based on the path MTU. + // Therefore, the DTLS layer MUST NOT use any compression algorithm. + // See https://tools.ietf.org/html/rfc8261#section-5 + SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION); + SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION); + SSL_CTX_set_read_ahead(mCtx, 1); + SSL_CTX_set_quiet_shutdown(mCtx, 1); + SSL_CTX_set_info_callback(mCtx, InfoCallback); + SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + CertificateCallback); + SSL_CTX_set_verify_depth(mCtx, 1); + + X509 *x509; + EVP_PKEY *pkey; + std::tie(x509, pkey) = mCertificate->credentials(); + SSL_CTX_use_certificate(mCtx, x509); + SSL_CTX_use_PrivateKey(mCtx, pkey); + + check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed"); + + if (!(mSsl = SSL_new(mCtx))) + throw std::runtime_error("Unable to create SSL instance"); + + SSL_set_ex_data(mSsl, TransportExIndex, this); + SSL_set_mtu(mSsl, 1280 - 40 - 8); // min MTU over UDP/IPv6 + + if (lower->role() == Description::Role::Active) + SSL_set_connect_state(mSsl); + else + SSL_set_accept_state(mSsl); + + if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem()))) + throw std::runtime_error("Unable to create BIO"); + + BIO_set_mem_eof_return(mInBio, BIO_EOF); + BIO_set_mem_eof_return(mOutBio, BIO_EOF); + SSL_set_bio(mSsl, mInBio, mOutBio); + + auto ecdh = unique_ptr( + EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free); + SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE); + SSL_set_tmp_ecdh(mSsl, ecdh.get()); + + mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this); +} + +DtlsTransport::~DtlsTransport() { + onRecv(nullptr); // unset recv callback + + mIncomingQueue.stop(); + + if (mRecvThread.joinable()) + mRecvThread.join(); + + SSL_shutdown(mSsl); + SSL_free(mSsl); + SSL_CTX_free(mCtx); +} + +DtlsTransport::State DtlsTransport::state() const { return mState; } + +bool DtlsTransport::send(message_ptr message) { + const size_t bufferSize = 4096; + byte buffer[bufferSize]; + + if (!message || mState != State::Connected) + return false; + + int ret = SSL_write(mSsl, message->data(), message->size()); + if (!check_openssl_ret(mSsl, ret)) { + return false; + } + + while (BIO_ctrl_pending(mOutBio) > 0) { + int ret = BIO_read(mOutBio, buffer, bufferSize); + if (check_openssl_ret(mSsl, ret) && ret > 0) + outgoing(make_message(buffer, buffer + ret)); + } + + return true; +} + +void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); } + +void DtlsTransport::changeState(State state) { + if (mState.exchange(state) != state) + mStateChangeCallback(state); +} + +void DtlsTransport::runRecvLoop() { + const size_t bufferSize = 4096; + byte buffer[bufferSize]; + + try { + changeState(State::Connecting); + + SSL_do_handshake(mSsl); + while (BIO_ctrl_pending(mOutBio) > 0) { + int ret = BIO_read(mOutBio, buffer, bufferSize); + if (check_openssl_ret(mSsl, ret) && ret > 0) + outgoing(make_message(buffer, buffer + ret)); + } + + while (auto next = mIncomingQueue.pop()) { + auto message = *next; + BIO_write(mInBio, message->data(), message->size()); + int ret = SSL_read(mSsl, buffer, bufferSize); + if (!check_openssl_ret(mSsl, ret)) + break; + + auto decrypted = ret > 0 ? make_message(buffer, buffer + ret) : nullptr; + + if (mState == State::Connecting) { + if (unsigned long err = ERR_get_error()) + throw std::runtime_error("handshake failed: " + openssl_error_string(err)); + + while (BIO_ctrl_pending(mOutBio) > 0) { + ret = BIO_read(mOutBio, buffer, bufferSize); + if (check_openssl_ret(mSsl, ret) && ret > 0) + outgoing(make_message(buffer, buffer + ret)); + } + + if (SSL_is_init_finished(mSsl)) + changeState(State::Connected); + } + + if (decrypted) + recv(decrypted); + } + } catch (const std::exception &e) { + std::cerr << "DTLS recv: " << e.what() << std::endl; + } + + if (mState == State::Connected) { + changeState(State::Disconnected); + recv(nullptr); + } else { + changeState(State::Failed); + } +} + +int DtlsTransport::CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx) { + SSL *ssl = + static_cast(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); + DtlsTransport *t = + static_cast(SSL_get_ex_data(ssl, DtlsTransport::TransportExIndex)); + + X509 *crt = X509_STORE_CTX_get_current_cert(ctx); + std::string fingerprint = make_fingerprint(crt); + + return t->mVerifierCallback(fingerprint) ? 1 : 0; +} + +void DtlsTransport::InfoCallback(const SSL *ssl, int where, int ret) { + DtlsTransport *t = + static_cast(SSL_get_ex_data(ssl, DtlsTransport::TransportExIndex)); + + if (where & SSL_CB_ALERT) { + if (ret != 256) // Close Notify + std::cerr << "DTLS alert: " << SSL_alert_desc_string_long(ret) << std::endl; + t->mIncomingQueue.stop(); // Close the connection + } +} + +} // namespace rtc + +#endif + diff --git a/src/dtlstransport.hpp b/src/dtlstransport.hpp index 819d44806..27fe56eac 100644 --- a/src/dtlstransport.hpp +++ b/src/dtlstransport.hpp @@ -28,9 +28,14 @@ #include #include #include +#include #include +#if USE_GNUTLS #include +#else +#include +#endif namespace rtc { @@ -58,7 +63,6 @@ class DtlsTransport : public Transport { const std::shared_ptr mCertificate; - gnutls_session_t mSession; Queue mIncomingQueue; std::atomic mState; std::thread mRecvThread; @@ -66,10 +70,25 @@ class DtlsTransport : public Transport { verifier_callback mVerifierCallback; state_callback mStateChangeCallback; +#if USE_GNUTLS + gnutls_session_t mSession; + static int CertificateCallback(gnutls_session_t session); static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len); static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen); static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms); +#else + SSL_CTX *mCtx; + SSL *mSsl; + BIO *mInBio, *mOutBio; + + static int TransportExIndex; + static std::mutex GlobalMutex; + + static void GlobalInit(); + static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx); + static void InfoCallback(const SSL *ssl, int where, int ret); +#endif }; } // namespace rtc diff --git a/src/icetransport.cpp b/src/icetransport.cpp index b43e1976c..5ad11f31d 100644 --- a/src/icetransport.cpp +++ b/src/icetransport.cpp @@ -132,7 +132,8 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role, IceTransport::~IceTransport() { g_main_loop_quit(mMainLoop.get()); - mMainLoopThread.join(); + if (mMainLoopThread.joinable()) + mMainLoopThread.join(); } Description::Role IceTransport::role() const { return mRole; } diff --git a/src/sctptransport.cpp b/src/sctptransport.cpp index 4b614874c..eb2fd142f 100644 --- a/src/sctptransport.cpp +++ b/src/sctptransport.cpp @@ -130,7 +130,8 @@ SctpTransport::~SctpTransport() { usrsctp_close(mSock); } - mSendThread.join(); + if (mSendThread.joinable()) + mSendThread.join(); usrsctp_deregister_address(this); GlobalCleanup();