From 8ce430e31b011a0e4efbe0aa83e9c5d551bb9745 Mon Sep 17 00:00:00 2001 From: johnxwork <44791188+johnxwork@users.noreply.github.com> Date: Thu, 23 May 2024 11:11:35 +0800 Subject: [PATCH] Fix WebSocket implementation in Cobalt 25 (#3328) Fixed the following bugs to make it work: - Deleted the whole concept of flow control; - The data frame received from net was not populated; - net now sends only a string view which does not promise validaty later. So we store a copy of the data between consecutive reads of the same frame. b/339672080 (cherry picked from commit a7b54abd873b5bca8268e6c7b2f5655cbf9b1f21) --- cobalt/websocket/BUILD.gn | 4 - cobalt/websocket/buffered_amount_tracker.cc | 62 -- cobalt/websocket/buffered_amount_tracker.h | 64 -- .../websocket/buffered_amount_tracker_test.cc | 144 ----- .../cobalt_web_socket_event_handler.cc | 19 +- .../cobalt_web_socket_event_handler.h | 5 +- cobalt/websocket/sec_web_socket_key.h | 59 -- .../websocket/web_socket_event_interface.cc | 126 ---- cobalt/websocket/web_socket_event_interface.h | 150 ----- .../websocket/web_socket_frame_container.cc | 196 ------ cobalt/websocket/web_socket_frame_container.h | 133 ----- .../web_socket_frame_container_test.cc | 437 -------------- .../websocket/web_socket_handshake_helper.cc | 336 ----------- .../websocket/web_socket_handshake_helper.h | 76 --- .../web_socket_handshake_helper_test.cc | 564 ------------------ cobalt/websocket/web_socket_impl.cc | 58 +- cobalt/websocket/web_socket_impl.h | 10 - cobalt/websocket/web_socket_impl_test.cc | 232 ------- .../websocket/web_socket_message_container.cc | 127 ---- .../websocket/web_socket_message_container.h | 95 --- .../web_socket_message_container_test.cc | 174 ------ 21 files changed, 21 insertions(+), 3050 deletions(-) delete mode 100644 cobalt/websocket/buffered_amount_tracker.cc delete mode 100644 cobalt/websocket/buffered_amount_tracker.h delete mode 100644 cobalt/websocket/buffered_amount_tracker_test.cc delete mode 100644 cobalt/websocket/sec_web_socket_key.h delete mode 100644 cobalt/websocket/web_socket_event_interface.cc delete mode 100644 cobalt/websocket/web_socket_event_interface.h delete mode 100644 cobalt/websocket/web_socket_frame_container.cc delete mode 100644 cobalt/websocket/web_socket_frame_container.h delete mode 100644 cobalt/websocket/web_socket_frame_container_test.cc delete mode 100644 cobalt/websocket/web_socket_handshake_helper.cc delete mode 100644 cobalt/websocket/web_socket_handshake_helper.h delete mode 100644 cobalt/websocket/web_socket_handshake_helper_test.cc delete mode 100644 cobalt/websocket/web_socket_impl_test.cc delete mode 100644 cobalt/websocket/web_socket_message_container.cc delete mode 100644 cobalt/websocket/web_socket_message_container.h delete mode 100644 cobalt/websocket/web_socket_message_container_test.cc diff --git a/cobalt/websocket/BUILD.gn b/cobalt/websocket/BUILD.gn index dc0608ebbd35..37383b215e90 100644 --- a/cobalt/websocket/BUILD.gn +++ b/cobalt/websocket/BUILD.gn @@ -16,12 +16,9 @@ static_library("websocket") { has_pedantic_warnings = true sources = [ - "buffered_amount_tracker.cc", - "buffered_amount_tracker.h", "close_event.h", "cobalt_web_socket_event_handler.cc", "cobalt_web_socket_event_handler.h", - "sec_web_socket_key.h", "web_socket.cc", "web_socket.h", "web_socket_impl.cc", @@ -47,7 +44,6 @@ target(gtest_target_type, "websocket_test") { sources = [ "mock_websocket_channel.cc", "mock_websocket_channel.h", - "web_socket_impl_test.cc", "web_socket_test.cc", ] deps = [ diff --git a/cobalt/websocket/buffered_amount_tracker.cc b/cobalt/websocket/buffered_amount_tracker.cc deleted file mode 100644 index 970421d9ea8d..000000000000 --- a/cobalt/websocket/buffered_amount_tracker.cc +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cobalt/websocket/buffered_amount_tracker.h" - -#include - -#include "base/logging.h" - -namespace cobalt { -namespace websocket { - -std::size_t BufferedAmountTracker::Pop(const std::size_t number_bytes_to_pop) { - std::size_t size_left_pop = number_bytes_to_pop; - std::size_t payload_amount = 0; - while (!entries_.empty() && (size_left_pop > 0)) { - Entry& entry(entries_[0]); - - std::size_t potential_payload_delta = 0; - - // Cache this variable in case we do a |pop_front|. - const bool is_user_payload = entry.is_user_payload_; - - if (entry.message_size_ > size_left_pop) { - potential_payload_delta = size_left_pop; - entry.message_size_ -= size_left_pop; - } else { // entry.message_size <= size_left - potential_payload_delta += entry.message_size_; - entries_.pop_front(); - } - - if (is_user_payload) { - payload_amount += potential_payload_delta; - } - - size_left_pop -= potential_payload_delta; - } - - // std::min prevents an underflow due to a bug. - DCHECK_LE(payload_amount, total_payload_inflight_); - - total_payload_inflight_ -= std::min(total_payload_inflight_, payload_amount); - - // Sort of an overflow check... - DCHECK_LE(size_left_pop, number_bytes_to_pop); - - return payload_amount; -} - -} // namespace websocket -} // namespace cobalt diff --git a/cobalt/websocket/buffered_amount_tracker.h b/cobalt/websocket/buffered_amount_tracker.h deleted file mode 100644 index 8ac6000c81d9..000000000000 --- a/cobalt/websocket/buffered_amount_tracker.h +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef COBALT_WEBSOCKET_BUFFERED_AMOUNT_TRACKER_H_ -#define COBALT_WEBSOCKET_BUFFERED_AMOUNT_TRACKER_H_ - -#include - -namespace cobalt { -namespace websocket { - -// This class is helper class for implementing the |bufferedAmount| attribute in -// the Websockets API at: -// https://www.w3.org/TR/websockets/#dom-websocket-bufferedamount. -class BufferedAmountTracker { - public: - BufferedAmountTracker() : total_payload_inflight_(0) {} - - // The payload in this context only applies to user messages. - // So even though a close message might have a "payload", for the purposes - // of the tracker, the |is_user_payload| will be false. - void Add(const bool is_user_payload, const std::size_t message_size) { - if (is_user_payload) { - total_payload_inflight_ += message_size; - } - entries_.push_back(Entry(is_user_payload, message_size)); - } - - // Returns the number of user payload bytes popped, these are bytes that were - // added with |is_user_payload| set to true. - std::size_t Pop(const std::size_t number_bytes_to_pop); - - std::size_t GetTotalBytesInflight() const { return total_payload_inflight_; } - - private: - struct Entry { - Entry(const bool is_user_payload, const std::size_t message_size) - : is_user_payload_(is_user_payload), message_size_(message_size) {} - - bool is_user_payload_; - std::size_t message_size_; - }; - - typedef std::deque Entries; - - std::size_t total_payload_inflight_; - Entries entries_; -}; - -} // namespace websocket -} // namespace cobalt - -#endif // COBALT_WEBSOCKET_BUFFERED_AMOUNT_TRACKER_H_ diff --git a/cobalt/websocket/buffered_amount_tracker_test.cc b/cobalt/websocket/buffered_amount_tracker_test.cc deleted file mode 100644 index 77290bd76f79..000000000000 --- a/cobalt/websocket/buffered_amount_tracker_test.cc +++ /dev/null @@ -1,144 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cobalt/websocket/buffered_amount_tracker.h" - -#include "testing/gtest/include/gtest/gtest.h" - -namespace cobalt { -namespace websocket { - -class BufferedAmountTrackerTest : public ::testing::Test { - public: - BufferedAmountTracker tracker_; -}; - -TEST_F(BufferedAmountTrackerTest, Construct) { - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 0ul); -} - -TEST_F(BufferedAmountTrackerTest, AddNonPayload) { - tracker_.Add(false, 5); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 0ul); -} - -TEST_F(BufferedAmountTrackerTest, AddPayload) { - tracker_.Add(true, 5); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 5ul); -} - -TEST_F(BufferedAmountTrackerTest, AddCombination) { - tracker_.Add(true, 5); - tracker_.Add(false, 3); - tracker_.Add(true, 11); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 16ul); -} - -TEST_F(BufferedAmountTrackerTest, AddCombination2) { - tracker_.Add(false, 2); - tracker_.Add(false, 3); - tracker_.Add(true, 11); - tracker_.Add(true, 1); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 12ul); -} - -TEST_F(BufferedAmountTrackerTest, AddZero) { - tracker_.Add(true, 0); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 0ul); - - tracker_.Add(false, 0); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 0ul); -} - -TEST_F(BufferedAmountTrackerTest, PopZero) { - std::size_t amount_poped = 0; - tracker_.Add(false, 5); - amount_poped = tracker_.Pop(0); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 0ul); - EXPECT_EQ(amount_poped, 0); - - tracker_.Add(false, 0); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 0ul); - EXPECT_EQ(amount_poped, 0); -} - -TEST_F(BufferedAmountTrackerTest, PopPayload) { - std::size_t amount_poped = 0; - tracker_.Add(true, 5); - amount_poped = tracker_.Pop(4); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 1ul); - EXPECT_EQ(amount_poped, 4); -} - -TEST_F(BufferedAmountTrackerTest, PopNonPayload) { - std::size_t amount_poped = 0; - tracker_.Add(false, 5); - amount_poped = tracker_.Pop(3); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 0ul); - EXPECT_EQ(amount_poped, 0); -} - -TEST_F(BufferedAmountTrackerTest, PopCombined) { - std::size_t amount_poped = 0; - tracker_.Add(false, 5); - tracker_.Add(true, 3); - amount_poped = tracker_.Pop(7); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 1ul); - EXPECT_EQ(amount_poped, 2); -} - -TEST_F(BufferedAmountTrackerTest, PopOverflow) { - std::size_t amount_poped = 0; - tracker_.Add(false, 2); - tracker_.Add(true, 2); - amount_poped = tracker_.Pop(16); - EXPECT_EQ(amount_poped, 2); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 0ul); -} - -TEST_F(BufferedAmountTrackerTest, MultipleOperations) { - tracker_.Add(false, 2); - tracker_.Add(true, 3); - tracker_.Add(false, 1); - tracker_.Add(false, 6); - tracker_.Add(true, 14); - tracker_.Add(false, 3); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 17ul); - - std::size_t amount_poped = 0; - tracker_.Add(false, 3); - amount_poped = tracker_.Pop(4); - EXPECT_EQ(amount_poped, 2); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 15ul); - tracker_.Add(false, 4); - amount_poped = tracker_.Pop(2); - EXPECT_EQ(amount_poped, 1); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 14ul); - tracker_.Add(true, 5); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 19ul); - amount_poped = tracker_.Pop(4); - EXPECT_EQ(amount_poped, 0); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 19ul); - tracker_.Add(false, 6); - amount_poped = tracker_.Pop(4); - EXPECT_EQ(amount_poped, 2); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 17ul); - tracker_.Add(true, 7); - amount_poped = tracker_.Pop(4); - EXPECT_EQ(amount_poped, 4); - EXPECT_EQ(tracker_.GetTotalBytesInflight(), 20ul); -} - -} // namespace websocket -} // namespace cobalt diff --git a/cobalt/websocket/cobalt_web_socket_event_handler.cc b/cobalt/websocket/cobalt_web_socket_event_handler.cc index 9caefd4523d2..e2cd5c0b04e7 100644 --- a/cobalt/websocket/cobalt_web_socket_event_handler.cc +++ b/cobalt/websocket/cobalt_web_socket_event_handler.cc @@ -24,12 +24,11 @@ namespace cobalt { namespace websocket { namespace { -typedef std::vector, size_t>> - FrameDataVector; +typedef std::vector FrameDataVector; std::size_t GetMessageLength(const FrameDataVector& frame_data) { std::size_t total_length = 0; for (const auto& i : frame_data) { - total_length += i.second; + total_length += i.size(); } return total_length; } @@ -42,12 +41,10 @@ std::size_t CombineFramesChunks(FrameDataVector::const_iterator begin, std::size_t bytes_available = buffer_length; for (FrameDataVector::const_iterator iterator = begin; iterator != end; ++iterator) { - const scoped_refptr& data = iterator->first; - - std::size_t frame_chunk_size = iterator->second; - + const auto& data = iterator->data(); + std::size_t frame_chunk_size = iterator->size(); if (bytes_available >= frame_chunk_size) { - memcpy(out_destination, data->data(), frame_chunk_size); + memcpy(out_destination, data, frame_chunk_size); out_destination += frame_chunk_size; bytes_written += frame_chunk_size; bytes_available -= frame_chunk_size; @@ -64,18 +61,18 @@ void CobaltWebSocketEventHandler::OnAddChannelResponse( const std::string& selected_subprotocol, const std::string& extensions) { creator_->OnHandshakeComplete(selected_subprotocol); } + void CobaltWebSocketEventHandler::OnDataFrame(bool fin, WebSocketMessageType type, base::span payload) { + std::string message(payload.data(), payload.size()); if (message_type_ == net::WebSocketFrameHeader::kOpCodeControlUnused) { message_type_ = type; } if (type != net::WebSocketFrameHeader::kOpCodeContinuation) { DCHECK_EQ(message_type_, type); } -#ifndef COBALT_PENDING_CLEAN_UP - frame_data_.push_back(std::make_pair(std::move(payload), buffer_size)); -#endif + frame_data_.push_back(std::move(message)); if (fin) { std::size_t message_length = GetMessageLength(frame_data_); scoped_refptr buf = diff --git a/cobalt/websocket/cobalt_web_socket_event_handler.h b/cobalt/websocket/cobalt_web_socket_event_handler.h index 801f34ee3cb5..55841298b194 100644 --- a/cobalt/websocket/cobalt_web_socket_event_handler.h +++ b/cobalt/websocket/cobalt_web_socket_event_handler.h @@ -21,8 +21,6 @@ #include #include "base/basictypes.h" -#include "cobalt/websocket/web_socket_frame_container.h" -#include "cobalt/websocket/web_socket_handshake_helper.h" #include "net/base/io_buffer.h" #include "net/websockets/websocket_event_interface.h" #include "net/websockets/websocket_frame_parser.h" @@ -134,8 +132,7 @@ class CobaltWebSocketEventHandler : public net::WebSocketEventInterface { // move more thing WebSocketImpl* creator_; // This vector should store data of data frames from the same message. - typedef std::vector, size_t>> - FrameDataVector; + typedef std::vector FrameDataVector; FrameDataVector frame_data_; WebSocketMessageType message_type_ = net::WebSocketFrameHeader::kOpCodeControlUnused; diff --git a/cobalt/websocket/sec_web_socket_key.h b/cobalt/websocket/sec_web_socket_key.h deleted file mode 100644 index ad6a7e84cf1f..000000000000 --- a/cobalt/websocket/sec_web_socket_key.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef COBALT_WEBSOCKET_SEC_WEB_SOCKET_KEY_H_ -#define COBALT_WEBSOCKET_SEC_WEB_SOCKET_KEY_H_ - -#include -#include - -#include "base/base64.h" -#include "base/basictypes.h" -#include "base/logging.h" -#include "base/strings/string_piece.h" -#include "starboard/memory.h" - -namespace cobalt { -namespace websocket { - -struct SecWebSocketKey { - enum { kKeySizeInBytes = 16 }; - typedef char SecWebSocketKeyBytes[kKeySizeInBytes]; - - SecWebSocketKey() { - memset(&key_bytes[0], 0, kKeySizeInBytes); - base::StringPiece key_stringpiece(key_bytes, sizeof(key_bytes)); - base::Base64Encode(key_stringpiece, &key_base64_encoded); - } - - explicit SecWebSocketKey(const SecWebSocketKeyBytes& key) { - memcpy(&key_bytes[0], &key[0], sizeof(key_bytes)); - base::StringPiece key_stringpiece(key_bytes, sizeof(key_bytes)); - base::Base64Encode(key_stringpiece, &key_base64_encoded); - } - - const std::string& GetKeyEncodedInBase64() const { - DCHECK_GT(key_base64_encoded.size(), 0ull); - return key_base64_encoded; - } - const SecWebSocketKeyBytes& GetRawKeyBytes() const { return key_bytes; } - - private: - SecWebSocketKeyBytes key_bytes; - std::string key_base64_encoded; -}; - -} // namespace websocket -} // namespace cobalt - -#endif // COBALT_WEBSOCKET_SEC_WEB_SOCKET_KEY_H_ diff --git a/cobalt/websocket/web_socket_event_interface.cc b/cobalt/websocket/web_socket_event_interface.cc deleted file mode 100644 index a0d0318bf84d..000000000000 --- a/cobalt/websocket/web_socket_event_interface.cc +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cobalt/websocket/web_socket_event_interface.h" - -#include - -#include "base/logging.h" -#include "cobalt/websocket/web_socket_impl.h" - -namespace cobalt { -namespace websocket { -namespace { -std::size_t GetMessageLength(const FrameDataVector& frame_data) { - std::size_t total_length = 0; - for (const auto& i : frame_data) { - total_length += i.second; - } - return total_length; -} -std::size_t CombineFramesChunks(FrameDataVector::const_iterator begin, - FrameDataVector::const_iterator end, - char* out_destination, - std::size_t buffer_length) { - DCHECK(out_destination); - std::size_t bytes_written = 0; - std::size_t bytes_available = buffer_length; - for (FrameDataVector::const_iterator iterator = begin; iterator != end; - ++iterator) { - const scoped_refptr& data = iterator->first; - - std::size_t frame_chunk_size = iterator->second; - - if (bytes_available >= frame_chunk_size) { - memcpy(out_destination, data->data(), frame_chunk_size); - out_destination += frame_chunk_size; - bytes_written += frame_chunk_size; - bytes_available -= frame_chunk_size; - } - } - - DCHECK_EQ(bytes_written, buffer_length); - return bytes_written; -} -} // namespace - -void CobaltWebSocketEventHandler::OnAddChannelResponse( - const std::string& selected_subprotocol, const std::string& extensions) { - creator_->OnHandshakeComplete(selected_subprotocol); -} -void CobaltWebSocketEventHandler::OnDataFrame(bool fin, - WebSocketMessageType type, - scoped_refptr buffer, - size_t buffer_size) { - if (message_type_ == net::WebSocketFrameHeader::kOpCodeControlUnused) { - message_type_ = type; - } - DCHECK_EQ(message_type_, type); - frame_data_.push_back(std::make_pair(std::move(buffer), buffer_size)); - if (fin) { - std::size_t message_length = GetMessageLength(frame_data_); - scoped_refptr buf = - base::WrapRefCounted(new net::IOBufferWithSize(message_length)); - CombineFramesChunks(frame_data_.begin(), frame_data_.end(), buf->data, - message_length); - frame_data_.clear(); - - bool is_text_message = - message_type_ == net::WebSocketFrameHeader::kOpCodeText; - if (is_text_message && buf && (buf->size() > 0)) { - base::StringPiece payload_string_piece(buf->data(), buf->size()); - if (!base::IsStringUTF8(payload_string_piece)) { - WebSocketImpl::CloseInfo close_info(net::kWebSocketErrorProtocolError); - creator_->TrampolineClose(close_info); - return; - } - } - creator_->OnWebSocketReceivedData(is_text_message, std::move(buf)); - } -} - -void CobaltWebSocketEventHandler::OnClosingHandshake() { - creator_->OnClose(true, net::kWebSocketNormalClosure, - "Received close handshake initiation."); -} - -void CobaltWebSocketEventHandler::OnFailChannel(const std::string& message) { - DLOG(WARNING) << "WebSocket channel failed due to: " << message; - creator_->OnClose(true, net::kWebSocketErrorAbnormalClosure, message); -} - -void OnDropChannel(bool was_clean, uint16_t code, const std::string& reason) { - creator_->OnClose(was_clean, code, reason); -} - -void CobaltWebSocketEventHandler::OnSSLCertificateError( - std::unique_ptr ssl_error_callbacks, - const GURL& url, const SSLInfo& ssl_info, bool fatal) { - // TODO: determine if there are circumstances we want to continue - // the request. - DLOG(WARNING) << "SSL cert failure occurred, cancelling connection"; - ssl_error_callbacks->CancelSSLRequest(-1, ssl_info); -} -int CobaltWebSocketEventHandler::OnAuthRequired( - scoped_refptr auth_info, - scoped_refptr response_headers, - const net::HostPortPair& host_port_pair, - base::OnceCallback callback, - base::Optional* credentials) { - NOTIMPLEMENTED() - return net::OK; -} - -} // namespace websocket -} // namespace cobalt diff --git a/cobalt/websocket/web_socket_event_interface.h b/cobalt/websocket/web_socket_event_interface.h deleted file mode 100644 index 7a629de39cbc..000000000000 --- a/cobalt/websocket/web_socket_event_interface.h +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef COBALT_WEBSOCKET_WEB_SOCKET_EVENT_INTERFACE_H_ -#define COBALT_WEBSOCKET_WEB_SOCKET_EVENT_INTERFACE_H_ - -#include -#include -#include -#include - -#include "base/basictypes.h" -#include "cobalt/websocket/web_socket_frame_container.h" -#include "cobalt/websocket/web_socket_handshake_helper.h" -#include "net/base/io_buffer.h" -#include "net/websockets/websocket_frame_parser.h" - -namespace cobalt { -namespace websocket { - -class WebSocketImpl; - -// This is the event interface that net takes to process all the events that -// happen during websocket connections. -// The lifetime of a connection process sometimes is shorter than the websocket -// so it has to be another class different from WebSocketChannel's creator -// which is our WebSocketImpl. -class CobaltWebSocketEventHandler : public net::WebSocketEventInterface { - public: - explicit CobaltWebSocketEventHandler(WebSocketImpl* creator) - : creator_(creator) { - DCHECK(creator); - } - - // Called when a URLRequest is created for handshaking. - void OnCreateURLRequest(net::URLRequest* request) override {} - - // Called in response to an AddChannelRequest. This means that a response has - // been received from the remote server. - void OnAddChannelResponse(const std::string& selected_subprotocol, - const std::string& extensions) override; - - // Called when a data frame has been received from the remote host and needs - // to be forwarded to the renderer process. - void OnDataFrame(bool fin, WebSocketMessageType type, - scoped_refptr buffer, - size_t buffer_size) override; - - // Called to provide more send quota for this channel to the renderer - // process. Currently the quota units are always bytes of message body - // data. In future it might depend on the type of multiplexing in use. - void OnFlowControl(int64_t quota) override {} - - // Called when the remote server has Started the WebSocket Closing - // Handshake. The client should not attempt to send any more messages after - // receiving this message. It will be followed by OnDropChannel() when the - // closing handshake is complete. - void OnClosingHandshake() override; - - // Called when the channel has been dropped, either due to a network close, a - // network error, or a protocol error. This may or may not be preceded by a - // call to OnClosingHandshake(). - // - // Warning: Both the |code| and |reason| are passed through to Javascript, so - // callers must take care not to provide details that could be useful to - // attackers attempting to use WebSocketMessageType to probe networks. - // - // |was_clean| should be true if the closing handshake completed successfully. - // - // The channel should not be used again after OnDropChannel() has been - // called. - // - // This function deletes the Channel. - void OnDropChannel(bool was_clean, uint16_t code, - const std::string& reason) override; - - // Called when the browser fails the channel, as specified in the spec. - // - // The channel should not be used again after OnFailChannel() has been - // called. - // - // This function deletes the Channel. - void OnFailChannel(const std::string& message) override; - - // Called when the browser starts the WebSocket Opening Handshake. - void OnStartOpeningHandshake( - std::unique_ptr request) override {} - - // Called when the browser finishes the WebSocket Opening Handshake. - void OnFinishOpeningHandshake( - std::unique_ptr response) override {} - - // Called on SSL Certificate Error during the SSL handshake. Should result in - // a call to either ssl_error_callbacks->ContinueSSLRequest() or - // ssl_error_callbacks->CancelSSLRequest(). Normally the implementation of - // this method will delegate to content::SSLManager::OnSSLCertificateError to - // make the actual decision. The callbacks must not be called after the - // WebSocketChannel has been destroyed. - void OnSSLCertificateError( - std::unique_ptr - ssl_error_callbacks, - const GURL& url, const net::SSLInfo& ssl_info, bool fatal) override; - - // Called when authentication is required. Returns a net error. The opening - // handshake is blocked when this function returns ERR_IO_PENDING. - // In that case calling |callback| resumes the handshake. |callback| can be - // called during the opening handshake. An implementation can rewrite - // |*credentials| (in the sync case) or provide new credentials (in the - // async case). - // Providing null credentials (nullopt in the sync case and nullptr in the - // async case) cancels authentication. Otherwise the new credentials are set - // and the opening handshake will be retried with the credentials. - int OnAuthRequired( - scoped_refptr auth_info, - scoped_refptr response_headers, - const net::HostPortPair& host_port_pair, - base::OnceCallback callback, - base::Optional* credentials) override; - - protected: - CobaltWebSocketEventHandler() {} - - private: - // For now, we forward most event to the impl to handle. We might want to - // move more thing - WebSocketImpl* creator_; - // This vector should store data of data frames from the same message. - typedef std::vector, size_t>> - FrameDataVector; - FrameDataVector frame_data_; - WebSocketMessageType message_type_ = - net::WebSocketFrameHeader::kOpCodeControlUnused; - DISALLOW_COPY_AND_ASSIGN(CobaltWebSocketEventHandler); -}; - -} // namespace websocket -} // namespace cobalt - -#endif // COBALT_WEBSOCKET_WEB_SOCKET_EVENT_INTERFACE_H_ diff --git a/cobalt/websocket/web_socket_frame_container.cc b/cobalt/websocket/web_socket_frame_container.cc deleted file mode 100644 index f11d19ef83f1..000000000000 --- a/cobalt/websocket/web_socket_frame_container.cc +++ /dev/null @@ -1,196 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cobalt/websocket/web_socket_frame_container.h" - -namespace { - -bool IsFinalChunk(const net::WebSocketFrameChunk* const chunk) { - return (chunk && chunk->final_chunk); -} - -bool IsContinuationOpCode(net::WebSocketFrameHeader::OpCode op_code) { - switch (op_code) { - case net::WebSocketFrameHeader::kOpCodeContinuation: - return true; - case net::WebSocketFrameHeader::kOpCodeText: - case net::WebSocketFrameHeader::kOpCodeBinary: - case net::WebSocketFrameHeader::kOpCodePing: - case net::WebSocketFrameHeader::kOpCodePong: - case net::WebSocketFrameHeader::kOpCodeClose: - return false; - default: - NOTREACHED() << "Invalid op_code " << op_code; - } - - return false; -} - -bool IsDataOpCode(net::WebSocketFrameHeader::OpCode op_code) { - switch (op_code) { - case net::WebSocketFrameHeader::kOpCodeText: - case net::WebSocketFrameHeader::kOpCodeBinary: - return true; - case net::WebSocketFrameHeader::kOpCodePing: - case net::WebSocketFrameHeader::kOpCodePong: - case net::WebSocketFrameHeader::kOpCodeClose: - case net::WebSocketFrameHeader::kOpCodeContinuation: - return false; - default: - NOTREACHED() << "Invalid op_code " << op_code; - } - - return false; -} - -inline bool IsControlOpCode(net::WebSocketFrameHeader::OpCode op_code) { - switch (op_code) { - case net::WebSocketFrameHeader::kOpCodePing: - case net::WebSocketFrameHeader::kOpCodePong: - case net::WebSocketFrameHeader::kOpCodeClose: - return true; - case net::WebSocketFrameHeader::kOpCodeText: - case net::WebSocketFrameHeader::kOpCodeBinary: - case net::WebSocketFrameHeader::kOpCodeContinuation: - return false; - default: - NOTREACHED() << "Invalid op_code " << op_code; - } - - return false; -} - -} // namespace - -namespace cobalt { -namespace websocket { - -void WebSocketFrameContainer::clear() { - for (const_iterator it = chunks_.begin(); it != chunks_.end(); ++it) { - delete *it; - } - chunks_.clear(); - frame_completed_ = false; - payload_size_bytes_ = 0; - expected_payload_size_bytes_ = 0; -} - -bool WebSocketFrameContainer::IsControlFrame() const { - const net::WebSocketFrameHeader* header = GetHeader(); - if (!header) { - return false; - } - return IsControlOpCode(header->opcode); -} - -bool WebSocketFrameContainer::IsContinuationFrame() const { - const net::WebSocketFrameHeader* header = GetHeader(); - if (!header) { - return false; - } - return IsContinuationOpCode(header->opcode); -} - -bool WebSocketFrameContainer::IsDataFrame() const { - const net::WebSocketFrameHeader* header = GetHeader(); - if (!header) { - return false; - } - return IsDataOpCode(header->opcode); -} - -WebSocketFrameContainer::ErrorCode WebSocketFrameContainer::Take( - const net::WebSocketFrameChunk* chunk) { - DCHECK(chunk); - - WebSocketFrameContainer::ErrorCode error_code = kErrorNone; - - do { - if (IsFrameComplete()) { - error_code = kErrorFrameAlreadyComplete; - break; - } - - const net::WebSocketFrameHeader* const chunk_header = chunk->header.get(); - - bool first_chunk = chunks_.empty(); - bool has_frame_header = (chunk_header != NULL); - if (first_chunk) { - if (has_frame_header) { - if (chunk_header->payload_length > kMaxFramePayloadInBytes) { - error_code = kErrorMaxFrameSizeViolation; - break; - } else { - COMPILE_ASSERT(kuint32max > kMaxFramePayloadInBytes, - max_frame_too_big); - expected_payload_size_bytes_ = - static_cast(chunk_header->payload_length); - } - } else { - error_code = kErrorFirstChunkMissingHeader; - break; - } - } else { - if (has_frame_header) { - error_code = kErrorHasExtraHeader; - break; - } - } - - // Cases when this should succeed: - // 1. first_chunk has a header (both are true) - // 2. non first_chunk has does not have header (both are false) - // Fun fact: boolean equality is the same as a boolean XNOR. - DCHECK_EQ(first_chunk, has_frame_header); - - net::IOBufferWithSize* data = chunk->data.get(); - - if (data) { - int chunk_data_size = data->size(); - std::size_t new_payload_size = payload_size_bytes_ + chunk_data_size; - - if (new_payload_size > kMaxFramePayloadInBytes) { - // This can only happen if the header "lied" about the payload_length. - error_code = kErrorMaxFrameSizeViolation; - break; - } - - if (chunk->final_chunk) { - if (new_payload_size < expected_payload_size_bytes_) { - error_code = kErrorPayloadSizeSmallerThanHeader; - break; - } else if (new_payload_size > expected_payload_size_bytes_) { - error_code = kErrorPayloadSizeLargerThanHeader; - break; - } - } - - payload_size_bytes_ += chunk_data_size; - } - - chunks_.push_back(chunk); - frame_completed_ |= IsFinalChunk(chunk); - } while (0); - - if (error_code != kErrorNone) { - // We didn't take ownership, so let's delete it, so that the caller can - // always assume that this code takes ownership of the pointer. - delete chunk; - } - - return error_code; -} - -} // namespace websocket -} // namespace cobalt diff --git a/cobalt/websocket/web_socket_frame_container.h b/cobalt/websocket/web_socket_frame_container.h deleted file mode 100644 index 7553abdef83e..000000000000 --- a/cobalt/websocket/web_socket_frame_container.h +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef COBALT_WEBSOCKET_WEB_SOCKET_FRAME_CONTAINER_H_ -#define COBALT_WEBSOCKET_WEB_SOCKET_FRAME_CONTAINER_H_ - -#include -#include -#include -#include - -#include "base/basictypes.h" -#include "base/notreached.h" -#include "net/base/io_buffer.h" -#include "net/websockets/websocket_frame.h" - -namespace cobalt { -namespace websocket { - -const size_t kMaxFramePayloadInBytes = 4 * 1024 * 1024; - -class WebSocketFrameContainer { - public: - typedef std::deque WebSocketFrameChunks; - typedef WebSocketFrameChunks::iterator iterator; - typedef WebSocketFrameChunks::const_iterator const_iterator; - - WebSocketFrameContainer() - : frame_completed_(false), - payload_size_bytes_(0), - expected_payload_size_bytes_(0) {} - ~WebSocketFrameContainer() { clear(); } - - void clear(); - - const net::WebSocketFrameHeader* GetHeader() const { - if (empty()) { - return NULL; - } - - return (*begin())->header.get(); - } - - bool IsControlFrame() const; - bool IsDataFrame() const; - bool IsContinuationFrame() const; - - enum ErrorCode { - kErrorNone, - kErrorMaxFrameSizeViolation, - kErrorFirstChunkMissingHeader, - kErrorHasExtraHeader, - kErrorFrameAlreadyComplete, - kErrorPayloadSizeSmallerThanHeader, - kErrorPayloadSizeLargerThanHeader - }; - - bool IsFrameComplete() const { return frame_completed_; } - - bool IsFinalFrame() const { - const net::WebSocketFrameHeader* const header = GetHeader(); - if (!header) { - return false; - } - return (header && header->final); - } - - // Note that this takes always ownership of the chunk. - // Should only be called if IsFrameComplete() is false. - // Note that if there is an error produced in the function, it will - // leave the state of this object unchanged. - ErrorCode Take(const net::WebSocketFrameChunk* chunk); - - iterator begin() { return chunks_.begin(); } - iterator end() { return chunks_.end(); } - const_iterator begin() const { return chunks_.begin(); } - const_iterator end() const { return chunks_.end(); } - const_iterator cbegin() const { return chunks_.begin(); } - const_iterator cend() const { return chunks_.end(); } - bool empty() const { return begin() == end(); } - - std::size_t GetCurrentPayloadSizeBytes() const { return payload_size_bytes_; } - std::size_t GetChunkCount() const { return chunks_.size(); } - - void swap(WebSocketFrameContainer& other) { - std::swap(frame_completed_, other.frame_completed_); - std::swap(payload_size_bytes_, other.payload_size_bytes_); - std::swap(expected_payload_size_bytes_, other.expected_payload_size_bytes_); - chunks_.swap(other.chunks_); - } - - // Returns false if op_code was not found, and returns true otherwise. - bool GetFrameOpCode(net::WebSocketFrameHeader::OpCode* op_code) const { - DCHECK(op_code); - if (empty()) { - return false; - } - - const net::WebSocketFrameChunk* first_chunk(*begin()); - DCHECK(first_chunk); - const std::unique_ptr& first_chunk_header = - first_chunk->header; - if (!first_chunk_header) { - NOTREACHED() << "No header found in the first chunk."; - return false; - } - - *op_code = first_chunk_header->opcode; - return true; - } - - private: - // Note: If you add a field, please remember to update swap() above. - bool frame_completed_; - std::size_t payload_size_bytes_; - std::size_t expected_payload_size_bytes_; - WebSocketFrameChunks chunks_; -}; - -} // namespace websocket -} // namespace cobalt - -#endif // COBALT_WEBSOCKET_WEB_SOCKET_FRAME_CONTAINER_H_ diff --git a/cobalt/websocket/web_socket_frame_container_test.cc b/cobalt/websocket/web_socket_frame_container_test.cc deleted file mode 100644 index 92986a016d44..000000000000 --- a/cobalt/websocket/web_socket_frame_container_test.cc +++ /dev/null @@ -1,437 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cobalt/websocket/web_socket_frame_container.h" - -#include - -#include "base/memory/ref_counted.h" -#include "net/base/io_buffer.h" -#include "testing/gtest/include/gtest/gtest.h" - -namespace cobalt { -namespace websocket { - -class WebSocketFrameContainerTest : public ::testing::Test { - protected: - WebSocketFrameContainer frame_container_; -}; - -TEST_F(WebSocketFrameContainerTest, TestConstruction) { - EXPECT_TRUE(frame_container_.empty()); - EXPECT_TRUE(frame_container_.begin() == frame_container_.end()); - EXPECT_TRUE(frame_container_.cbegin() == frame_container_.cend()); - EXPECT_EQ(frame_container_.GetChunkCount(), 0UL); - EXPECT_EQ(frame_container_.GetCurrentPayloadSizeBytes(), 0UL); - EXPECT_EQ(frame_container_.GetHeader(), - static_cast(NULL)); - - EXPECT_FALSE(frame_container_.IsControlFrame()); - EXPECT_FALSE(frame_container_.IsDataFrame()); - EXPECT_FALSE(frame_container_.IsContinuationFrame()); - - net::WebSocketFrameHeader::OpCode op = net::WebSocketFrameHeader::kOpCodeText; - EXPECT_EQ(frame_container_.GetFrameOpCode(&op), false); -} - -TEST_F(WebSocketFrameContainerTest, TestClear) { - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = true; - chunk1->header->payload_length = 0; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - chunk1->final_chunk = true; - WebSocketFrameContainer::ErrorCode error_code1 = - frame_container_.Take(chunk1); - EXPECT_EQ(error_code1, WebSocketFrameContainer::kErrorNone); - - frame_container_.clear(); - - EXPECT_TRUE(frame_container_.empty()); - EXPECT_TRUE(frame_container_.begin() == frame_container_.end()); - EXPECT_TRUE(frame_container_.cbegin() == frame_container_.cend()); - EXPECT_EQ(frame_container_.GetChunkCount(), 0UL); - EXPECT_EQ(frame_container_.GetCurrentPayloadSizeBytes(), 0UL); - EXPECT_EQ(frame_container_.GetHeader(), - static_cast(NULL)); - - EXPECT_FALSE(frame_container_.IsControlFrame()); - EXPECT_FALSE(frame_container_.IsDataFrame()); - EXPECT_FALSE(frame_container_.IsContinuationFrame()); - - net::WebSocketFrameHeader::OpCode op = net::WebSocketFrameHeader::kOpCodeText; - EXPECT_EQ(frame_container_.GetFrameOpCode(&op), false); -} - -TEST_F(WebSocketFrameContainerTest, TestFirstChunkMissingHeader) { - net::WebSocketFrameChunk* chunk = new net::WebSocketFrameChunk(); - WebSocketFrameContainer::ErrorCode error_code = frame_container_.Take(chunk); - EXPECT_EQ(error_code, WebSocketFrameContainer::kErrorFirstChunkMissingHeader); -} - -TEST_F(WebSocketFrameContainerTest, TestHasExtraHeader) { - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = true; - chunk1->header->payload_length = 0; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - WebSocketFrameContainer::ErrorCode error_code1 = - frame_container_.Take(chunk1); - EXPECT_EQ(error_code1, WebSocketFrameContainer::kErrorNone); - - net::WebSocketFrameChunk* chunk2 = new net::WebSocketFrameChunk(); - chunk2->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk2->header->final = true; - chunk2->header->payload_length = 0; - chunk2->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - WebSocketFrameContainer::ErrorCode error_code2 = - frame_container_.Take(chunk2); - EXPECT_EQ(error_code2, WebSocketFrameContainer::kErrorHasExtraHeader); -} - -TEST_F(WebSocketFrameContainerTest, TestFrameAlreadyCompleteNoHeader) { - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = true; - chunk1->header->payload_length = 0; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - chunk1->final_chunk = true; - WebSocketFrameContainer::ErrorCode error_code1 = - frame_container_.Take(chunk1); - EXPECT_EQ(error_code1, WebSocketFrameContainer::kErrorNone); - - net::WebSocketFrameChunk* chunk2 = new net::WebSocketFrameChunk(); - chunk2->final_chunk = true; - WebSocketFrameContainer::ErrorCode error_code2 = - frame_container_.Take(chunk2); - EXPECT_EQ(error_code2, WebSocketFrameContainer::kErrorFrameAlreadyComplete); -} - -TEST_F(WebSocketFrameContainerTest, TestFrameAlreadyCompleteHeader) { - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = true; - chunk1->header->payload_length = 0; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - chunk1->final_chunk = true; - WebSocketFrameContainer::ErrorCode error_code1 = - frame_container_.Take(chunk1); - EXPECT_EQ(error_code1, WebSocketFrameContainer::kErrorNone); - - net::WebSocketFrameChunk* chunk2 = new net::WebSocketFrameChunk(); - chunk2->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk2->header->final = true; - chunk2->header->payload_length = 0; - chunk2->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - chunk2->final_chunk = true; - WebSocketFrameContainer::ErrorCode error_code2 = - frame_container_.Take(chunk2); - EXPECT_EQ(error_code2, WebSocketFrameContainer::kErrorFrameAlreadyComplete); -} - -TEST_F(WebSocketFrameContainerTest, TestFrameTooBig) { - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = true; - chunk1->header->payload_length = 40 * 1024 * 1024; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeText; - WebSocketFrameContainer::ErrorCode error_code1 = - frame_container_.Take(chunk1); - EXPECT_EQ(error_code1, WebSocketFrameContainer::kErrorMaxFrameSizeViolation); -} - -TEST_F(WebSocketFrameContainerTest, TestFrameTooBigLieAboutSize) { - net::WebSocketFrameChunk* chunk = new net::WebSocketFrameChunk(); - chunk->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk->header->final = true; - chunk->header->payload_length = 40; - chunk->header->opcode = net::WebSocketFrameHeader::kOpCodePing; - chunk->data = base::WrapRefCounted( - new net::IOBufferWithSize(40 * 1024 * 1024)); - WebSocketFrameContainer::ErrorCode error_code = frame_container_.Take(chunk); - EXPECT_EQ(error_code, WebSocketFrameContainer::kErrorMaxFrameSizeViolation); -} - -TEST_F(WebSocketFrameContainerTest, PayloadTooSmall) { - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = false; - chunk1->header->payload_length = 50; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - chunk1->data = base::WrapRefCounted( - new net::IOBufferWithSize(20)); - EXPECT_EQ(frame_container_.Take(chunk1), WebSocketFrameContainer::kErrorNone); - - net::WebSocketFrameChunk* chunk2 = new net::WebSocketFrameChunk(); - chunk2->data = base::WrapRefCounted( - new net::IOBufferWithSize(18)); - EXPECT_EQ(frame_container_.Take(chunk2), WebSocketFrameContainer::kErrorNone); - net::WebSocketFrameChunk* chunk3 = new net::WebSocketFrameChunk(); - chunk3->final_chunk = true; - chunk3->data = - base::WrapRefCounted(new net::IOBufferWithSize(2)); - EXPECT_EQ(frame_container_.Take(chunk3), - WebSocketFrameContainer::kErrorPayloadSizeSmallerThanHeader); -} - -TEST_F(WebSocketFrameContainerTest, FrameComplete) { - EXPECT_FALSE(frame_container_.IsFrameComplete()); - - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = false; - chunk1->header->payload_length = 50; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - chunk1->data = base::WrapRefCounted( - new net::IOBufferWithSize(20)); - EXPECT_EQ(frame_container_.Take(chunk1), WebSocketFrameContainer::kErrorNone); - EXPECT_FALSE(frame_container_.IsFrameComplete()); - - net::WebSocketFrameChunk* chunk2 = new net::WebSocketFrameChunk(); - chunk2->data = base::WrapRefCounted( - new net::IOBufferWithSize(18)); - EXPECT_EQ(frame_container_.Take(chunk2), WebSocketFrameContainer::kErrorNone); - EXPECT_FALSE(frame_container_.IsFrameComplete()); - - net::WebSocketFrameChunk* chunk3 = new net::WebSocketFrameChunk(); - chunk3->final_chunk = true; - chunk3->data = base::WrapRefCounted( - new net::IOBufferWithSize(12)); - EXPECT_EQ(frame_container_.Take(chunk3), WebSocketFrameContainer::kErrorNone); - - EXPECT_TRUE(frame_container_.IsFrameComplete()); -} - -TEST_F(WebSocketFrameContainerTest, PayloadTooBig) { - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = false; - chunk1->header->payload_length = 50; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - chunk1->data = base::WrapRefCounted( - new net::IOBufferWithSize(20)); - EXPECT_EQ(frame_container_.Take(chunk1), WebSocketFrameContainer::kErrorNone); - - net::WebSocketFrameChunk* chunk2 = new net::WebSocketFrameChunk(); - chunk2->data = base::WrapRefCounted( - new net::IOBufferWithSize(18)); - EXPECT_EQ(frame_container_.Take(chunk2), WebSocketFrameContainer::kErrorNone); - - net::WebSocketFrameChunk* chunk3 = new net::WebSocketFrameChunk(); - chunk3->data = base::WrapRefCounted( - new net::IOBufferWithSize(22)); - chunk3->final_chunk = true; - EXPECT_EQ(frame_container_.Take(chunk3), - WebSocketFrameContainer::kErrorPayloadSizeLargerThanHeader); -} - -TEST_F(WebSocketFrameContainerTest, TestIsControlFrame) { - struct ControlFrameTestCase { - net::WebSocketFrameHeader::OpCode op_code; - } control_frame_test_cases[] = { - {net::WebSocketFrameHeader::kOpCodePing}, - {net::WebSocketFrameHeader::kOpCodePong}, - {net::WebSocketFrameHeader::kOpCodeClose}, - }; - - for (std::size_t i(0); i != ARRAYSIZE_UNSAFE(control_frame_test_cases); ++i) { - const ControlFrameTestCase& test_case(control_frame_test_cases[i]); - WebSocketFrameContainer frame_container; - net::WebSocketFrameChunk* chunk = new net::WebSocketFrameChunk(); - chunk->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk->header->final = true; - chunk->header->payload_length = 0; - chunk->header->opcode = test_case.op_code; - EXPECT_EQ(frame_container.Take(chunk), WebSocketFrameContainer::kErrorNone); - - EXPECT_FALSE(frame_container.IsContinuationFrame()); - EXPECT_TRUE(frame_container.IsControlFrame()); - EXPECT_FALSE(frame_container.IsDataFrame()); - } -} - -TEST_F(WebSocketFrameContainerTest, TestIsTextDataFrame) { - net::WebSocketFrameChunk* chunk = new net::WebSocketFrameChunk(); - chunk->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk->header->final = true; - chunk->header->payload_length = 0; - chunk->header->opcode = net::WebSocketFrameHeader::kOpCodeText; - EXPECT_EQ(frame_container_.Take(chunk), WebSocketFrameContainer::kErrorNone); - - EXPECT_FALSE(frame_container_.IsContinuationFrame()); - EXPECT_FALSE(frame_container_.IsControlFrame()); - EXPECT_TRUE(frame_container_.IsDataFrame()); -} - -TEST_F(WebSocketFrameContainerTest, TestIsBinaryDataFrame) { - net::WebSocketFrameChunk* chunk = new net::WebSocketFrameChunk(); - chunk->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk->header->final = true; - chunk->header->payload_length = 0; - chunk->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - EXPECT_EQ(frame_container_.Take(chunk), WebSocketFrameContainer::kErrorNone); - - EXPECT_FALSE(frame_container_.IsContinuationFrame()); - EXPECT_FALSE(frame_container_.IsControlFrame()); - EXPECT_TRUE(frame_container_.IsDataFrame()); -} - -TEST_F(WebSocketFrameContainerTest, TestIsContinuationFrame) { - net::WebSocketFrameChunk* chunk = new net::WebSocketFrameChunk(); - chunk->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk->header->final = true; - chunk->header->payload_length = 0; - chunk->header->opcode = net::WebSocketFrameHeader::kOpCodeContinuation; - EXPECT_EQ(frame_container_.Take(chunk), WebSocketFrameContainer::kErrorNone); - - EXPECT_TRUE(frame_container_.IsContinuationFrame()); - EXPECT_FALSE(frame_container_.IsControlFrame()); - EXPECT_FALSE(frame_container_.IsDataFrame()); -} - -TEST_F(WebSocketFrameContainerTest, TestIsFrameComplete) { - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = false; - chunk1->header->payload_length = 0; - chunk1->final_chunk = false; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeContinuation; - EXPECT_EQ(frame_container_.Take(chunk1), WebSocketFrameContainer::kErrorNone); - - EXPECT_FALSE(frame_container_.IsFrameComplete()); - - net::WebSocketFrameChunk* chunk2 = new net::WebSocketFrameChunk(); - chunk2->final_chunk = true; - EXPECT_EQ(frame_container_.Take(chunk2), WebSocketFrameContainer::kErrorNone); - - EXPECT_TRUE(frame_container_.IsFrameComplete()); -} - -TEST_F(WebSocketFrameContainerTest, TestGetHeader) { - EXPECT_EQ(frame_container_.GetHeader(), - static_cast(NULL)); - - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = false; - chunk1->header->payload_length = 0; - chunk1->final_chunk = false; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeContinuation; - - EXPECT_EQ(frame_container_.Take(chunk1), WebSocketFrameContainer::kErrorNone); - - EXPECT_NE(frame_container_.GetHeader(), - static_cast(NULL)); -} - -TEST_F(WebSocketFrameContainerTest, FinalFrameTest) { - { - WebSocketFrameContainer frame_container; - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = false; - chunk1->header->payload_length = 0; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeText; - EXPECT_EQ(frame_container.Take(chunk1), - WebSocketFrameContainer::kErrorNone); - EXPECT_FALSE(frame_container.IsFinalFrame()); - } - { - WebSocketFrameContainer frame_container; - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = true; - chunk1->header->payload_length = 0; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeText; - EXPECT_EQ(frame_container.Take(chunk1), - WebSocketFrameContainer::kErrorNone); - EXPECT_TRUE(frame_container.IsFinalFrame()); - } -} - -TEST_F(WebSocketFrameContainerTest, CheckChunkCount) { - EXPECT_EQ(frame_container_.GetChunkCount(), 0UL); - - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = false; - chunk1->header->payload_length = 0; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeText; - EXPECT_EQ(frame_container_.Take(chunk1), WebSocketFrameContainer::kErrorNone); - - EXPECT_EQ(frame_container_.GetChunkCount(), 1UL); - - net::WebSocketFrameChunk* chunk2 = new net::WebSocketFrameChunk(); - EXPECT_EQ(frame_container_.Take(chunk2), WebSocketFrameContainer::kErrorNone); - - EXPECT_EQ(frame_container_.GetChunkCount(), 2UL); - - net::WebSocketFrameChunk* chunk3 = new net::WebSocketFrameChunk(); - EXPECT_EQ(frame_container_.Take(chunk3), WebSocketFrameContainer::kErrorNone); - - EXPECT_EQ(frame_container_.GetChunkCount(), 3UL); -} - -TEST_F(WebSocketFrameContainerTest, CheckPayloadSize) { - EXPECT_EQ(frame_container_.GetCurrentPayloadSizeBytes(), 0UL); - - net::WebSocketFrameChunk* chunk1 = new net::WebSocketFrameChunk(); - chunk1->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk1->header->final = false; - chunk1->header->payload_length = 50; - chunk1->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - chunk1->data = base::WrapRefCounted( - new net::IOBufferWithSize(20)); - EXPECT_EQ(frame_container_.Take(chunk1), WebSocketFrameContainer::kErrorNone); - - EXPECT_EQ(frame_container_.GetCurrentPayloadSizeBytes(), 20UL); - - net::WebSocketFrameChunk* chunk2 = new net::WebSocketFrameChunk(); - chunk2->data = base::WrapRefCounted( - new net::IOBufferWithSize(18)); - EXPECT_EQ(frame_container_.Take(chunk2), WebSocketFrameContainer::kErrorNone); - - EXPECT_EQ(frame_container_.GetCurrentPayloadSizeBytes(), 38UL); - - net::WebSocketFrameChunk* chunk3 = new net::WebSocketFrameChunk(); - chunk3->data = base::WrapRefCounted( - new net::IOBufferWithSize(12)); - EXPECT_EQ(frame_container_.Take(chunk3), WebSocketFrameContainer::kErrorNone); - - EXPECT_EQ(frame_container_.GetCurrentPayloadSizeBytes(), 50UL); -} - -} // namespace websocket -} // namespace cobalt diff --git a/cobalt/websocket/web_socket_handshake_helper.cc b/cobalt/websocket/web_socket_handshake_helper.cc deleted file mode 100644 index cc388f937907..000000000000 --- a/cobalt/websocket/web_socket_handshake_helper.cc +++ /dev/null @@ -1,336 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -/* Modifications: Copyright 2017 The Cobalt Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cobalt/websocket/web_socket_handshake_helper.h" - -#include -#include - -#include "base/logging.h" -#include "base/strings/string_number_conversions.h" -#include "base/strings/string_piece.h" -#include "base/strings/string_util.h" -#include "net/http/http_request_headers.h" -#include "net/http/http_status_code.h" -#include "net/websockets/websocket_extension.h" -#include "net/websockets/websocket_extension_parser.h" -#include "net/websockets/websocket_frame.h" -#include "net/websockets/websocket_handshake_challenge.h" -#include "net/websockets/websocket_handshake_constants.h" -#include "starboard/system.h" - -namespace { -// Following enum and anonymous functions are adapted from Chromium net source, -// commit id: 7321c9e7ee80ef15b65c2f39646a5a2d22a9c950 in -// src/net/websockets/websocket_basic_handshake_stream.cc. - -enum GetHeaderResult { - GET_HEADER_OK, - GET_HEADER_MISSING, - GET_HEADER_MULTIPLE, -}; - -GetHeaderResult GetSingleHeaderValue(const net::HttpResponseHeaders* headers, - const base::StringPiece& name, - std::string* value) { - size_t iter = 0; - size_t num_values = 0; - std::string temp_value; - std::string name_string = name.as_string(); - while (headers->EnumerateHeader(&iter, name_string, &temp_value)) { - if (++num_values > 1) return GET_HEADER_MULTIPLE; - *value = temp_value; - } - return num_values > 0 ? GET_HEADER_OK : GET_HEADER_MISSING; -} - -std::string MissingHeaderMessage(const std::string& header_name) { - return std::string("'") + header_name + "' header is missing"; -} - -std::string MultipleHeaderValuesMessage(const std::string& header_name) { - return std::string("'") + header_name + - "' header must not appear more than once in a response"; -} - -bool ValidateHeaderHasSingleValue(GetHeaderResult result, - const std::string& header_name, - std::string* failure_message) { - if (result == GET_HEADER_MISSING) { - *failure_message = MissingHeaderMessage(header_name); - return false; - } - if (result == GET_HEADER_MULTIPLE) { - *failure_message = MultipleHeaderValuesMessage(header_name); - return false; - } - DCHECK_EQ(result, GET_HEADER_OK); - return true; -} - -bool ValidateUpgrade(const net::HttpResponseHeaders* headers, - std::string* failure_message) { - std::string value; - GetHeaderResult result = - GetSingleHeaderValue(headers, net::websockets::kUpgrade, &value); - if (!ValidateHeaderHasSingleValue(result, net::websockets::kUpgrade, - failure_message)) { - return false; - } - - if (!base::EqualsCaseInsensitiveASCII(value, - net::websockets::kWebSocketLowercase)) { - *failure_message = "'Upgrade' header value is not 'WebSocket': " + value; - return false; - } - return true; -} - -bool ValidateSecWebSocketAccept(const net::HttpResponseHeaders* headers, - const std::string& expected, - std::string* failure_message) { - std::string actual; - GetHeaderResult result = GetSingleHeaderValue( - headers, net::websockets::kSecWebSocketAccept, &actual); - if (!ValidateHeaderHasSingleValue( - result, net::websockets::kSecWebSocketAccept, failure_message)) { - return false; - } - - if (expected != actual) { - *failure_message = "Incorrect 'Sec-WebSocket-Accept' header value"; - return false; - } - return true; -} - -bool ValidateConnection(const net::HttpResponseHeaders* headers, - std::string* failure_message) { - // Connection header is permitted to contain other tokens. - if (!headers->HasHeader(net::HttpRequestHeaders::kConnection)) { - *failure_message = - MissingHeaderMessage(net::HttpRequestHeaders::kConnection); - return false; - } - if (!headers->HasHeaderValue(net::HttpRequestHeaders::kConnection, - net::websockets::kUpgrade)) { - *failure_message = "'Connection' header value must contain 'Upgrade'"; - return false; - } - return true; -} - -bool ValidateSubProtocol( - const net::HttpResponseHeaders* headers, - const std::vector& requested_sub_protocols, - std::string* sub_protocol, std::string* failure_message) { - size_t iter = 0; - std::string value; - std::set requested_set(requested_sub_protocols.begin(), - requested_sub_protocols.end()); - int count = 0; - bool has_multiple_protocols = false; - bool has_invalid_protocol = false; - - while (!has_invalid_protocol || !has_multiple_protocols) { - std::string temp_value; - if (!headers->EnumerateHeader(&iter, net::websockets::kSecWebSocketProtocol, - &temp_value)) - break; - value = temp_value; - if (requested_set.count(value) == 0) has_invalid_protocol = true; - if (++count > 1) has_multiple_protocols = true; - } - - if (has_multiple_protocols) { - *failure_message = - MultipleHeaderValuesMessage(net::websockets::kSecWebSocketProtocol); - return false; - } else if (count > 0 && requested_sub_protocols.size() == 0) { - *failure_message = std::string( - "Response must not include 'Sec-WebSocket-Protocol' " - "header if not present in request: ") + - value; - return false; - } else if (has_invalid_protocol) { - *failure_message = "'Sec-WebSocket-Protocol' header value '" + value + - "' in response does not match any of sent values"; - return false; - } else if (requested_sub_protocols.size() > 0 && count == 0) { - *failure_message = - "Sent non-empty 'Sec-WebSocket-Protocol' header " - "but no response was received"; - return false; - } - *sub_protocol = value; - return true; -} - -bool ValidateExtensions(const net::HttpResponseHeaders* headers, - std::string* failure_message) { - size_t iter = 0; - std::string header_value; - while (headers->EnumerateHeader( - &iter, net::websockets::kSecWebSocketExtensions, &header_value)) { - net::WebSocketExtensionParser parser; - if (!parser.Parse(header_value)) { - *failure_message = - "'Sec-WebSocket-Extensions' header value is " - "rejected by the parser: " + - header_value; - return false; - } - - const std::vector& extensions = - parser.extensions(); - if (extensions.empty() == false) { - *failure_message = "Cobalt does not support any websocket extensions"; - return false; - } - } - return true; -} - -cobalt::websocket::SecWebSocketKey GenerateRandomSecWebSocketKey() { - using cobalt::websocket::SecWebSocketKey; - SecWebSocketKey::SecWebSocketKeyBytes random_data; - SbSystemGetRandomData(&random_data, - sizeof(SecWebSocketKey::SecWebSocketKeyBytes)); - cobalt::websocket::SecWebSocketKey key(random_data); - return key; -} - -} // namespace - -namespace cobalt { -namespace websocket { - -void WebSocketHandshakeHelper::GenerateHandshakeRequest( - const GURL& connect_url, const std::string& origin, - const std::vector& desired_sub_protocols, - std::string* handshake_request) { - DCHECK(handshake_request); - GenerateSecWebSocketKey(); - - int effective_port = connect_url.IntPort(); - std::string host_header(connect_url.host()); - if (effective_port != url::PORT_UNSPECIFIED) { - host_header += ":" + connect_url.port(); - } - - std::string& header_string(*handshake_request); - header_string.clear(); - header_string.reserve(256); // This avoids reallocations for most cases. - - // Note: Concatenating string literals and std::string objects are separated - // to avoid creating unnecessary std::string objects. - header_string += "GET "; - header_string += connect_url.path(); - if (connect_url.has_query()) { - header_string += "?"; - header_string += connect_url.query(); - } - header_string += " HTTP/1.1\r\n"; - header_string += "Host:"; - header_string += host_header; - header_string += "\r\n"; - header_string += - "Connection:Upgrade\r\n" - "Pragma:no-cache\r\n" - "Cache-Control:no-cache\r\n" - "Upgrade:websocket\r\n" - "Sec-WebSocket-Extensions:\r\n" - "Sec-WebSocket-Version:13\r\n"; - header_string += "Origin:"; - header_string += origin; - header_string += "\r\n"; - header_string += "Sec-WebSocket-Key:"; - header_string += sec_websocket_key_.GetKeyEncodedInBase64(); - header_string += "\r\n"; - header_string += "User-Agent:"; - header_string += user_agent_; - header_string += "\r\n"; - - if (!desired_sub_protocols.empty()) { - header_string += "Sec-WebSocket-Protocol:"; - header_string += base::JoinString(desired_sub_protocols, ","); - header_string += "\r\n"; - } - - header_string += "\r\n"; - - requested_sub_protocols_ = desired_sub_protocols; - - const std::string& sec_websocket_key_base64( - sec_websocket_key_.GetKeyEncodedInBase64()); - handshake_challenge_response_ = - net::ComputeSecWebSocketAccept(sec_websocket_key_base64); -} - -WebSocketHandshakeHelper::WebSocketHandshakeHelper( - const base::StringPiece user_agent) - : sec_websocket_key_generator_function_(&GenerateRandomSecWebSocketKey), - user_agent_(user_agent.data(), user_agent.size()) {} - -WebSocketHandshakeHelper::WebSocketHandshakeHelper( - const base::StringPiece user_agent, - SecWebSocketKeyGeneratorFunction sec_websocket_key_generator_function) - : sec_websocket_key_generator_function_( - sec_websocket_key_generator_function), - user_agent_(user_agent.data(), user_agent.size()) {} - -bool WebSocketHandshakeHelper::IsResponseValid( - const net::HttpResponseHeaders& headers, std::string* failure_message) { - DCHECK(failure_message); - int response_code = headers.response_code(); - - // Check response code first. - if (response_code != net::HTTP_SWITCHING_PROTOCOLS) { - *failure_message = - "Invalid response code " + base::IntToString(response_code); - return false; - } - - if (!ValidateUpgrade(&headers, failure_message)) { - return false; - } - if (!ValidateSecWebSocketAccept(&headers, handshake_challenge_response_, - failure_message)) { - return false; - } - if (!ValidateConnection(&headers, failure_message)) { - return false; - } - if (!ValidateSubProtocol(&headers, requested_sub_protocols_, - &selected_subprotocol_, failure_message)) { - return false; - } - // Cobalt does not support extensions, so we just make sure that none are - // being selected. - if (!ValidateExtensions(&headers, failure_message)) { - return false; - } - - failure_message->clear(); - return true; -} - -void WebSocketHandshakeHelper::GenerateSecWebSocketKey() { - sec_websocket_key_ = sec_websocket_key_generator_function_(); -} - -} // namespace websocket -} // namespace cobalt diff --git a/cobalt/websocket/web_socket_handshake_helper.h b/cobalt/websocket/web_socket_handshake_helper.h deleted file mode 100644 index 61e46b6ab903..000000000000 --- a/cobalt/websocket/web_socket_handshake_helper.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2017 The Cobalt Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COBALT_WEBSOCKET_WEB_SOCKET_HANDSHAKE_HELPER_H_ -#define COBALT_WEBSOCKET_WEB_SOCKET_HANDSHAKE_HELPER_H_ - -#include -#include - -#include "base/gtest_prod_util.h" -#include "base/strings/string_piece.h" -#include "cobalt/websocket/sec_web_socket_key.h" -#include "net/http/http_response_headers.h" -#include "net/websockets/websocket_frame.h" -#include "url/gurl.h" - -namespace cobalt { -namespace websocket { - -class WebSocketHandshakeHelper { - public: - typedef SecWebSocketKey (*SecWebSocketKeyGeneratorFunction)(); - - explicit WebSocketHandshakeHelper(const base::StringPiece user_agent); - - // Overriding the key-generation function is useful for testing. - WebSocketHandshakeHelper( - const base::StringPiece user_agent, - SecWebSocketKeyGeneratorFunction sec_websocket_key_generator_function); - - void GenerateHandshakeRequest( - const GURL& connect_url, const std::string& origin, - const std::vector& desired_sub_protocols, - std::string* handshake_request); - - bool IsResponseValid(const net::HttpResponseHeaders& headers, - std::string* failure_message); - - const std::string& GetSelectedSubProtocol() const { - return selected_subprotocol_; - } - - private: - // Having key generator function passed is slightly slower, but very useful - // for testing. - SecWebSocketKeyGeneratorFunction sec_websocket_key_generator_function_; - - std::string user_agent_; - std::string handshake_challenge_response_; - SecWebSocketKey sec_websocket_key_; - std::vector requested_sub_protocols_; - std::string selected_subprotocol_; - - void GenerateSecWebSocketKey(); - - FRIEND_TEST_ALL_PREFIXES(WebSocketHandshakeHelperTest, null_key); - - DISALLOW_COPY_AND_ASSIGN(WebSocketHandshakeHelper); -}; - -} // namespace websocket -} // namespace cobalt - -#endif // COBALT_WEBSOCKET_WEB_SOCKET_HANDSHAKE_HELPER_H_ diff --git a/cobalt/websocket/web_socket_handshake_helper_test.cc b/cobalt/websocket/web_socket_handshake_helper_test.cc deleted file mode 100644 index a2616f4bbd6d..000000000000 --- a/cobalt/websocket/web_socket_handshake_helper_test.cc +++ /dev/null @@ -1,564 +0,0 @@ -// Copyright 2013 The Chromium Authors. All rights reserved. -/* Modifications: Copyright 2017 The Cobalt Authors. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cobalt/websocket/web_socket_handshake_helper.h" - -#include "base/strings/string_piece.h" -#include "net/http/http_request_headers.h" -#include "net/http/http_util.h" -#include "net/websockets/websocket_frame.h" -#include "starboard/memory.h" -#include "testing/gtest/include/gtest/gtest.h" -#include "url/gurl.h" - -namespace { - -using cobalt::websocket::SecWebSocketKey; - -SecWebSocketKey NullMaskingKeyGenerator() { - SecWebSocketKey::SecWebSocketKeyBytes key_bytes; - memset(&key_bytes, 0, sizeof(key_bytes)); - return SecWebSocketKey(key_bytes); -} - -std::vector RequestHeadersToVector( - const net::HttpRequestHeaders &headers) { - net::HttpRequestHeaders::Iterator it(headers); - std::vector result; - while (it.GetNext()) - result.push_back( - net::HttpRequestHeaders::HeaderKeyValuePair(it.name(), it.value())); - return result; -} - -const char kTestUserAgent[] = - "Mozilla/5.0 (X11; Linux x86_64) Cobalt/9.27875-debug (unlike Gecko) " - "Starboard/2"; - -} // namespace - -namespace net { - -bool operator==(const HttpRequestHeaders::HeaderKeyValuePair &lhs, - const HttpRequestHeaders::HeaderKeyValuePair &rhs) { - return (lhs.key == rhs.key) && (lhs.value == rhs.value); -} - -std::ostream &operator<<(std::ostream &os, - const HttpRequestHeaders::HeaderKeyValuePair &obj) { - os << obj.key << ":" << obj.value; - return os; -} - -} // namespace net - -namespace cobalt { -namespace websocket { - -class WebSocketHandshakeHelperTest : public ::testing::Test { - public: - WebSocketHandshakeHelperTest() - : handshake_helper_(kTestUserAgent, &NullMaskingKeyGenerator) {} - - WebSocketHandshakeHelper handshake_helper_; - std::vector sub_protocols_; -}; - -TEST_F(WebSocketHandshakeHelperTest, null_key) { - EXPECT_EQ(static_cast(SecWebSocketKey::kKeySizeInBytes), 16); - - handshake_helper_.GenerateSecWebSocketKey(); - std::string null_key(SecWebSocketKey::kKeySizeInBytes, '\0'); - EXPECT_EQ(memcmp(null_key.data(), - handshake_helper_.sec_websocket_key_.GetRawKeyBytes(), - SecWebSocketKey::kKeySizeInBytes), - 0); -} - -TEST_F(WebSocketHandshakeHelperTest, HandshakeInfo) { - std::string handshake_request; - GURL localhost_websocket_endpoint("ws://localhost/"); - handshake_helper_.GenerateHandshakeRequest(localhost_websocket_endpoint, - "http://localhost", sub_protocols_, - &handshake_request); - - char end_of_line[] = "\r\n"; - std::size_t first_line_end = handshake_request.find(end_of_line); - ASSERT_NE(first_line_end, std::string::npos); - net::HttpRequestHeaders headers; - base::StringPiece handshake_request_stringpiece( - handshake_request.data() + first_line_end + sizeof(end_of_line) - 1); - headers.AddHeadersFromString(handshake_request_stringpiece); - std::vector request_headers = - RequestHeadersToVector(headers); - - typedef net::HttpRequestHeaders::HeaderKeyValuePair HeaderKeyValuePair; - ASSERT_EQ(10u, request_headers.size()); - EXPECT_EQ(HeaderKeyValuePair("Host", "localhost"), request_headers[0]); - EXPECT_EQ(HeaderKeyValuePair("Connection", "Upgrade"), request_headers[1]); - EXPECT_EQ(HeaderKeyValuePair("Pragma", "no-cache"), request_headers[2]); - EXPECT_EQ(HeaderKeyValuePair("Cache-Control", "no-cache"), - request_headers[3]); - EXPECT_EQ(HeaderKeyValuePair("Upgrade", "websocket"), request_headers[4]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Extensions", ""), - request_headers[5]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Version", "13"), - request_headers[6]); - EXPECT_EQ(HeaderKeyValuePair("Origin", "http://localhost"), - request_headers[7]); - EXPECT_EQ("Sec-WebSocket-Key", request_headers[8].key); - EXPECT_EQ(HeaderKeyValuePair("User-Agent", kTestUserAgent), - request_headers[9]); -} - -TEST_F(WebSocketHandshakeHelperTest, HandshakeWithPort) { - std::string handshake_request; - GURL localhost_websocket_endpoint("ws://localhost:4541/"); - handshake_helper_.GenerateHandshakeRequest(localhost_websocket_endpoint, - "http://localhost", sub_protocols_, - &handshake_request); - - char end_of_line[] = "\r\n"; - std::size_t first_line_end = handshake_request.find(end_of_line); - ASSERT_NE(first_line_end, std::string::npos); - net::HttpRequestHeaders headers; - base::StringPiece handshake_request_stringpiece( - handshake_request.data() + first_line_end + sizeof(end_of_line) - 1); - headers.AddHeadersFromString(handshake_request_stringpiece); - std::vector request_headers = - RequestHeadersToVector(headers); - - typedef net::HttpRequestHeaders::HeaderKeyValuePair HeaderKeyValuePair; - ASSERT_EQ(10u, request_headers.size()); - EXPECT_EQ(HeaderKeyValuePair("Host", "localhost:4541"), request_headers[0]); - EXPECT_EQ(HeaderKeyValuePair("Connection", "Upgrade"), request_headers[1]); - EXPECT_EQ(HeaderKeyValuePair("Pragma", "no-cache"), request_headers[2]); - EXPECT_EQ(HeaderKeyValuePair("Cache-Control", "no-cache"), - request_headers[3]); - EXPECT_EQ(HeaderKeyValuePair("Upgrade", "websocket"), request_headers[4]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Extensions", ""), - request_headers[5]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Version", "13"), - request_headers[6]); - EXPECT_EQ(HeaderKeyValuePair("Origin", "http://localhost"), - request_headers[7]); - EXPECT_EQ("Sec-WebSocket-Key", request_headers[8].key); - EXPECT_EQ(HeaderKeyValuePair("User-Agent", kTestUserAgent), - request_headers[9]); -} - -TEST_F(WebSocketHandshakeHelperTest, HandshakePath) { - std::string handshake_request; - GURL localhost_websocket_endpoint("ws://localhost:4541/abc-def"); - handshake_helper_.GenerateHandshakeRequest(localhost_websocket_endpoint, - "http://localhost", sub_protocols_, - &handshake_request); - - char end_of_line[] = "\r\n"; - std::size_t first_line_end = handshake_request.find(end_of_line); - ASSERT_NE(first_line_end, std::string::npos); - - std::string http_GET_line = handshake_request.substr(0, first_line_end); - EXPECT_EQ(http_GET_line, "GET /abc-def HTTP/1.1"); - net::HttpRequestHeaders headers; - base::StringPiece handshake_request_stringpiece( - handshake_request.data() + first_line_end + sizeof(end_of_line) - 1); - headers.AddHeadersFromString(handshake_request_stringpiece); - std::vector request_headers = - RequestHeadersToVector(headers); - - typedef net::HttpRequestHeaders::HeaderKeyValuePair HeaderKeyValuePair; - ASSERT_EQ(10u, request_headers.size()); - EXPECT_EQ(HeaderKeyValuePair("Host", "localhost:4541"), request_headers[0]); - EXPECT_EQ(HeaderKeyValuePair("Connection", "Upgrade"), request_headers[1]); - EXPECT_EQ(HeaderKeyValuePair("Pragma", "no-cache"), request_headers[2]); - EXPECT_EQ(HeaderKeyValuePair("Cache-Control", "no-cache"), - request_headers[3]); - EXPECT_EQ(HeaderKeyValuePair("Upgrade", "websocket"), request_headers[4]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Extensions", ""), - request_headers[5]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Version", "13"), - request_headers[6]); - EXPECT_EQ(HeaderKeyValuePair("Origin", "http://localhost"), - request_headers[7]); - EXPECT_EQ("Sec-WebSocket-Key", request_headers[8].key); - EXPECT_EQ(HeaderKeyValuePair("User-Agent", kTestUserAgent), - request_headers[9]); -} - -TEST_F(WebSocketHandshakeHelperTest, HandshakePathWithQuery) { - std::string handshake_request; - GURL localhost_websocket_endpoint("ws://localhost:4541/abc?one=1&two=2"); - handshake_helper_.GenerateHandshakeRequest(localhost_websocket_endpoint, - "http://localhost", sub_protocols_, - &handshake_request); - - char end_of_line[] = "\r\n"; - std::size_t first_line_end = handshake_request.find(end_of_line); - ASSERT_NE(first_line_end, std::string::npos); - - std::string http_GET_line = handshake_request.substr(0, first_line_end); - EXPECT_EQ(http_GET_line, "GET /abc?one=1&two=2 HTTP/1.1"); - net::HttpRequestHeaders headers; - base::StringPiece handshake_request_stringpiece( - handshake_request.data() + first_line_end + sizeof(end_of_line) - 1); - headers.AddHeadersFromString(handshake_request_stringpiece); - std::vector request_headers = - RequestHeadersToVector(headers); - - typedef net::HttpRequestHeaders::HeaderKeyValuePair HeaderKeyValuePair; - ASSERT_EQ(10u, request_headers.size()); - EXPECT_EQ(HeaderKeyValuePair("Host", "localhost:4541"), request_headers[0]); - EXPECT_EQ(HeaderKeyValuePair("Connection", "Upgrade"), request_headers[1]); - EXPECT_EQ(HeaderKeyValuePair("Pragma", "no-cache"), request_headers[2]); - EXPECT_EQ(HeaderKeyValuePair("Cache-Control", "no-cache"), - request_headers[3]); - EXPECT_EQ(HeaderKeyValuePair("Upgrade", "websocket"), request_headers[4]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Extensions", ""), - request_headers[5]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Version", "13"), - request_headers[6]); - EXPECT_EQ(HeaderKeyValuePair("Origin", "http://localhost"), - request_headers[7]); - EXPECT_EQ("Sec-WebSocket-Key", request_headers[8].key); - EXPECT_EQ(HeaderKeyValuePair("User-Agent", kTestUserAgent), - request_headers[9]); -} - -TEST_F(WebSocketHandshakeHelperTest, HandshakePathWithDesiredProtocols) { - sub_protocols_.push_back("chat"); - sub_protocols_.push_back("superchat"); - - std::string handshake_request; - GURL localhost_websocket_endpoint("ws://localhost/"); - handshake_helper_.GenerateHandshakeRequest(localhost_websocket_endpoint, - "http://localhost", sub_protocols_, - &handshake_request); - - char end_of_line[] = "\r\n"; - std::size_t first_line_end = handshake_request.find(end_of_line); - ASSERT_NE(first_line_end, std::string::npos); - - std::string http_GET_line = handshake_request.substr(0, first_line_end); - EXPECT_EQ(http_GET_line, "GET / HTTP/1.1"); - net::HttpRequestHeaders headers; - base::StringPiece handshake_request_stringpiece( - handshake_request.data() + first_line_end + sizeof(end_of_line) - 1); - headers.AddHeadersFromString(handshake_request_stringpiece); - std::vector request_headers = - RequestHeadersToVector(headers); - - typedef net::HttpRequestHeaders::HeaderKeyValuePair HeaderKeyValuePair; - ASSERT_EQ(11u, request_headers.size()); - EXPECT_EQ(HeaderKeyValuePair("Host", "localhost"), request_headers[0]); - EXPECT_EQ(HeaderKeyValuePair("Connection", "Upgrade"), request_headers[1]); - EXPECT_EQ(HeaderKeyValuePair("Pragma", "no-cache"), request_headers[2]); - EXPECT_EQ(HeaderKeyValuePair("Cache-Control", "no-cache"), - request_headers[3]); - EXPECT_EQ(HeaderKeyValuePair("Upgrade", "websocket"), request_headers[4]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Extensions", ""), - request_headers[5]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Version", "13"), - request_headers[6]); - EXPECT_EQ(HeaderKeyValuePair("Origin", "http://localhost"), - request_headers[7]); - EXPECT_EQ("Sec-WebSocket-Key", request_headers[8].key); - EXPECT_EQ(HeaderKeyValuePair("User-Agent", kTestUserAgent), - request_headers[9]); - EXPECT_EQ(HeaderKeyValuePair("Sec-WebSocket-Protocol", "chat,superchat"), - request_headers[10]); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckValidResponseCode) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n" - "\r\n"; - - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_TRUE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_TRUE(error.empty()); - EXPECT_EQ(error, ""); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckInValidResponseCode) { - char response_on_wire[] = - "HTTP/1.1 200 OK\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept: Ei4axqShP74o6nOtmL3e5uRpei8=\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_EQ(error, "Invalid response code 200"); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckValidUpgradeHeader) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_TRUE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_TRUE(error.empty()); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckMissingUpgradeHeader) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_EQ(error, "'Upgrade' header is missing"); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckMultipleUpgradeHeaders) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Upgrade: HyperSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_EQ(error, - "'Upgrade' header must not appear more than once in a response"); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckValidSecWebSocketAccept) { - // To verify Sec-WebSocket-Accept values for testing, python was used. - // In [1]: import struct - // In [2]: import hashlib - // In [3]: import base64 - // In [5]: h = hashlib.sha1(base64.b64encode(struct.pack('QQ', 0, 0)) + - // b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - // In [6]: base64.b64encode(h.digest()) - // Out[6]: 'ICX+Yqv66kxgM0FcWaLWlFLwTAI=' - - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:ICX+Yqv66kxgM0FcWaLWlFLwTAI=\r\n"; - - std::string handshake_request; - GURL localhost_websocket_endpoint("ws://localhost:4541/abc?one=1&two=2"); - handshake_helper_.GenerateHandshakeRequest(localhost_websocket_endpoint, - "http://localhost", sub_protocols_, - &handshake_request); - - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_TRUE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_TRUE(error.empty()); - EXPECT_EQ(error, ""); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckMissingSecWebSocketAccept) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_EQ(error, "'Sec-WebSocket-Accept' header is missing"); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckMultipleSecWebSocketAccept) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n" - "Sec-WebSocket-Accept:\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_EQ(error, - "'Sec-WebSocket-Accept' header must not appear more than once in a " - "response"); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckValidConnection) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_TRUE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_TRUE(error.empty()); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckMissingConnectionHeader) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Sec-WebSocket-Accept:\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_EQ(error, "'Connection' header is missing"); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckValidInvalidSubprotocol) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n" - "Sec-WebSocket-Protocol:chat\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_FALSE(error.empty()); - EXPECT_EQ(error, - "Response must not include 'Sec-WebSocket-Protocol' header if not " - "present in request: chat"); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckValidInvalidDifferentSubprotocol) { - sub_protocols_.push_back("superchat"); - - std::string handshake_request; - GURL localhost_websocket_endpoint("ws://localhost/"); - handshake_helper_.GenerateHandshakeRequest(localhost_websocket_endpoint, - "http://localhost", sub_protocols_, - &handshake_request); - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n" - "Sec-WebSocket-Protocol:chat\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_FALSE(error.empty()); - EXPECT_EQ(error, "Incorrect 'Sec-WebSocket-Accept' header value"); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckSubprotocolHeaderIsOptional) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_TRUE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_TRUE(error.empty()); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckValidExtensionHeader) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n" - "Sec-WebSocket-Extensions:mux\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_EQ(error, "Cobalt does not support any websocket extensions"); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckExtensionHeaderIsOptional) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_TRUE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_TRUE(error.empty()); -} - -TEST_F(WebSocketHandshakeHelperTest, CheckMultipleExtensionHeadersAreOK) { - char response_on_wire[] = - "HTTP/1.1 101 Switching Protocols\r\n" - "Upgrade: WebSocket\r\n" - "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept:\r\n" - "Sec-WebSocket-Extensions:mux\r\n" - "Sec-WebSocket-Extensions:demux\r\n"; - std::string transformed_headers = net::HttpUtil::AssembleRawHeaders( - response_on_wire, arraysize(response_on_wire)); - scoped_refptr responseHeaders = - new net::HttpResponseHeaders(transformed_headers); - std::string error; - EXPECT_FALSE(handshake_helper_.IsResponseValid(*responseHeaders, &error)); - EXPECT_EQ(error, "Cobalt does not support any websocket extensions"); -} - -} // namespace websocket -} // namespace cobalt diff --git a/cobalt/websocket/web_socket_impl.cc b/cobalt/websocket/web_socket_impl.cc index 896726c33079..c6dfdd3aa5b0 100644 --- a/cobalt/websocket/web_socket_impl.cc +++ b/cobalt/websocket/web_socket_impl.cc @@ -156,17 +156,15 @@ void WebSocketImpl::TrampolineClose(const CloseInfo &close_info) { void WebSocketImpl::OnHandshakeComplete( const std::string &selected_subprotocol) { + if (websocket_channel_->ReadFrames() != + net::WebSocketChannel::CHANNEL_ALIVE) { + LOG(ERROR) << "Channel is closed before reading completes."; + } owner_task_runner_->PostTask( FROM_HERE, base::Bind(&WebSocketImpl::OnWebSocketConnected, this, selected_subprotocol)); } -void WebSocketImpl::OnFlowControl(int64_t quota) { - DCHECK_GE(current_quota_, 0); - current_quota_ += quota; - ProcessSendQueue(); -} - void WebSocketImpl::OnWebSocketConnected( const std::string &selected_subprotocol) { DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); @@ -205,8 +203,9 @@ void WebSocketImpl::OnClose(bool was_clean, int error_code, std::uint16_t close_code = static_cast(error_code); - DLOG(INFO) << "WebSocket is closing." << " code[" << close_code << "] reason[" - << close_reason << "]" << " was_clean: " << was_clean; + DLOG(INFO) << "WebSocket is closing." + << " code[" << close_code << "] reason[" << close_reason << "]" + << " was_clean: " << was_clean; // Queue the deletion of |websocket_channel_|. We would do it here, but this // function may be called as a callback *by* |websocket_channel_|; @@ -263,44 +262,11 @@ void WebSocketImpl::SendOnDelegateThread( DLOG(WARNING) << "Attempt to send over a closed channel."; return; } - SendQueueMessage new_message = {io_buffer, length, op_code}; - send_queue_.push(std::move(new_message)); - ProcessSendQueue(); -} - -void WebSocketImpl::ProcessSendQueue() { - DCHECK(delegate_task_runner_->RunsTasksInCurrentSequence()); - while (current_quota_ > 0 && !send_queue_.empty()) { - SendQueueMessage message = send_queue_.front(); - size_t current_message_length = message.length - sent_size_of_top_message_; - bool final = false; - bool continuation = sent_size_of_top_message_ > 0 ? true : false; - if (current_quota_ < static_cast(current_message_length)) { - // quota is not enough to send the top message. - scoped_refptr new_io_buffer( - new net::IOBuffer(static_cast(current_quota_))); - memcpy(new_io_buffer->data(), - message.io_buffer->data() + sent_size_of_top_message_, - current_quota_); - sent_size_of_top_message_ += current_quota_; - message.io_buffer = new_io_buffer; - current_message_length = current_quota_; - current_quota_ = 0; - } else { - // Sent all of the remaining top message. - final = true; - send_queue_.pop(); - sent_size_of_top_message_ = 0; - current_quota_ -= current_message_length; - } - auto channel_state = websocket_channel_->SendFrame( - final, - continuation ? net::WebSocketFrameHeader::kOpCodeContinuation - : message.op_code, - message.io_buffer, current_message_length); - if (channel_state == net::WebSocketChannel::CHANNEL_DELETED) { - websocket_channel_.reset(); - } + SendQueueMessage message = {io_buffer, length, op_code}; + auto channel_state = websocket_channel_->SendFrame( + true, message.op_code, message.io_buffer, message.length); + if (channel_state == net::WebSocketChannel::CHANNEL_DELETED) { + websocket_channel_.reset(); } } diff --git a/cobalt/websocket/web_socket_impl.h b/cobalt/websocket/web_socket_impl.h index d1473ebfd2c2..489541fb141c 100644 --- a/cobalt/websocket/web_socket_impl.h +++ b/cobalt/websocket/web_socket_impl.h @@ -26,11 +26,7 @@ #include "base/sequence_checker.h" #include "base/task/sequenced_task_runner.h" #include "cobalt/network/network_module.h" -#include "cobalt/websocket/buffered_amount_tracker.h" #include "cobalt/websocket/cobalt_web_socket_event_handler.h" -#include "cobalt/websocket/web_socket_frame_container.h" -#include "cobalt/websocket/web_socket_handshake_helper.h" -#include "cobalt/websocket/web_socket_message_container.h" #include "net/url_request/url_request_context_getter.h" #include "net/websockets/websocket_channel.h" #include "net/websockets/websocket_errors.h" @@ -89,8 +85,6 @@ class WebSocketImpl : public base::RefCountedThreadSafe { void OnHandshakeComplete(const std::string& selected_subprotocol); - void OnFlowControl(int64_t quota); - struct CloseInfo { CloseInfo(const net::WebSocketError code, const std::string& reason) : code(code), reason(reason) {} @@ -114,7 +108,6 @@ class WebSocketImpl : public base::RefCountedThreadSafe { bool SendHelper(const net::WebSocketFrameHeader::OpCode op_code, const char* data, std::size_t length, std::string* error_message); - void ProcessSendQueue(); void OnWebSocketConnected(const std::string& selected_subprotocol); void OnWebSocketDisconnected(bool was_clean, uint16 code, @@ -132,9 +125,6 @@ class WebSocketImpl : public base::RefCountedThreadSafe { std::string origin_; GURL connect_url_; - // Data buffering and flow control. - // Should only be modified on delegate(network) thread. - int64_t current_quota_ = 0; struct SendQueueMessage { scoped_refptr io_buffer; size_t length; diff --git a/cobalt/websocket/web_socket_impl_test.cc b/cobalt/websocket/web_socket_impl_test.cc deleted file mode 100644 index 344504f76e21..000000000000 --- a/cobalt/websocket/web_socket_impl_test.cc +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2019 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cobalt/websocket/web_socket_impl.h" - -#include -#include -#include - -#include "base/memory/ref_counted.h" -#include "cobalt/base/polymorphic_downcast.h" -#include "cobalt/base/task_runner_util.h" -#include "cobalt/dom/dom_settings.h" -#include "cobalt/dom/testing/stub_environment_settings.h" -#include "cobalt/network/network_module.h" -#include "cobalt/script/script_exception.h" -#include "cobalt/script/testing/mock_exception_state.h" -#include "cobalt/web/context.h" -#include "cobalt/web/dom_exception.h" -#include "cobalt/web/testing/stub_web_context.h" -#include "cobalt/websocket/mock_websocket_channel.h" -#include "cobalt/websocket/web_socket.h" -#include "testing/gtest/include/gtest/gtest.h" - -using cobalt::script::testing::MockExceptionState; -using ::testing::_; -using ::testing::DefaultValue; -using ::testing::Return; -using ::testing::SaveArg; -using ::testing::StrictMock; - -namespace cobalt { -namespace websocket { -namespace { -// These limits are copied from net::WebSocketChannel implementation. -const int kDefaultSendQuotaHighWaterMark = 1 << 17; -const int k800KB = 800; -const int kTooMuch = kDefaultSendQuotaHighWaterMark + 1; -const int kWayTooMuch = kDefaultSendQuotaHighWaterMark * 2 + 1; -const int k512KB = 512; -} // namespace - -class WebSocketImplTest : public ::testing::Test { - public: - web::EnvironmentSettings* settings() const { - return web_context_->environment_settings(); - } - void AddQuota(int quota) { - base::task_runner_util::PostBlockingTask( - network_task_runner_, FROM_HERE, - base::Bind(&WebSocketImpl::OnFlowControl, websocket_impl_, quota)); - } - - protected: - WebSocketImplTest() : web_context_(new web::testing::StubWebContext()) { - web_context_->SetupEnvironmentSettings( - new dom::testing::StubEnvironmentSettings()); - web_context_->environment_settings()->set_creation_url( - GURL("https://127.0.0.1:1234")); - web_context_->SetupFinished(); - std::vector sub_protocols; - sub_protocols.push_back("chat"); - network_task_runner_ = web_context_->network_module() - ->url_request_context_getter() - ->GetNetworkTaskRunner(); - } - - void SetUp() override { - websocket_impl_ = - new WebSocketImpl(web_context_->network_module(), nullptr); - // Setting this was usually done by WebSocketImpl::Connect, but since we do - // not do Connect for every test, we have to make sure its task runner is - // set. - websocket_impl_->delegate_task_runner_ = network_task_runner_; - // The holder is only created to be base::Passed() on the next line, it will - // be empty so do not use it later. - base::task_runner_util::PostBlockingTask( - network_task_runner_, FROM_HERE, - base::Bind( - [](scoped_refptr websocket_impl, - web::Context* web_context, - MockWebSocketChannel** mock_channel_slot) { - *mock_channel_slot = new MockWebSocketChannel( - websocket_impl, web_context->network_module()); - websocket_impl->websocket_channel_ = - std::unique_ptr(*mock_channel_slot); - }, - websocket_impl_, web_context_.get(), &mock_channel_)); - } - - void TearDown() override { - base::task_runner_util::PostBlockingTask( - network_task_runner_, FROM_HERE, - base::Bind(&WebSocketImpl::OnClose, websocket_impl_, true /*was_clan*/, - net::kWebSocketNormalClosure /*error_code*/, - "" /*close_reason*/)); - } - - std::unique_ptr web_context_; - scoped_refptr network_task_runner_; - scoped_refptr websocket_impl_; - MockWebSocketChannel* mock_channel_; - StrictMock exception_state_; -}; - -TEST_F(WebSocketImplTest, DISABLED_NormalSizeRequest) { - // Normally the high watermark quota is given at websocket connection success. - AddQuota(kDefaultSendQuotaHighWaterMark); - - { - base::AutoLock scoped_lock(mock_channel_->lock()); - // mock_channel_ is created and used on network thread. - EXPECT_CALL( - *mock_channel_, - MockSendFrame(true, net::WebSocketFrameHeader::kOpCodeText, _, k800KB)) - .Times(1) - .WillOnce(Return(net::WebSocketChannel::CHANNEL_ALIVE)); - } - - char data[k800KB]; - int32 buffered_amount = 0; - std::string error; - websocket_impl_->SendText(data, k800KB, &buffered_amount, &error); -} - -TEST_F(WebSocketImplTest, DISABLED_LargeRequest) { - AddQuota(kDefaultSendQuotaHighWaterMark); - - // mock_channel_ is created and used on network thread. - { - base::AutoLock scoped_lock(mock_channel_->lock()); - EXPECT_CALL(*mock_channel_, - MockSendFrame(true, net::WebSocketFrameHeader::kOpCodeText, _, - kDefaultSendQuotaHighWaterMark)) - .Times(1) - .WillOnce(Return(net::WebSocketChannel::CHANNEL_ALIVE)); - } - - char data[kDefaultSendQuotaHighWaterMark]; - int32 buffered_amount = 0; - std::string error; - websocket_impl_->SendText(data, kDefaultSendQuotaHighWaterMark, - &buffered_amount, &error); -} - -TEST_F(WebSocketImplTest, DISABLED_OverLimitRequest) { - AddQuota(kDefaultSendQuotaHighWaterMark); - - // mock_channel_ is created and used on network thread. - { - base::AutoLock scoped_lock(mock_channel_->lock()); - EXPECT_CALL(*mock_channel_, - MockSendFrame(false, net::WebSocketFrameHeader::kOpCodeText, _, - kDefaultSendQuotaHighWaterMark)) - .Times(1) - .WillRepeatedly(Return(net::WebSocketChannel::CHANNEL_ALIVE)); - - EXPECT_CALL( - *mock_channel_, - MockSendFrame(false, net::WebSocketFrameHeader::kOpCodeContinuation, _, - kDefaultSendQuotaHighWaterMark)) - .Times(1) - .WillRepeatedly(Return(net::WebSocketChannel::CHANNEL_ALIVE)); - - EXPECT_CALL(*mock_channel_, - MockSendFrame( - true, net::WebSocketFrameHeader::kOpCodeContinuation, _, 1)) - .Times(1) - .WillOnce(Return(net::WebSocketChannel::CHANNEL_ALIVE)); - } - - char data[kWayTooMuch]; - int32 buffered_amount = 0; - std::string error; - websocket_impl_->SendText(data, kWayTooMuch, &buffered_amount, &error); - - AddQuota(kDefaultSendQuotaHighWaterMark); - AddQuota(kDefaultSendQuotaHighWaterMark); -} - -TEST_F(WebSocketImplTest, DISABLED_ReuseSocketForLargeRequest) { - AddQuota(kDefaultSendQuotaHighWaterMark); - - // mock_channel_ is created and used on network thread. - { - base::AutoLock scoped_lock(mock_channel_->lock()); - EXPECT_CALL(*mock_channel_, - MockSendFrame(false, net::WebSocketFrameHeader::kOpCodeBinary, - _, kDefaultSendQuotaHighWaterMark)) - .Times(1) - .WillOnce(Return(net::WebSocketChannel::CHANNEL_ALIVE)); - EXPECT_CALL(*mock_channel_, - MockSendFrame( - true, net::WebSocketFrameHeader::kOpCodeContinuation, _, 1)) - .Times(1) - .WillOnce(Return(net::WebSocketChannel::CHANNEL_ALIVE)); - EXPECT_CALL(*mock_channel_, - MockSendFrame(false, net::WebSocketFrameHeader::kOpCodeText, _, - k512KB - 1)) - .Times(1) - .WillOnce(Return(net::WebSocketChannel::CHANNEL_ALIVE)); - EXPECT_CALL( - *mock_channel_, - MockSendFrame(true, net::WebSocketFrameHeader::kOpCodeContinuation, _, - kTooMuch - (k512KB - 1))) - .Times(1) - .WillOnce(Return(net::WebSocketChannel::CHANNEL_ALIVE)); - } - - char data[kTooMuch]; - int32 buffered_amount = 0; - std::string error; - websocket_impl_->SendBinary(data, kTooMuch, &buffered_amount, &error); - websocket_impl_->SendText(data, kTooMuch, &buffered_amount, &error); - - AddQuota(k512KB); - AddQuota(kDefaultSendQuotaHighWaterMark); -} - -} // namespace websocket -} // namespace cobalt diff --git a/cobalt/websocket/web_socket_message_container.cc b/cobalt/websocket/web_socket_message_container.cc deleted file mode 100644 index 4260c8be74f6..000000000000 --- a/cobalt/websocket/web_socket_message_container.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cobalt/websocket/web_socket_message_container.h" - -#include "base/basictypes.h" - -namespace cobalt { -namespace websocket { - -std::size_t CombineFramesChunks(WebSocketFrameContainer::const_iterator begin, - WebSocketFrameContainer::const_iterator end, - char *out_destination, - std::size_t buffer_length) { - DCHECK(out_destination); - std::size_t bytes_written = 0; - std::size_t bytes_available = buffer_length; - for (WebSocketFrameContainer::const_iterator iterator = begin; - iterator != end; ++iterator) { - const scoped_refptr &data((*iterator)->data); - - if (data) { - std::size_t frame_chunk_size = data->size(); - - if (bytes_available >= frame_chunk_size) { - memcpy(out_destination, data->data(), frame_chunk_size); - out_destination += frame_chunk_size; - bytes_written += frame_chunk_size; - bytes_available -= frame_chunk_size; - } - } - } - - DCHECK_LE(bytes_written, buffer_length); - return bytes_written; -} - -bool WebSocketMessageContainer::Take(WebSocketFrameContainer *frame_container) { - DCHECK(frame_container); - DCHECK(!IsMessageComplete()); - DCHECK(frame_container->IsFrameComplete()); - if (!frame_container) { - return false; - } - if (frame_container->empty()) { - return true; - } - - bool is_first_frame = frames_.empty(); - bool is_continuation_frame = frame_container->IsContinuationFrame(); - - if (is_first_frame) { - if (is_continuation_frame) { - return false; - } - } else { - // All frames after the first one must be continuation frames. - if (!is_continuation_frame) { - return false; - } - } - - frames_.push_back(WebSocketFrameContainer()); - - WebSocketFrameContainer &last_object(frames_.back()); - last_object.swap(*frame_container); - DCHECK(last_object.IsFrameComplete()); - - payload_size_bytes_ += last_object.GetCurrentPayloadSizeBytes(); - message_completed_ |= last_object.IsFinalFrame(); - - return true; -} - -scoped_refptr -WebSocketMessageContainer::GetMessageAsIOBuffer() const { - scoped_refptr buf; - - DCHECK_LE(kMaxMessagePayloadInBytes, static_cast(kint32max)); - DCHECK_LE(payload_size_bytes_, kMaxMessagePayloadInBytes); - DCHECK_GE(payload_size_bytes_, 0UL); - - if ((payload_size_bytes_ > 0) && - (payload_size_bytes_ <= kMaxMessagePayloadInBytes)) { - buf = base::WrapRefCounted( - new net::IOBufferWithSize(static_cast(payload_size_bytes_))); - - std::size_t total_bytes_written = 0; - - char *data_pointer = buf->data(); - std::size_t size_remaining = buf->size(); - - for (WebSocketFrames::const_iterator iterator = frames_.begin(); - iterator != frames_.end(); ++iterator) { - const WebSocketFrameContainer &frame_container(*iterator); - - std::size_t bytes_written = - CombineFramesChunks(frame_container.begin(), frame_container.end(), - data_pointer, size_remaining); - - DCHECK_LE(bytes_written, size_remaining); - - size_remaining -= bytes_written; - data_pointer += bytes_written; - - total_bytes_written += bytes_written; - } - - DCHECK_EQ(total_bytes_written, payload_size_bytes_); - } - - return buf; -} - -} // namespace websocket -} // namespace cobalt diff --git a/cobalt/websocket/web_socket_message_container.h b/cobalt/websocket/web_socket_message_container.h deleted file mode 100644 index c7cb65bf7e87..000000000000 --- a/cobalt/websocket/web_socket_message_container.h +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#ifndef COBALT_WEBSOCKET_WEB_SOCKET_MESSAGE_CONTAINER_H_ -#define COBALT_WEBSOCKET_WEB_SOCKET_MESSAGE_CONTAINER_H_ - -#include - -#include "base/memory/ref_counted.h" -#include "cobalt/websocket/web_socket_frame_container.h" -#include "net/base/io_buffer.h" -#include "net/websockets/websocket_frame.h" - -namespace cobalt { -namespace websocket { - -const size_t kMaxMessagePayloadInBytes = 4 * 1024 * 1024; -COMPILE_ASSERT(kMaxMessagePayloadInBytes >= kMaxFramePayloadInBytes, - max_message_size_must_be_greater_than_max_payload_size); - -class WebSocketMessageContainer { - public: - typedef std::deque WebSocketFrames; - - WebSocketMessageContainer() - : message_completed_(false), payload_size_bytes_(0) {} - ~WebSocketMessageContainer() { clear(); } - - void clear() { - message_completed_ = false; - payload_size_bytes_ = 0; - frames_.clear(); - } - - bool GetMessageOpCode(net::WebSocketFrameHeader::OpCode *op_code) const { - DCHECK(op_code); - if (empty()) { - return false; - } - - return frames_.begin()->GetFrameOpCode(op_code); - } - - bool IsMessageComplete() const { return message_completed_; } - - // Returns true if and only if it a text message. - // Note: It is valid to call this function on uncompleted messages. - bool IsTextMessage() const { - net::WebSocketFrameHeader::OpCode message_op_code = - net::WebSocketFrameHeader::kOpCodeContinuation; - - bool success = GetMessageOpCode(&message_op_code); - if (!success) { - DLOG(INFO) << "Unable to retrieve the message op code. Empty message?"; - return false; - } - - DCHECK_NE(message_op_code, net::WebSocketFrameHeader::kOpCodePing); - DCHECK_NE(message_op_code, net::WebSocketFrameHeader::kOpCodePong); - DCHECK_NE(message_op_code, net::WebSocketFrameHeader::kOpCodeClose); - - return (message_op_code == net::WebSocketFrameHeader::kOpCodeText); - } - - // Should only be called if IsMessageComplete() is false, and - // |frame_container| is a full frame. - bool Take(WebSocketFrameContainer *frame_container); - - std::size_t GetCurrentPayloadSizeBytes() const { return payload_size_bytes_; } - - scoped_refptr GetMessageAsIOBuffer() const; - - const WebSocketFrames &GetFrames() const { return frames_; } - bool empty() const { return frames_.empty(); } - - private: - bool message_completed_; - std::size_t payload_size_bytes_; - WebSocketFrames frames_; -}; - -} // namespace websocket -} // namespace cobalt - -#endif // COBALT_WEBSOCKET_WEB_SOCKET_MESSAGE_CONTAINER_H_ diff --git a/cobalt/websocket/web_socket_message_container_test.cc b/cobalt/websocket/web_socket_message_container_test.cc deleted file mode 100644 index 2a8f4963c886..000000000000 --- a/cobalt/websocket/web_socket_message_container_test.cc +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright 2017 The Cobalt Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "cobalt/websocket/web_socket_message_container.h" - -#include - -#include "testing/gtest/include/gtest/gtest.h" - -namespace { - -const char payload_filler = '?'; - -} // namespace - -namespace cobalt { -namespace websocket { - -class WebSocketMessageContainerTest : public ::testing::Test { - public: - WebSocketMessageContainerTest() { - PopulateBinaryFrame(2); // Create a final binary frame with 2 byte payload. - PopulateTextFrame(); - PopulateContinuationFrame(3); - } - - protected: - void PopulateBinaryFrame(int payload_size) { - net::WebSocketFrameChunk* chunk = new net::WebSocketFrameChunk(); - chunk->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk->header->final = true; - chunk->header->payload_length = payload_size; - chunk->header->opcode = net::WebSocketFrameHeader::kOpCodeBinary; - chunk->data = base::WrapRefCounted( - new net::IOBufferWithSize(payload_size)); - chunk->final_chunk = true; - WebSocketFrameContainer::ErrorCode error_code = - final_binary_frame_.Take(chunk); - EXPECT_EQ(error_code, WebSocketFrameContainer::kErrorNone); - } - - void PopulateTextFrame() { - net::WebSocketFrameChunk* chunk = new net::WebSocketFrameChunk(); - chunk->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk->header->final = false; - chunk->header->payload_length = 0; - chunk->header->opcode = net::WebSocketFrameHeader::kOpCodeText; - chunk->final_chunk = true; - WebSocketFrameContainer::ErrorCode error_code = - nonfinal_text_frame_.Take(chunk); - EXPECT_EQ(error_code, WebSocketFrameContainer::kErrorNone); - } - - void PopulateContinuationFrame(int payload_size) { - net::WebSocketFrameChunk* chunk = new net::WebSocketFrameChunk(); - chunk->header = std::unique_ptr( - new net::WebSocketFrameHeader()); - chunk->header->final = true; - chunk->header->payload_length = payload_size; - chunk->header->opcode = net::WebSocketFrameHeader::kOpCodeContinuation; - chunk->data = base::WrapRefCounted( - new net::IOBufferWithSize(payload_size)); - - net::IOBufferWithSize& data_array(*chunk->data.get()); - memset(data_array.data(), payload_filler, data_array.size()); - - chunk->final_chunk = true; - WebSocketFrameContainer::ErrorCode error_code = - final_continuation_frame_.Take(chunk); - EXPECT_EQ(error_code, WebSocketFrameContainer::kErrorNone); - } - - WebSocketFrameContainer final_binary_frame_; // Final frame - - WebSocketFrameContainer nonfinal_text_frame_; // Not a final frame. - WebSocketFrameContainer final_continuation_frame_; - - WebSocketMessageContainer message_container_; -}; - -TEST_F(WebSocketMessageContainerTest, Construct) { - EXPECT_TRUE(message_container_.empty()); - EXPECT_FALSE(message_container_.IsMessageComplete()); - EXPECT_EQ(message_container_.GetCurrentPayloadSizeBytes(), 0UL); -} - -TEST_F(WebSocketMessageContainerTest, Clear) { - EXPECT_TRUE(message_container_.Take(&final_binary_frame_)); - message_container_.clear(); - EXPECT_TRUE(message_container_.empty()); - EXPECT_FALSE(message_container_.IsMessageComplete()); - EXPECT_EQ(message_container_.GetCurrentPayloadSizeBytes(), 0UL); -} - -TEST_F(WebSocketMessageContainerTest, GetMessageOpCode) { - net::WebSocketFrameHeader::OpCode op; - EXPECT_FALSE(message_container_.GetMessageOpCode(&op)); - EXPECT_TRUE(message_container_.Take(&final_binary_frame_)); - EXPECT_TRUE(message_container_.GetMessageOpCode(&op)); - EXPECT_EQ(op, net::WebSocketFrameHeader::kOpCodeBinary); -} - -TEST_F(WebSocketMessageContainerTest, TakeMultipleFrames) { - EXPECT_EQ(message_container_.GetFrames().size(), 0U); - EXPECT_TRUE(message_container_.Take(&nonfinal_text_frame_)); - EXPECT_EQ(message_container_.GetFrames().size(), 1U); - EXPECT_TRUE(message_container_.Take(&final_continuation_frame_)); - EXPECT_EQ(message_container_.GetFrames().size(), 2U); -} - -TEST_F(WebSocketMessageContainerTest, IsMessageComplete) { - EXPECT_TRUE(message_container_.Take(&final_binary_frame_)); - EXPECT_TRUE(message_container_.IsMessageComplete()); -} - -TEST_F(WebSocketMessageContainerTest, IsMessageComplete2) { - EXPECT_TRUE(message_container_.Take(&nonfinal_text_frame_)); - EXPECT_FALSE(message_container_.IsMessageComplete()); - EXPECT_TRUE(message_container_.Take(&final_continuation_frame_)); - EXPECT_TRUE(message_container_.IsMessageComplete()); -} - -TEST_F(WebSocketMessageContainerTest, IsTextMessage) { - EXPECT_TRUE(message_container_.Take(&nonfinal_text_frame_)); - EXPECT_TRUE(message_container_.IsTextMessage()); -} - -TEST_F(WebSocketMessageContainerTest, IsTextMessage2) { - EXPECT_TRUE(message_container_.Take(&final_binary_frame_)); - EXPECT_FALSE(message_container_.IsTextMessage()); -} - -TEST_F(WebSocketMessageContainerTest, PayloadSize) { - EXPECT_EQ(message_container_.GetCurrentPayloadSizeBytes(), 0UL); - EXPECT_TRUE(message_container_.Take(&nonfinal_text_frame_)); - EXPECT_EQ(message_container_.GetCurrentPayloadSizeBytes(), 0UL); - EXPECT_TRUE(message_container_.Take(&final_continuation_frame_)); - EXPECT_EQ(message_container_.GetCurrentPayloadSizeBytes(), 3UL); -} - -TEST_F(WebSocketMessageContainerTest, TakeStartsWithContinuationFrame) { - EXPECT_FALSE(message_container_.Take(&final_continuation_frame_)); -} - -TEST_F(WebSocketMessageContainerTest, GetIOBuffer) { - EXPECT_TRUE(message_container_.Take(&nonfinal_text_frame_)); - EXPECT_TRUE(message_container_.Take(&final_continuation_frame_)); - scoped_refptr payload = - message_container_.GetMessageAsIOBuffer(); - - int payload_size = payload->size(); - EXPECT_GE(payload_size, 0); - - char* data = payload->data(); - for (int i = 0; i != payload_size; ++i) { - DCHECK_EQ(payload_filler, data[i]); - } -} - -} // namespace websocket -} // namespace cobalt