diff --git a/all_files.vcxitems b/all_files.vcxitems index 240de11..1dcab1d 100644 --- a/all_files.vcxitems +++ b/all_files.vcxitems @@ -15,6 +15,7 @@ + diff --git a/cdc_rsync/base/BUILD b/cdc_rsync/base/BUILD index 655a739..d68b392 100644 --- a/cdc_rsync/base/BUILD +++ b/cdc_rsync/base/BUILD @@ -80,7 +80,15 @@ cc_library( cc_library( name = "socket", + srcs = ["socket.cc"], hdrs = ["socket.h"], + deps = [ + "//common:log", + "//common:platform", + "//common:status", + "//common:util", + "@com_google_absl//absl/status", + ], ) filegroup( diff --git a/cdc_rsync/base/socket.cc b/cdc_rsync/base/socket.cc new file mode 100644 index 0000000..e02b9ed --- /dev/null +++ b/cdc_rsync/base/socket.cc @@ -0,0 +1,65 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cdc_rsync/base/socket.h" + +#include "common/log.h" +#include "common/platform.h" +#include "common/status.h" +#include "common/util.h" + +#if PLATFORM_WINDOWS +#include +#endif + +namespace cdc_ft { + +// static +absl::Status Socket::Initialize() { +#if PLATFORM_WINDOWS + WSADATA wsaData; + const int result = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (result != 0) { + return MakeStatus("WSAStartup() failed: %s", Util::GetWin32Error(result)); + } + return absl::OkStatus(); +#elif PLATFORM_LINUX + return absl::OkStatus(); +#endif +} + +// static +absl::Status Socket::Shutdown() { +#if PLATFORM_WINDOWS + const int result = WSACleanup(); + if (result == SOCKET_ERROR) { + return MakeStatus("WSACleanup() failed: %s", + Util::GetWin32Error(WSAGetLastError())); + } + return absl::OkStatus(); +#elif PLATFORM_LINUX + return absl::OkStatus(); +#endif +} + +SocketFinalizer::~SocketFinalizer() { + absl::Status status = Socket::Shutdown(); + if (!status.ok()) { + LOG_ERROR("Socket shutdown failed: %s", status.message()) + } +}; + +} // namespace cdc_ft diff --git a/cdc_rsync/base/socket.h b/cdc_rsync/base/socket.h index c156dab..9c6c876 100644 --- a/cdc_rsync/base/socket.h +++ b/cdc_rsync/base/socket.h @@ -26,6 +26,14 @@ class Socket { Socket() = default; virtual ~Socket() = default; + // Calls WSAStartup() on Windows, no-op on Linux. + // Must be called before using sockets. + static absl::Status Initialize(); + + // Calls WSACleanup() on Windows, no-op on Linux. + // Must be called after using sockets. + static absl::Status Shutdown(); + // Send data to the socket. virtual absl::Status Send(const void* buffer, size_t size) = 0; @@ -40,6 +48,12 @@ class Socket { size_t* bytes_received) = 0; }; +// Convenience class that calls Shutdown() on destruction. Logs on errors. +class SocketFinalizer { + public: + ~SocketFinalizer(); +}; + } // namespace cdc_ft #endif // CDC_RSYNC_BASE_SOCKET_H_ diff --git a/cdc_rsync/cdc_rsync_client.cc b/cdc_rsync/cdc_rsync_client.cc index 1dccd6c..525890c 100644 --- a/cdc_rsync/cdc_rsync_client.cc +++ b/cdc_rsync/cdc_rsync_client.cc @@ -263,6 +263,12 @@ absl::Status CdcRsyncClient::StartServer() { return SetTag(MakeStatus("Redeploy server"), Tag::kDeployServer); } + status = Socket::Initialize(); + if (!status.ok()) { + return WrapStatus(status, "Failed to initialize sockets"); + } + socket_finalizer_ = std::make_unique(); + assert(is_server_listening_); status = socket_.Connect(port); if (!status.ok()) { diff --git a/cdc_rsync/cdc_rsync_client.h b/cdc_rsync/cdc_rsync_client.h index 9e24b78..055fb69 100644 --- a/cdc_rsync/cdc_rsync_client.h +++ b/cdc_rsync/cdc_rsync_client.h @@ -123,6 +123,7 @@ class CdcRsyncClient { WinProcessFactory process_factory_; RemoteUtil remote_util_; PortManager port_manager_; + std::unique_ptr socket_finalizer_; ClientSocket socket_; MessagePump message_pump_{&socket_, MessagePump::PacketReceivedDelegate()}; ConsoleProgressPrinter printer_; diff --git a/cdc_rsync/client_socket.cc b/cdc_rsync/client_socket.cc index c124ed7..873d3f4 100644 --- a/cdc_rsync/client_socket.cc +++ b/cdc_rsync/client_socket.cc @@ -39,10 +39,10 @@ absl::Status MakeSocketStatus(const char* message) { } // namespace -struct SocketInfo { +struct ClientSocketInfo { SOCKET socket; - SocketInfo() : socket(INVALID_SOCKET) {} + ClientSocketInfo() : socket(INVALID_SOCKET) {} }; ClientSocket::ClientSocket() = default; @@ -50,12 +50,6 @@ ClientSocket::ClientSocket() = default; ClientSocket::~ClientSocket() { Disconnect(); } absl::Status ClientSocket::Connect(int port) { - WSADATA wsaData; - int result = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (result != 0) { - return MakeStatus("WSAStartup() failed: %i", result); - } - addrinfo hints; ZeroMemory(&hints, sizeof(hints)); hints.ai_family = AF_INET; @@ -64,14 +58,13 @@ absl::Status ClientSocket::Connect(int port) { // Resolve the server address and port. addrinfo* addr_infos = nullptr; - result = getaddrinfo("localhost", std::to_string(port).c_str(), &hints, - &addr_infos); + int result = getaddrinfo("localhost", std::to_string(port).c_str(), &hints, + &addr_infos); if (result != 0) { - WSACleanup(); return MakeStatus("getaddrinfo() failed: %i", result); } - socket_info_ = std::make_unique(); + socket_info_ = std::make_unique(); int count = 0; for (addrinfo* curr = addr_infos; curr; curr = curr->ai_next, count++) { socket_info_->socket = @@ -101,7 +94,6 @@ absl::Status ClientSocket::Connect(int port) { if (socket_info_->socket == INVALID_SOCKET) { socket_info_.reset(); - WSACleanup(); return MakeStatus("Unable to connect to port %i", port); } @@ -120,7 +112,6 @@ void ClientSocket::Disconnect() { } socket_info_.reset(); - WSACleanup(); } absl::Status ClientSocket::Send(const void* buffer, size_t size) { diff --git a/cdc_rsync/client_socket.h b/cdc_rsync/client_socket.h index ec6eb91..9be7835 100644 --- a/cdc_rsync/client_socket.h +++ b/cdc_rsync/client_socket.h @@ -45,7 +45,7 @@ class ClientSocket : public Socket { size_t* bytes_received) override; private: - std::unique_ptr socket_info_; + std::unique_ptr socket_info_; }; } // namespace cdc_ft diff --git a/cdc_rsync_server/BUILD b/cdc_rsync_server/BUILD index 5b80201..e6c9551 100644 --- a/cdc_rsync_server/BUILD +++ b/cdc_rsync_server/BUILD @@ -127,11 +127,17 @@ cc_library( name = "server_socket", srcs = ["server_socket.cc"], hdrs = ["server_socket.h"], - target_compatible_with = ["@platforms//os:linux"], + linkopts = select({ + "//tools:windows": [ + "/DEFAULTLIB:Ws2_32.lib", # Sockets, e.g. recv, send, WSA*. + ], + "//conditions:default": [], + }), deps = [ "//cdc_rsync/base:socket", "//common:log", "//common:status", + "//common:util", "@com_google_absl//absl/status", ], ) diff --git a/cdc_rsync_server/cdc_rsync_server.cc b/cdc_rsync_server/cdc_rsync_server.cc index b6f9e04..f7a01f2 100644 --- a/cdc_rsync_server/cdc_rsync_server.cc +++ b/cdc_rsync_server/cdc_rsync_server.cc @@ -148,10 +148,7 @@ PathFilter::Rule::Type ToInternalType( CdcRsyncServer::CdcRsyncServer() = default; -CdcRsyncServer::~CdcRsyncServer() { - message_pump_.reset(); - socket_.reset(); -} +CdcRsyncServer::~CdcRsyncServer() = default; bool CdcRsyncServer::CheckComponents( const std::vector& components) { @@ -173,8 +170,14 @@ bool CdcRsyncServer::CheckComponents( } absl::Status CdcRsyncServer::Run(int port) { + absl::Status status = Socket::Initialize(); + if (!status.ok()) { + return WrapStatus(status, "Failed to initialize sockets"); + } + socket_finalizer_ = std::make_unique(); + socket_ = std::make_unique(); - absl::Status status = socket_->StartListening(port); + status = socket_->StartListening(port); if (!status.ok()) { return WrapStatus(status, "Failed to start listening on port %i", port); } @@ -563,7 +566,7 @@ absl::Status CdcRsyncServer::HandleSendMissingFileData() { // Verify that there is no directory existing with the same name. if (path::Exists(filepath) && path::DirExists(filepath)) { assert(!diff_.extraneous_dirs.empty()); - absl::Status status = path::RemoveFile(filepath); + status = path::RemoveFile(filepath); if (!status.ok()) { return WrapStatus( status, "Failed to remove folder '%s' before creating file '%s'", diff --git a/cdc_rsync_server/cdc_rsync_server.h b/cdc_rsync_server/cdc_rsync_server.h index 0a58549..59c66fd 100644 --- a/cdc_rsync_server/cdc_rsync_server.h +++ b/cdc_rsync_server/cdc_rsync_server.h @@ -32,6 +32,7 @@ namespace cdc_ft { class MessagePump; class ServerSocket; +class SocketFinalizer; class CdcRsyncServer { public: @@ -90,6 +91,8 @@ class CdcRsyncServer { // Used to toggle decompression. void Thread_OnPackageReceived(PacketType type); + // The order determines the correct destruction order, so keep it! + std::unique_ptr socket_finalizer_; std::unique_ptr socket_; std::unique_ptr message_pump_; diff --git a/cdc_rsync_server/server_socket.cc b/cdc_rsync_server/server_socket.cc index 228d31c..f76be6e 100644 --- a/cdc_rsync_server/server_socket.cc +++ b/cdc_rsync_server/server_socket.cc @@ -14,20 +14,72 @@ #include "cdc_rsync_server/server_socket.h" +#include "common/log.h" +#include "common/platform.h" +#include "common/status.h" +#include "common/util.h" + +#if PLATFORM_WINDOWS + +#include + +#elif PLATFORM_LINUX + #include #include #include #include -#include "common/log.h" -#include "common/status.h" +#endif namespace cdc_ft { - namespace { -int kInvalidFd = -1; +#if PLATFORM_WINDOWS + +using SocketType = SOCKET; +using SockAddrType = SOCKADDR; +constexpr SocketType kInvalidSocket = INVALID_SOCKET; +constexpr int kSocketError = SOCKET_ERROR; +constexpr int kSendingEnd = SD_SEND; + +constexpr int kErrAgain = WSAEWOULDBLOCK; // There's no EAGAIN on Windows. +constexpr int kErrWouldBlock = WSAEWOULDBLOCK; +constexpr int kErrAddrInUse = WSAEADDRINUSE; + +int GetLastError() { return WSAGetLastError(); } +std::string GetErrorStr(int err) { return Util::GetWin32Error(err); } +void Close(SocketType* socket) { + if (*socket != kInvalidSocket) { + closesocket(*socket); + *socket = kInvalidSocket; + } +} + +// Not necessary on Windows. +#define HANDLE_EINTR(x) (x) + +#elif PLATFORM_LINUX + +using SocketType = int; +using SockAddrType = sockaddr; +constexpr SocketType kInvalidSocket = -1; +constexpr int kSocketError = -1; +constexpr int kSendingEnd = SHUT_WR; + +constexpr int kErrAgain = EAGAIN; +constexpr int kErrWouldBlock = EWOULDBLOCK; +constexpr int kErrAddrInUse = EADDRINUSE; + +int GetLastError() { return errno; } +std::string GetErrorStr(int err) { return strerror(err); } +void Close(SocketType* socket) { + if (*socket != kInvalidSocket) { + close(*socket); + *socket = kInvalidSocket; + } +} // Keep re-evaluating the expression |x| while it returns EINTR. #define HANDLE_EINTR(x) \ @@ -39,10 +91,22 @@ int kInvalidFd = -1; eintr_wrapper_result; \ }) +#endif + +std::string GetLastErrorStr() { return GetErrorStr(GetLastError()); } + } // namespace +struct ServerSocketInfo { + // Listening socket file descriptor (where new connections are accepted). + SocketType listen_sock = kInvalidSocket; + + // Connection socket file descriptor (where data is sent to/received from). + SocketType conn_sock = kInvalidSocket; +}; + ServerSocket::ServerSocket() - : Socket(), listen_sockfd_(kInvalidFd), conn_sockfd_(kInvalidFd) {} + : Socket(), socket_info_(std::make_unique()) {} ServerSocket::~ServerSocket() { Disconnect(); @@ -50,25 +114,26 @@ ServerSocket::~ServerSocket() { } absl::Status ServerSocket::StartListening(int port) { - if (listen_sockfd_ != kInvalidFd) { + if (socket_info_->listen_sock != kInvalidSocket) { return MakeStatus("Already listening"); } LOG_DEBUG("Open socket"); - listen_sockfd_ = socket(AF_INET, SOCK_STREAM, 0); - if (listen_sockfd_ < 0) { - listen_sockfd_ = kInvalidFd; - return MakeStatus("socket() failed: %s", strerror(errno)); + socket_info_->listen_sock = socket(AF_INET, SOCK_STREAM, 0); + if (socket_info_->listen_sock == kInvalidSocket) { + return MakeStatus("Creating listen socket failed: %s", GetLastErrorStr()); } // If the program terminates abnormally, the socket might remain in a // TIME_WAIT state and report "address already in use" on bind(). Setting // SO_REUSEADDR works around that. See // https://hea-www.harvard.edu/~fine/Tech/addrinuse.html - int enable = 1; - if (setsockopt(listen_sockfd_, SOL_SOCKET, SO_REUSEADDR, &enable, - sizeof(enable)) < 0) { - LOG_DEBUG("setsockopt() failed"); + const int enable = 1; + int result = + setsockopt(socket_info_->listen_sock, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&enable), sizeof(enable)); + if (result == kSocketError) { + LOG_DEBUG("Enabling address reusal failed"); } LOG_DEBUG("Bind socket"); @@ -77,46 +142,47 @@ absl::Status ServerSocket::StartListening(int port) { serv_addr.sin_family = AF_INET; serv_addr.sin_addr.s_addr = INADDR_ANY; serv_addr.sin_port = htons(port); - if (bind(listen_sockfd_, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) < - 0) { + + result = bind(socket_info_->listen_sock, + reinterpret_cast(&serv_addr), + sizeof(serv_addr)); + if (result == kSocketError) { + int err = GetLastError(); absl::Status status = - MakeStatus("bind() to port %i failed: %s", port, strerror(errno)); - if (errno == EADDRINUSE) { + MakeStatus("Binding to port %i failed: %s", port, GetErrorStr(err)); + if (err == kErrAddrInUse) { // Happens when two instances are run at the same time. Help callers to // print reasonable errors. status = SetTag(status, Tag::kAddressInUse); } - close(listen_sockfd_); - listen_sockfd_ = kInvalidFd; - + Close(&socket_info_->listen_sock); return status; } LOG_DEBUG("Listen"); - listen(listen_sockfd_, 1); + result = listen(socket_info_->listen_sock, 1); + if (result == kSocketError) { + int err = GetLastError(); + Close(&socket_info_->listen_sock); + return MakeStatus("Listening to socket failed: %s", GetErrorStr(err)); + } + return absl::OkStatus(); } void ServerSocket::StopListening() { - if (listen_sockfd_ != kInvalidFd) { - close(listen_sockfd_); - listen_sockfd_ = kInvalidFd; - } - + Close(&socket_info_->listen_sock); LOG_INFO("Stopped listening."); } absl::Status ServerSocket::WaitForConnection() { - if (conn_sockfd_ != kInvalidFd) { + if (socket_info_->conn_sock != kInvalidSocket) { return MakeStatus("Already connected"); } - sockaddr_in cli_addr; - socklen_t cli_len = sizeof(cli_addr); - conn_sockfd_ = accept(listen_sockfd_, (struct sockaddr*)&cli_addr, &cli_len); - if (conn_sockfd_ < 0) { - conn_sockfd_ = kInvalidFd; - return MakeStatus("accept() failed: %s", strerror(errno)); + socket_info_->conn_sock = accept(socket_info_->listen_sock, nullptr, nullptr); + if (socket_info_->conn_sock == kInvalidSocket) { + return MakeStatus("Accepting connection failed: %s", GetLastErrorStr()); } LOG_DEBUG("Client connected"); @@ -124,39 +190,36 @@ absl::Status ServerSocket::WaitForConnection() { } void ServerSocket::Disconnect() { - if (conn_sockfd_ != kInvalidFd) { - close(conn_sockfd_); - conn_sockfd_ = kInvalidFd; - } - + Close(&socket_info_->conn_sock); LOG_INFO("Disconnected"); } absl::Status ServerSocket::ShutdownSendingEnd() { - int result = shutdown(conn_sockfd_, SHUT_WR); - if (result != 0) { - return MakeStatus("shutdown() failed: %s", strerror(errno)); + int result = shutdown(socket_info_->conn_sock, kSendingEnd); + if (result == kSocketError) { + return MakeStatus("Socket shutdown failed: %s", GetLastErrorStr()); } return absl::OkStatus(); } absl::Status ServerSocket::Send(const void* buffer, size_t size) { - const uint8_t* curr_ptr = reinterpret_cast(buffer); - ssize_t bytes_left = size; + const char* curr_ptr = reinterpret_cast(buffer); + assert(size <= INT_MAX); + int bytes_left = static_cast(size); while (bytes_left > 0) { - ssize_t bytes_written = - HANDLE_EINTR(send(conn_sockfd_, curr_ptr, bytes_left, /*flags*/ 0)); + int bytes_written = HANDLE_EINTR( + send(socket_info_->conn_sock, curr_ptr, bytes_left, /*flags*/ 0)); if (bytes_written < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { + const int err = GetLastError(); + if (err == kErrAgain || err == kErrWouldBlock) { // Shouldn't happen as the socket should be blocking. LOG_DEBUG("Socket would block"); continue; } - return MakeStatus("write() to fd %i failed: %s", conn_sockfd_, - strerror(errno)); + return MakeStatus("Sending to socket failed: %s", GetErrorStr(err)); } bytes_left -= bytes_written; @@ -173,21 +236,22 @@ absl::Status ServerSocket::Receive(void* buffer, size_t size, return absl::OkStatus(); } - uint8_t* curr_ptr = reinterpret_cast(buffer); - ssize_t bytes_left = size; + char* curr_ptr = static_cast(buffer); + assert(size <= INT_MAX); + int bytes_left = size; while (bytes_left > 0) { - ssize_t bytes_read = - HANDLE_EINTR(recv(conn_sockfd_, curr_ptr, bytes_left, /*flags*/ 0)); + int bytes_read = HANDLE_EINTR( + recv(socket_info_->conn_sock, curr_ptr, bytes_left, /*flags*/ 0)); if (bytes_read < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { + const int err = GetLastError(); + if (err == kErrAgain || err == kErrWouldBlock) { // Shouldn't happen as the socket should be blocking. LOG_DEBUG("Socket would block"); continue; } - return MakeStatus("recv() from fd %i failed: %s", conn_sockfd_, - strerror(errno)); + return MakeStatus("Receiving from socket failed: %s", GetErrorStr(err)); } bytes_left -= bytes_read; @@ -196,7 +260,7 @@ absl::Status ServerSocket::Receive(void* buffer, size_t size, if (bytes_read == 0) { // EOF. Make sure we're not in the middle of a message. - if (bytes_left < static_cast(size)) { + if (bytes_left < static_cast(size)) { return MakeStatus("EOF after partial read"); } diff --git a/cdc_rsync_server/server_socket.h b/cdc_rsync_server/server_socket.h index 063ed56..636b949 100644 --- a/cdc_rsync_server/server_socket.h +++ b/cdc_rsync_server/server_socket.h @@ -50,11 +50,7 @@ class ServerSocket : public Socket { size_t* bytes_received) override; private: - // Listening socket file descriptor (where new connections are accepted). - int listen_sockfd_; - - // Connection socket file descriptor (where data is sent to/received from). - int conn_sockfd_; + std::unique_ptr socket_info_; }; } // namespace cdc_ft