diff --git a/docs/guides/websockets.md b/docs/guides/websockets.md
index 2e91eb4b8..5c0518e39 100644
--- a/docs/guides/websockets.md
+++ b/docs/guides/websockets.md
@@ -45,6 +45,11 @@ The maximum payload size that a connection accepts can be adjusted either global
By default, this limit is disabled. To disable the global setting in specific routes, you only need to call `#!cpp CROW_WEBSOCKET_ROUTE(app, "/url").max_payload(UINT64_MAX)`.
+## Subprotocols
+[:octicons-feed-tag-16: master](https://github.com/CrowCpp/Crow)
+
+Specifies the possible subprotocols that are available for the client. If specified, the first match with the client's requested subprotocols will be returned in the "Sec-WebSocket-Protocol" header of the handshake response. Otherwise, the connection will be closed. If no subprotocol are specified on both the client and the server side, the connection process will continue normally. It can be specified by using `#!cpp CROW_WEBSOCKET_ROUTE(app, "/url").subprotocols()`.
+
For more info about websocket routes go [here](../reference/classcrow_1_1_web_socket_rule.html).
diff --git a/include/crow/routing.h b/include/crow/routing.h
index 150937ef4..5477d0fac 100644
--- a/include/crow/routing.h
+++ b/include/crow/routing.h
@@ -461,12 +461,12 @@ namespace crow // NOTE: Already documented in "crow/app.h"
void handle_upgrade(const request& req, response&, SocketAdaptor&& adaptor) override
{
max_payload_ = max_payload_override_ ? max_payload_ : app_->websocket_max_payload();
- new crow::websocket::Connection(req, std::move(adaptor), app_, max_payload_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
+ new crow::websocket::Connection(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
}
#ifdef CROW_ENABLE_SSL
void handle_upgrade(const request& req, response&, SSLAdaptor&& adaptor) override
{
- new crow::websocket::Connection(req, std::move(adaptor), app_, max_payload_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
+ new crow::websocket::Connection(req, std::move(adaptor), app_, max_payload_, subprotocols_, open_handler_, message_handler_, close_handler_, error_handler_, accept_handler_);
}
#endif
@@ -478,6 +478,12 @@ namespace crow // NOTE: Already documented in "crow/app.h"
return *this;
}
+ self_t& subprotocols(const std::vector& subprotocols)
+ {
+ subprotocols_ = subprotocols;
+ return *this;
+ }
+
template
self_t& onopen(Func f)
{
@@ -522,6 +528,7 @@ namespace crow // NOTE: Already documented in "crow/app.h"
std::function accept_handler_;
uint64_t max_payload_;
bool max_payload_override_ = false;
+ std::vector subprotocols_;
};
/// Allows the user to assign parameters using functions.
diff --git a/include/crow/utility.h b/include/crow/utility.h
index 21bdeb6eb..af162dc3a 100644
--- a/include/crow/utility.h
+++ b/include/crow/utility.h
@@ -907,5 +907,44 @@ namespace crow
return v.substr(begin, end - begin);
}
+
+ /**
+ * @brief splits a string based on a separator
+ */
+ inline static std::vector split(const std::string& v, const std::string& separator)
+ {
+ std::vector result;
+ size_t startPos = 0;
+
+ for (size_t foundPos = v.find(separator); foundPos != std::string::npos; foundPos = v.find(separator, startPos))
+ {
+ result.push_back(v.substr(startPos, foundPos - startPos));
+ startPos = foundPos + separator.size();
+ }
+
+ result.push_back(v.substr(startPos));
+ return result;
+ }
+
+ /**
+ * @brief Returns the first occurence that matches between two ranges of iterators
+ * @param first1 begin() iterator of the first range
+ * @param last1 end() iterator of the first range
+ * @param first2 begin() iterator of the second range
+ * @param last2 end() iterator of the second range
+ * @return first occurence that matches between two ranges of iterators
+ */
+ template
+ inline static Iter1 find_first_of(Iter1 first1, Iter1 last1, Iter2 first2, Iter2 last2)
+ {
+ for (; first1 != last1; ++first1)
+ {
+ if (std::find(first2, last2, *first1) != last2)
+ {
+ return first1;
+ }
+ }
+ return last1;
+ }
} // namespace utility
} // namespace crow
diff --git a/include/crow/websocket.h b/include/crow/websocket.h
index 7c6694e32..8e5bd8731 100644
--- a/include/crow/websocket.h
+++ b/include/crow/websocket.h
@@ -66,6 +66,7 @@ namespace crow // NOTE: Already documented in "crow/app.h"
virtual void send_pong(std::string msg) = 0;
virtual void close(std::string const& msg = "quit", uint16_t status_code = CloseStatusCode::NormalClosure) = 0;
virtual std::string get_remote_ip() = 0;
+ virtual std::string get_subprotocol() const = 0;
virtual ~connection() = default;
void userdata(void* u) { userdata_ = u; }
@@ -109,7 +110,8 @@ namespace crow // NOTE: Already documented in "crow/app.h"
///
/// Requires a request with an "Upgrade: websocket" header.
/// Automatically handles the handshake.
- Connection(const crow::request& req, Adaptor&& adaptor, Handler* handler, uint64_t max_payload,
+ Connection(const crow::request& req, Adaptor&& adaptor, Handler* handler,
+ uint64_t max_payload, const std::vector& subprotocols,
std::function open_handler,
std::function message_handler,
std::function close_handler,
@@ -132,6 +134,17 @@ namespace crow // NOTE: Already documented in "crow/app.h"
return;
}
+ std::string requested_subprotocols_header = req.get_header_value("Sec-WebSocket-Protocol");
+ if (!subprotocols.empty() || !requested_subprotocols_header.empty())
+ {
+ auto requested_subprotocols = utility::split(requested_subprotocols_header, ", ");
+ auto subprotocol = utility::find_first_of(subprotocols.begin(), subprotocols.end(), requested_subprotocols.begin(), requested_subprotocols.end());
+ if (subprotocol != subprotocols.end())
+ {
+ subprotocol_ = *subprotocol;
+ }
+ }
+
if (accept_handler_)
{
void* ud = nullptr;
@@ -268,6 +281,12 @@ namespace crow // NOTE: Already documented in "crow/app.h"
max_payload_bytes_ = payload;
}
+ /// Returns the matching client/server subprotocol, empty string if none matched.
+ std::string get_subprotocol() const override
+ {
+ return subprotocol_;
+ }
+
protected:
/// Generate the websocket headers using an opcode and the message size (in bytes).
std::string build_header(int opcode, size_t size)
@@ -307,6 +326,12 @@ namespace crow // NOTE: Already documented in "crow/app.h"
write_buffers_.emplace_back(header);
write_buffers_.emplace_back(std::move(hello));
write_buffers_.emplace_back(crlf);
+ if (!subprotocol_.empty())
+ {
+ write_buffers_.emplace_back("Sec-WebSocket-Protocol: ");
+ write_buffers_.emplace_back(subprotocol_);
+ write_buffers_.emplace_back(crlf);
+ }
write_buffers_.emplace_back(crlf);
do_write();
if (open_handler_)
@@ -779,6 +804,7 @@ namespace crow // NOTE: Already documented in "crow/app.h"
uint16_t remaining_length16_{0};
uint64_t remaining_length_{0};
uint64_t max_payload_bytes_{UINT64_MAX};
+ std::string subprotocol_;
bool close_connection_{false};
bool is_reading{false};
bool has_mask_{false};
diff --git a/tests/unittest.cpp b/tests/unittest.cpp
index bbf402177..fe48576f5 100644
--- a/tests/unittest.cpp
+++ b/tests/unittest.cpp
@@ -3203,6 +3203,57 @@ TEST_CASE("websocket_max_payload")
app.stop();
} // websocket_max_payload
+TEST_CASE("websocket_subprotocols")
+{
+ static std::string http_message = "GET /ws HTTP/1.1\r\nConnection: keep-alive, Upgrade\r\nupgrade: websocket\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Protocol: myprotocol\r\nSec-WebSocket-Version: 13\r\nHost: localhost\r\n\r\n";
+
+ static websocket::connection* connection = nullptr;
+ static bool connected{false};
+
+ SimpleApp app;
+
+ CROW_WEBSOCKET_ROUTE(app, "/ws")
+ .subprotocols({"anotherprotocol", "myprotocol"})
+ .onaccept([&](const crow::request& req, void**) {
+ CROW_LOG_INFO << "Accepted websocket with URL " << req.url;
+ return true;
+ })
+ .onopen([&](websocket::connection& con) {
+ connected = true;
+ connection = &con;
+ CROW_LOG_INFO << "Connected websocket and subprotocol is " << con.get_subprotocol();
+ })
+ .onclose([&](websocket::connection&, const std::string&, uint16_t) {
+ CROW_LOG_INFO << "Closing websocket";
+ });
+
+ app.validate();
+
+ auto _ = app.bindaddr(LOCALHOST_ADDRESS).port(45451).run_async();
+ app.wait_for_server_start();
+ asio::io_service is;
+
+ asio::ip::tcp::socket c(is);
+ c.connect(asio::ip::tcp::endpoint(
+ asio::ip::address::from_string(LOCALHOST_ADDRESS), 45451));
+
+
+ char buf[2048];
+
+ //----------Handshake----------
+ {
+ std::fill_n(buf, 2048, 0);
+ c.send(asio::buffer(http_message));
+
+ c.receive(asio::buffer(buf, 2048));
+ std::this_thread::sleep_for(std::chrono::milliseconds(5));
+ CHECK(connected);
+ CHECK(connection->get_subprotocol() == "myprotocol");
+ }
+
+ app.stop();
+}
+
#ifdef CROW_ENABLE_COMPRESSION
TEST_CASE("zlib_compression")
{