Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Https subscription and root certs #341

Merged
merged 1 commit into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 59 additions & 42 deletions src/client/include/RestClientNative.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@
std::mutex _subscriptionLock;
std::map<URI<STRICT>, httplib::Client> _subscription1;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
X509_STORE *_client_cert_store = nullptr;
std::map<URI<STRICT>, httplib::SSLClient> _subscription2;
#endif

Expand All @@ -166,12 +165,6 @@
, _maxIoThreads(detail::find_argument_value<true, MaxIoThreads>([] { return MaxIoThreads(); }, initArgs...))
, _thread_pool(detail::find_argument_value<true, ThreadPoolType>([this] { return std::make_shared<BasicThreadPool<IO_BOUND>>(_name, _minIoThreads, _maxIoThreads); }, initArgs...))
, _caCertificate(detail::find_argument_value<true, ClientCertificates>([] { return rest::DefaultCertificate().get(); }, initArgs...)) {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (_client_cert_store != nullptr) {
X509_STORE_free(_client_cert_store);
}
_client_cert_store = detail::createCertificateStore(_caCertificate);
#endif
}
~RestClient() override { RestClient::stop(); };

Expand Down Expand Up @@ -285,7 +278,8 @@
if (cmd.topic.scheme() && equal_with_case_ignore(cmd.topic.scheme().value(), "https")) {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
httplib::SSLClient client(cmd.topic.hostName().value(), cmd.topic.port() ? cmd.topic.port().value() : 443);
client.set_ca_cert_store(_client_cert_store);
// client owns its certificate store and destroys it after use. create a store for each client
client.set_ca_cert_store(detail::createCertificateStore(_caCertificate));
ablepharus marked this conversation as resolved.
Show resolved Hide resolved
client.enable_server_certificate_verification(CHECK_CERTIFICATES);
callback(client);
#else
Expand Down Expand Up @@ -315,45 +309,68 @@
|| equal_with_case_ignore(*cmd.topic.scheme(), "https")
#endif
) {
auto it = _subscription1.find(cmd.topic);
if (it == _subscription1.end()) {
auto &client = _subscription1.try_emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value())).first->second;
client.set_follow_location(true);

auto longPollingEndpoint = [&] {
if (!cmd.topic.queryParamMap().contains(LONG_POLLING_IDX_TAG)) {
return URI<>::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build();
} else {
return URI<>::factory(cmd.topic).build();
}
}();

const auto pollHeaders = getPreferredContentTypeHeader(longPollingEndpoint);
auto endpoint = longPollingEndpoint.relativeRef().value();
client.set_read_timeout(cmd.timeout); // default keep-alive value
while (_run) {
auto redirect_get = [&client](auto url, auto headers) {
for (;;) {
auto result = client.Get(url, headers);
if (!result) return result;

if (result->status >= 300 && result->status < 400) {
url = httplib::detail::decode_url(result.value().get_header_value("location"), true);
} else {
return result;
}
auto createNewSubscription = [&](auto &client) {
{
client.set_follow_location(true);

auto longPollingEndpoint = [&] {
if (!cmd.topic.queryParamMap().contains(LONG_POLLING_IDX_TAG)) {
return URI<>::factory(cmd.topic).addQueryParameter(LONG_POLLING_IDX_TAG, "Next").build();
} else {
return URI<>::factory(cmd.topic).build();
}
};
if (const httplib::Result &result = redirect_get(endpoint, pollHeaders)) {
returnMdpMessage(cmd, result);
} else { // failed or server is down -> wait until retry
std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry
if (_run) {
returnMdpMessage(cmd, result, fmt::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast<int>(result.error())));
}();

const auto pollHeaders = getPreferredContentTypeHeader(longPollingEndpoint);
auto endpoint = longPollingEndpoint.relativeRef().value();
client.set_read_timeout(cmd.timeout); // default keep-alive value
while (_run) {
auto redirect_get = [&client](auto url, auto headers) {
for (;;) {
auto result = client.Get(url, headers);
if (!result) return result;

if (result->status >= 300 && result->status < 400) {
url = httplib::detail::decode_url(result.value().get_header_value("location"), true);
} else {
return result;
}
}
};
if (const httplib::Result &result = redirect_get(endpoint, pollHeaders)) {
returnMdpMessage(cmd, result);
} else { // failed or server is down -> wait until retry
std::this_thread::sleep_for(cmd.timeout); // time-out until potential retry
if (_run) {
returnMdpMessage(cmd, result, fmt::format("Long-Polling-GET request failed for {}: {}", cmd.topic.str(), static_cast<int>(result.error())));

Check warning on line 345 in src/client/include/RestClientNative.hpp

View check run for this annotation

Codecov / codecov/patch

src/client/include/RestClientNative.hpp#L345

Added line #L345 was not covered by tests
}
}
}
}
};
if (equal_with_case_ignore(*cmd.topic.scheme(), "http")) {
auto it = _subscription1.find(cmd.topic);
if (it == _subscription1.end()) {
_subscription1.emplace(cmd.topic, httplib::Client(cmd.topic.hostName().value(), cmd.topic.port().value()));
createNewSubscription(_subscription1.at(cmd.topic));
}
} else {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
if (auto it = _subscription2.find(cmd.topic); it == _subscription2.end()) {
_subscription2.emplace(
std::piecewise_construct,
std::forward_as_tuple(cmd.topic),
std::forward_as_tuple(cmd.topic.hostName().value(), cmd.topic.port().value()));
auto &client = _subscription2.at(cmd.topic);
client.set_ca_cert_store(detail::createCertificateStore(_caCertificate));
client.enable_server_certificate_verification(CHECK_CERTIFICATES);
createNewSubscription(_subscription2.at(cmd.topic));
}
#else
throw std::invalid_argument("https is not supported");
#endif
}

} else {
throw std::invalid_argument(fmt::format("unsupported scheme '{}' for requested subscription '{}'", cmd.topic.scheme(), cmd.topic.str()));
}
Expand Down
9 changes: 9 additions & 0 deletions src/client/include/RestDefaultClientCertificates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
_concatenated_certificates += root_certificates[1];
_concatenated_certificates += root_certificates[2];
_concatenated_certificates += root_certificates[3];

if (auto filename = std::getenv("OPENCMW_REST_CERT_FILE"); filename) {
std::ifstream ifs{ filename };

Check warning on line 23 in src/client/include/RestDefaultClientCertificates.hpp

View check run for this annotation

Codecov / codecov/patch

src/client/include/RestDefaultClientCertificates.hpp#L23

Added line #L23 was not covered by tests
if (!ifs.is_open()) {
std::string cert;
ifs >> cert;
_concatenated_certificates += cert;
}
}

Check warning on line 29 in src/client/include/RestDefaultClientCertificates.hpp

View check run for this annotation

Codecov / codecov/patch

src/client/include/RestDefaultClientCertificates.hpp#L25-L29

Added lines #L25 - L29 were not covered by tests
}
constexpr std::string get() const noexcept {
return _concatenated_certificates;
Expand Down
173 changes: 171 additions & 2 deletions src/client/test/RestClient_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ TEST_CASE("Basic Rest Client Constructor and API Tests", "[Client]") {
RestClient client5("clientName", DefaultContentTypeHeader(MIME::HTML), MinIoThreads(2), MaxIoThreads(5), ClientCertificates(testCertificate));
REQUIRE(client5.defaultMimeType() == MIME::HTML);
REQUIRE(client5.threadPool()->poolName() == "clientName");

REQUIRE_THROWS_AS(RestClient(ClientCertificates("Invalid Certificate Format")), std::invalid_argument);
}

TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") {
Expand Down Expand Up @@ -123,6 +121,73 @@ TEST_CASE("Basic Rest Client Get/Set Test - HTTP", "[Client]") {
}

#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
TEST_CASE("Multiple Rest Client Get/Set Test - HTTPS", "[Client]") {
using namespace opencmw::client;
RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate));
REQUIRE(RestClient::CHECK_CERTIFICATES);
RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check
REQUIRE(client.name() == "TestSSLClient");
REQUIRE(client.defaultMimeType() == MIME::JSON);

// HTTP
X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate);
EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey);
if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) {
FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr)));
}
httplib::SSLServer server(cert, pkey);

std::string acceptHeader;
server.Get("/endPoint", [&acceptHeader](const httplib::Request &req, httplib::Response &res) {
if (req.headers.contains("accept")) {
acceptHeader = req.headers.find("accept")->second;
} else {
FAIL("no accept headers found");
}
res.set_content("Hello World!", acceptHeader);
});
client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); });
while (!server.is_running()) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
REQUIRE(server.is_running());

std::array<std::atomic<bool>, 4> dones;
dones[0] = false;
dones[1] = false;
dones[2] = false;
dones[3] = false;
std::atomic<int> counter{ 0 };
auto makeCommand = [&]() {
IoBuffer data;
data.put('A');
data.put('B');
data.put('C');
data.put(0);

Command command;
command.command = mdp::Command::Get;
command.topic = URI<STRICT>("https://localhost:8080/endPoint");
command.data = std::move(data);
command.callback = [&dones, &counter](const mdp::Message &/*rep*/) {
int currentCounter = counter.fetch_add(1, std::memory_order_relaxed);
dones[currentCounter].store(true, std::memory_order_release);
// Assuming you have access to 'done' variable, uncomment the following line
dones[currentCounter].notify_all();
};
client.request(command);
};
for (int i = 0; i < 4; i++)
makeCommand();

for (auto &done : dones) {
done.wait(false);
}
REQUIRE(std::ranges::all_of(dones, [](auto &done) { return done.load(std::memory_order_acquire); }));
REQUIRE(acceptHeader == MIME::JSON.typeName());
server.stop();
}

TEST_CASE("Basic Rest Client Get/Set Test - HTTPS", "[Client]") {
using namespace opencmw::client;
RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate));
Expand Down Expand Up @@ -296,4 +361,108 @@ TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test", "[Client]") {
std::cout << "server stopped" << std::endl;
}

TEST_CASE("Basic Rest Client Subscribe/Unsubscribe Test HTTPS", "[Client]") {
// HTTP
X509 *cert = opencmw::client::detail::readServerCertificateFromFile(testServerCertificates.serverCertificate);
EVP_PKEY *pkey = opencmw::client::detail::readServerPrivateKeyFromFile(testServerCertificates.serverKey);
if (const X509_STORE *ca_store = opencmw::client::detail::createCertificateStore(testServerCertificates.caCertificate); !cert || !pkey || !ca_store) {
FAIL(fmt::format("Failed to load certificate: {}", ERR_error_string(ERR_get_error(), nullptr)));
}
using namespace opencmw::client;

std::atomic<int> updateCounter{ 0 };
detail::EventDispatcher eventDispatcher;
httplib::SSLServer server(cert, pkey);
server.Get("/event", [&eventDispatcher, &updateCounter](const httplib::Request &req, httplib::Response &res) {
DEBUG_LOG("Server received request");
auto acceptType = req.headers.find("accept");
if (acceptType == req.headers.end() || MIME::EVENT_STREAM.typeName() != acceptType->second) { // non-SSE request -> return default response
#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16))
res.set_content(fmt::format("update counter = {}", updateCounter.load()), MIME::TEXT);
#else
res.set_content(fmt::format("update counter = {}", updateCounter.load()), std::string(MIME::TEXT.typeName()));
#endif
return;
} else {
fmt::print("server received SSE request on path '{}' body = '{}'\n", req.path, req.body);
#if not defined(__EMSCRIPTEN__) and (not defined(__clang__) or (__clang_major__ >= 16))
res.set_chunked_content_provider(MIME::EVENT_STREAM, [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) {
#else
res.set_chunked_content_provider(std::string(MIME::EVENT_STREAM.typeName()), [&eventDispatcher](size_t /*offset*/, httplib::DataSink &sink) {
#endif
eventDispatcher.wait_event(sink);
return true;
});
}
});
server.Get("/endPoint", [](const httplib::Request &req, httplib::Response &res) {
fmt::print("server received request on path '{}' body = '{}'\n", req.path, req.body);
res.set_content("Hello World!", "text/plain");
});

RestClient client("TestSSLClient", ClientCertificates(testServerCertificates.caCertificate));

client.threadPool()->execute<"RestServer">([&server] { server.listen("localhost", 8080); });
while (!server.is_running()) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
REQUIRE(server.is_running());
REQUIRE(RestClient::CHECK_CERTIFICATES);
RestClient::CHECK_CERTIFICATES = true; // 'false' disables certificate check
REQUIRE(client.name() == "TestSSLClient");
REQUIRE(client.defaultMimeType() == MIME::JSON);

std::atomic<int> receivedRegular(0);
std::atomic<int> receivedError(0);
IoBuffer data;
data.put('A');
data.put('B');
data.put('C');
data.put(0);

Command command;
command.command = mdp::Command::Subscribe;
command.topic = URI<STRICT>("https://localhost:8080/event");
command.data = std::move(data);
command.callback = [&receivedRegular, &receivedError](const mdp::Message &rep) {
fmt::print("SSE client received reply = '{}' - body size: '{}'\n", rep.data.asString(), rep.data.size());
if (rep.error.size() == 0) {
receivedRegular.fetch_add(1, std::memory_order_relaxed);
} else {
receivedError.fetch_add(1, std::memory_order_relaxed);
}
receivedRegular.notify_all();
receivedError.notify_all();
};

client.request(command);

std::cout << "client request launched" << std::endl;
std::this_thread::sleep_for(std::chrono::milliseconds(100));
eventDispatcher.send_event("test-event meta data");
std::jthread dispatcher([&updateCounter, &eventDispatcher] {
while (updateCounter < 5) {
std::this_thread::sleep_for(std::chrono::milliseconds(20));
eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++));
}
});
dispatcher.join();

while (receivedRegular.load(std::memory_order_relaxed) < 5) {
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
std::cout << "done waiting" << std::endl;
REQUIRE(receivedRegular.load(std::memory_order_acquire) >= 5);

command.command = mdp::Command::Unsubscribe;
client.request(command);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
std::cout << "done Unsubscribe" << std::endl;

client.stop();
server.stop();
eventDispatcher.send_event(fmt::format("test-event {}", updateCounter++));
std::cout << "server stopped" << std::endl;
}

} // namespace opencmw::rest_client_test
Loading