Skip to content

Commit

Permalink
dnsdist: Always store the OpenSSLTLSIOCtx in the connection
Browse files Browse the repository at this point in the history
(cherry picked from commit 6aac1f0)
  • Loading branch information
rgacogne committed Sep 17, 2024
1 parent 8fd80e2 commit 84f83b3
Showing 1 changed file with 62 additions and 26 deletions.
88 changes: 62 additions & 26 deletions pdns/tcpiohandler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ class OpenSSLSession : public TLSSession
std::unique_ptr<SSL_SESSION, void(*)(SSL_SESSION*)> d_sess;
};

class OpenSSLTLSIOCtx;

class OpenSSLTLSConnection: public TLSConnection
{
public:
/* server side connection */
OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<OpenSSLFrontendContext> feContext): d_feContext(feContext), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(d_feContext->d_tlsCtx.get()), SSL_free)), d_timeout(timeout)
OpenSSLTLSConnection(int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_timeout(timeout)
{
d_socket = socket;

Expand All @@ -99,7 +101,7 @@ class OpenSSLTLSConnection: public TLSConnection
}

/* client-side connection */
OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<SSL_CTX>& tlsCtx): d_tlsCtx(tlsCtx), d_conn(std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(tlsCtx.get()), SSL_free)), d_hostname(hostname), d_timeout(timeout)
OpenSSLTLSConnection(const std::string& hostname, bool hostIsAddr, int socket, const struct timeval& timeout, std::shared_ptr<const OpenSSLTLSIOCtx> tlsCtx, std::unique_ptr<SSL, void(*)(SSL*)>&& conn): d_tlsCtx(std::move(tlsCtx)), d_conn(std::move(conn)), d_hostname(std::move(hostname)), d_timeout(timeout), d_isClient(true)
{
d_socket = socket;

Expand Down Expand Up @@ -286,7 +288,7 @@ class OpenSSLTLSConnection: public TLSConnection

IOState tryHandshake() override
{
if (!d_feContext) {
if (isClient()) {
/* In client mode, the handshake is initiated by the call to SSL_connect()
done from connect()/tryConnect().
In blocking mode it does not return before the handshake has been finished,
Expand Down Expand Up @@ -314,7 +316,7 @@ class OpenSSLTLSConnection: public TLSConnection

void doHandshake() override
{
if (!d_feContext) {
if (isClient()) {
/* we are a client, nothing to do, see the non-blocking version */
return;
}
Expand All @@ -335,7 +337,7 @@ class OpenSSLTLSConnection: public TLSConnection

IOState tryWrite(const PacketBuffer& buffer, size_t& pos, size_t toWrite) override
{
if (!d_feContext && !d_connected) {
if (isClient() && !d_connected) {
if (d_ktls) {
/* work-around to get kTLS to be started, as we cannot do that until after the socket has been connected */
SSL_set_fd(d_conn.get(), SSL_get_fd(d_conn.get()));
Expand Down Expand Up @@ -563,6 +565,11 @@ class OpenSSLTLSConnection: public TLSConnection
d_ktls = true;
}

bool isClient() const
{
return d_isClient;
}

static void generateConnectionIndexIfNeeded()
{
auto init = s_initTLSConnIndex.lock();
Expand All @@ -588,25 +595,38 @@ class OpenSSLTLSConnection: public TLSConnection
static LockGuarded<bool> s_initTLSConnIndex;
static int s_tlsConnIndex;
std::vector<std::unique_ptr<TLSSession>> d_tlsSessions;
/* server context */
std::shared_ptr<OpenSSLFrontendContext> d_feContext;
/* client context */
std::shared_ptr<SSL_CTX> d_tlsCtx;
std::shared_ptr<const OpenSSLTLSIOCtx> d_tlsCtx; // we need to hold a reference to this to make sure that the context exists for as long as the connection, even if a reload happens in the meantime
std::unique_ptr<SSL, void(*)(SSL*)> d_conn;
std::string d_hostname;
struct timeval d_timeout;
bool d_connected{false};
bool d_ktls{false};
bool d_isClient{false};
};

LockGuarded<bool> OpenSSLTLSConnection::s_initTLSConnIndex{false};
int OpenSSLTLSConnection::s_tlsConnIndex{-1};

class OpenSSLTLSIOCtx: public TLSCtx
class OpenSSLTLSIOCtx: public TLSCtx, public std::enable_shared_from_this<OpenSSLTLSIOCtx>
{
struct Private
{
explicit Private() = default;
};

public:
static std::shared_ptr<OpenSSLTLSIOCtx> createServerSideContext(TLSFrontend& fe)
{
return std::make_shared<OpenSSLTLSIOCtx>(fe, Private());
}

static std::shared_ptr<OpenSSLTLSIOCtx> createClientSideContext(const TLSContextParameters& params)
{
return std::make_shared<OpenSSLTLSIOCtx>(params, Private());
}

/* server side context */
OpenSSLTLSIOCtx(TLSFrontend& fe): d_feContext(std::make_shared<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
OpenSSLTLSIOCtx(TLSFrontend& fe, [[maybe_unused]] Private priv): d_feContext(std::make_unique<OpenSSLFrontendContext>(fe.d_addr, fe.d_tlsConfig))
{
OpenSSLTLSConnection::generateConnectionIndexIfNeeded();

Expand Down Expand Up @@ -649,7 +669,7 @@ class OpenSSLTLSIOCtx: public TLSCtx
}

/* client side context */
OpenSSLTLSIOCtx(const TLSContextParameters& params)
OpenSSLTLSIOCtx(const TLSContextParameters& params, [[maybe_unused]] Private)
{
int sslOptions =
SSL_OP_NO_SSLv2 |
Expand Down Expand Up @@ -797,16 +817,24 @@ class OpenSSLTLSIOCtx: public TLSCtx
return 1;
}

SSL_CTX* getOpenSSLContext() const
{
if (d_feContext) {
return d_feContext->d_tlsCtx.get();
}
return d_tlsCtx.get();
}

std::unique_ptr<TLSConnection> getConnection(int socket, const struct timeval& timeout, time_t now) override
{
handleTicketsKeyRotation(now);

return std::make_unique<OpenSSLTLSConnection>(socket, timeout, d_feContext);
return std::make_unique<OpenSSLTLSConnection>(socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
}

std::unique_ptr<TLSConnection> getClientConnection(const std::string& host, bool hostIsAddr, int socket, const struct timeval& timeout) override
{
auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, d_tlsCtx);
auto conn = std::make_unique<OpenSSLTLSConnection>(host, hostIsAddr, socket, timeout, shared_from_this(), std::unique_ptr<SSL, void(*)(SSL*)>(SSL_new(getOpenSSLContext()), SSL_free));
if (d_ktls) {
conn->enableKTLS();
}
Expand Down Expand Up @@ -841,24 +869,32 @@ class OpenSSLTLSIOCtx: public TLSCtx
return "openssl";
}

bool isServerContext() const
{
return d_feContext != nullptr;
}

bool setALPNProtos(const std::vector<std::vector<uint8_t>>& protos) override
{
if (d_feContext && d_feContext->d_tlsCtx) {
auto* openSSLContext = getOpenSSLContext();
if (openSSLContext == nullptr) {
return false;
}

if (isServerContext()) {
d_alpnProtos = protos;
libssl_set_alpn_select_callback(d_feContext->d_tlsCtx.get(), alpnServerSelectCallback, this);
libssl_set_alpn_select_callback(openSSLContext, alpnServerSelectCallback, this);
return true;
}
if (d_tlsCtx) {
return libssl_set_alpn_protos(d_tlsCtx.get(), protos);
}
return false;

return libssl_set_alpn_protos(openSSLContext, protos);
}

#ifndef DISABLE_NPN
bool setNextProtocolSelectCallback(bool(*cb)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen)) override
{
d_nextProtocolSelectCallback = cb;
libssl_set_npn_select_callback(d_tlsCtx.get(), npnSelectCallback, this);
libssl_set_npn_select_callback(getOpenSSLContext(), npnSelectCallback, this);
return true;
}
#endif /* DISABLE_NPN */
Expand Down Expand Up @@ -910,8 +946,8 @@ class OpenSSLTLSIOCtx: public TLSCtx
}

std::vector<std::vector<uint8_t>> d_alpnProtos; // store the supported ALPN protocols, so that the server can select based on what the client sent
std::shared_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
std::shared_ptr<SSL_CTX> d_tlsCtx{nullptr}; // client context, on a server-side the context is stored in d_feContext->d_tlsCtx
std::unique_ptr<OpenSSLFrontendContext> d_feContext{nullptr};
bool (*d_nextProtocolSelectCallback)(unsigned char** out, unsigned char* outlen, const unsigned char* in, unsigned int inlen){nullptr};
bool d_ktls{false};
};
Expand Down Expand Up @@ -1840,15 +1876,15 @@ bool TLSFrontend::setupTLS()
#endif /* HAVE_GNUTLS */
#ifdef HAVE_LIBSSL
if (d_provider == "openssl") {
newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
setupDoTProtocolNegotiation(newCtx);
std::atomic_store_explicit(&d_ctx, newCtx, std::memory_order_release);
return true;
}
#endif /* HAVE_LIBSSL */
}
#ifdef HAVE_LIBSSL
newCtx = std::make_shared<OpenSSLTLSIOCtx>(*this);
newCtx = OpenSSLTLSIOCtx::createServerSideContext(*this);
#else /* HAVE_LIBSSL */
#ifdef HAVE_GNUTLS
newCtx = std::make_shared<GnuTLSIOCtx>(*this);
Expand All @@ -1873,13 +1909,13 @@ std::shared_ptr<TLSCtx> getTLSContext(const TLSContextParameters& params)
#endif /* HAVE_GNUTLS */
#ifdef HAVE_LIBSSL
if (params.d_provider == "openssl") {
return std::make_shared<OpenSSLTLSIOCtx>(params);
return OpenSSLTLSIOCtx::createClientSideContext(params);
}
#endif /* HAVE_LIBSSL */
}

#ifdef HAVE_LIBSSL
return std::make_shared<OpenSSLTLSIOCtx>(params);
return OpenSSLTLSIOCtx::createClientSideContext(params);
#else /* HAVE_LIBSSL */
#ifdef HAVE_GNUTLS
return std::make_shared<GnuTLSIOCtx>(params);
Expand Down

0 comments on commit 84f83b3

Please sign in to comment.