From 389aa9f5fbd79637c8720528894ba05baf980ab8 Mon Sep 17 00:00:00 2001 From: Will Sobel Date: Thu, 2 May 2024 16:56:23 -0400 Subject: [PATCH] Added initial websocket test --- src/mtconnect/sink/rest_sink/routing.hpp | 6 +- .../sink/rest_sink/websocket_session.hpp | 11 +- test_package/CMakeLists.txt | 1 + test_package/websockets_test.cpp | 246 ++++++++++++++++++ 4 files changed, 260 insertions(+), 4 deletions(-) create mode 100644 test_package/websockets_test.cpp diff --git a/src/mtconnect/sink/rest_sink/routing.hpp b/src/mtconnect/sink/rest_sink/routing.hpp index 0710f15b..bfa584ee 100644 --- a/src/mtconnect/sink/rest_sink/routing.hpp +++ b/src/mtconnect/sink/rest_sink/routing.hpp @@ -241,7 +241,11 @@ namespace mtconnect::sink::rest_sink { /// @brief Sets the command associated with this routing for use with websockets /// @param command the command - void command(const std::string &command) { m_command = command; } + auto &command(const std::string &command) + { + m_command = command; + return *this; + } protected: void pathParameters(std::string s) diff --git a/src/mtconnect/sink/rest_sink/websocket_session.hpp b/src/mtconnect/sink/rest_sink/websocket_session.hpp index 013d8601..84eda57e 100644 --- a/src/mtconnect/sink/rest_sink/websocket_session.hpp +++ b/src/mtconnect/sink/rest_sink/websocket_session.hpp @@ -104,7 +104,14 @@ namespace mtconnect::sink::rest_sink { if (!m_isOpen) return; - auto ptr = derived().shared_ptr(); + m_isOpen = false; + + auto wptr = weak_from_this(); + std::shared_ptr ptr; + if (!wptr.expired()) + { + ptr = wptr.lock(); + } m_request.reset(); m_requests.clear(); @@ -117,8 +124,6 @@ namespace mtconnect::sink::rest_sink { } } closeStream(); - - m_isOpen = false; } void writeResponse(ResponsePtr &&response, Complete complete = nullptr) override diff --git a/test_package/CMakeLists.txt b/test_package/CMakeLists.txt index ec596101..b0fb4c83 100644 --- a/test_package/CMakeLists.txt +++ b/test_package/CMakeLists.txt @@ -245,6 +245,7 @@ add_agent_test(qname FALSE entity) add_agent_test(file_cache FALSE sink/rest_sink) add_agent_test(http_server FALSE sink/rest_sink TRUE) +add_agent_test(websockets FALSE sink/rest_sink TRUE) add_agent_test(tls_http_server FALSE sink/rest_sink TRUE) add_agent_test(routing FALSE sink/rest_sink) diff --git a/test_package/websockets_test.cpp b/test_package/websockets_test.cpp new file mode 100644 index 00000000..0ca15442 --- /dev/null +++ b/test_package/websockets_test.cpp @@ -0,0 +1,246 @@ +// +// Copyright Copyright 2009-2024, AMT – The Association For Manufacturing Technology (“AMT”) +// All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Ensure that gtest is the first header otherwise Windows raises an error +#include +// Keep this comment to keep gtest.h above. (clang-format off/on is not working here!) + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "mtconnect/logging.hpp" +#include "mtconnect/sink/rest_sink/server.hpp" + +using namespace std; +using namespace mtconnect; +using namespace mtconnect::sink::rest_sink; + +namespace asio = boost::asio; +namespace beast = boost::beast; +namespace http = boost::beast::http; +using tcp = boost::asio::ip::tcp; +namespace websocket = beast::websocket; + +// main +int main(int argc, char* argv[]) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} + +class Client +{ +public: + Client(asio::io_context& ioc) : m_context(ioc), m_stream(ioc) {} + + ~Client() { close(); } + + void fail(beast::error_code ec, char const* what) + { + LOG(error) << what << ": " << ec.message() << "\n"; + m_done = true; + m_ec = ec; + } + + void connect(unsigned short port, asio::yield_context yield) + { + beast::error_code ec; + + // These objects perform our I/O + tcp::endpoint server(asio::ip::address_v4::from_string("127.0.0.1"), port); + + // Make the connection on the IP address we get from a lookup + beast::get_lowest_layer(m_stream).async_connect(server, yield[ec]); + + if (ec) + { + return fail(ec, "connect"); + } + + m_stream.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client)); + + m_stream.set_option(websocket::stream_base::decorator([](websocket::request_type& req) { + req.set(http::field::user_agent, + std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client"); + })); + + string host = "127.0.0.1:" + std::to_string(port); + m_stream.async_handshake(host, "/", yield[ec]); + + if (ec) + { + return fail(ec, "connect"); + } + + m_connected = true; + + m_stream.async_read(m_buffer, beast::bind_front_handler(&Client::onRead, this)); + } + + void onRead(beast::error_code ec, std::size_t bytes_transferred) + { + m_result = beast::buffers_to_string(m_buffer.data()); + m_buffer.consume(m_buffer.size()); + + m_done = true; + } + + void request(const string& payload, asio::yield_context yield) + { + cout << "spawnRequest: done: false" << endl; + m_done = false; + beast::error_code ec; + + m_stream.async_write(asio::buffer(payload), yield[ec]); + + waitFor(2s, [this]() { return m_done; }); + } + + template + bool waitFor(const chrono::duration& time, function pred) + { + boost::asio::steady_timer timer(m_context); + timer.expires_from_now(time); + bool timeout = false; + timer.async_wait([&timeout](boost::system::error_code ec) { + if (!ec) + { + timeout = true; + } + }); + + while (!timeout && !pred()) + { + m_context.run_for(500ms); + } + timer.cancel(); + + return pred(); + } + + void close() + { + beast::error_code ec; + + // Gracefully close the socket + m_stream.next_layer().shutdown(tcp::socket::shutdown_both, ec); + } + + bool m_connected {false}; + int m_status; + std::string m_result; + asio::io_context& m_context; + bool m_done {false}; + websocket::stream m_stream; + beast::flat_buffer m_buffer; + boost::beast::error_code m_ec; + beast::flat_buffer m_b; + int m_count {0}; +}; + +class WebsocketsTest : public testing::Test +{ +protected: + void SetUp() override + { + using namespace mtconnect::configuration; + m_server = make_unique(m_context, ConfigOptions {{Port, 0}, {ServerIp, "127.0.0.1"s}}); + } + + void createServer(const ConfigOptions& options) + { + using namespace mtconnect::configuration; + ConfigOptions opts {{Port, 0}, {ServerIp, "127.0.0.1"s}}; + opts.merge(ConfigOptions(options)); + m_server = make_unique(m_context, opts); + } + + void start() + { + m_server->start(); + while (!m_server->isListening()) + m_context.run_one(); + m_client = make_unique(m_context); + } + + void startClient() + { + m_client->m_connected = false; + asio::spawn(m_context, + std::bind(&Client::connect, m_client.get(), + static_cast(m_server->getPort()), std::placeholders::_1)); + + m_client->waitFor(1s, [this]() { return m_client->m_connected; }); + } + + void TearDown() override + { + m_server.reset(); + m_client.reset(); + } + + asio::io_context m_context; + unique_ptr m_server; + unique_ptr m_client; +}; + +TEST_F(WebsocketsTest, should_connect_to_server) +{ + start(); + startClient(); + + ASSERT_TRUE(m_client->m_connected); +} + +TEST_F(WebsocketsTest, should_make_simple_request) +{ + weak_ptr savedSession; + + auto probe = [&](SessionPtr session, RequestPtr request) -> bool { + savedSession = session; + ResponsePtr resp = make_unique(status::ok); + resp->m_body = "All Devices for "s + *request->m_requestId; + resp->m_requestId = request->m_requestId; + session->writeResponse(std::move(resp), []() { cout << "Written" << endl; }); + return true; + }; + + m_server->addRouting({boost::beast::http::verb::get, "/probe", probe}).command("probe"); + m_server->addCommands(); + + start(); + startClient(); + + asio::spawn(m_context, std::bind(&Client::request, m_client.get(), + "{\"id\":\"1\",\"request\":\"probe\"}"s, std::placeholders::_1)); + + m_client->waitFor(2s, [this]() { return m_client->m_done; }); + + ASSERT_TRUE(m_client->m_done); + ASSERT_EQ("All Devices for 1", m_client->m_result); +}