diff --git a/Cargo.lock b/Cargo.lock index 00cc915..af39fef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -418,6 +418,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.38" @@ -1410,9 +1416,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "matchit" @@ -1436,6 +1442,15 @@ version = "2.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d0d8b92cd8358e8d229c11df9358decae64d137c5be540952c5ca7b25aea768" +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + [[package]] name = "mime" version = "0.3.17" @@ -1525,6 +1540,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.5.0", + "cfg-if", + "cfg_aliases", + "libc", + "memoffset", +] + [[package]] name = "nom" version = "7.1.3" @@ -3057,6 +3085,7 @@ dependencies = [ "jsonschema", "lazy_static", "log", + "nix", "petgraph", "prometheus", "prost", diff --git a/Cargo.toml b/Cargo.toml index 0f6094d..628a4e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ derive_more = { version = "0.99.17" } reqwest = { version = "0.12.5", features = ["json"] } jsonschema = { version = "0.18.0" } url = { version = "2.5.2" } +nix = { version = "0.29.0", features = ["net"] } [build-dependencies] tonic-build = "0.11.0" diff --git a/src/cgw_errors.rs b/src/cgw_errors.rs index a8d2175..55b121a 100644 --- a/src/cgw_errors.rs +++ b/src/cgw_errors.rs @@ -15,6 +15,8 @@ pub enum Error { RemoteDiscoveryFailedInfras(Vec), + Tcp(String), + Tls(String), Redis(String), diff --git a/src/main.rs b/src/main.rs index 87920dd..d3439a1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -28,6 +28,7 @@ extern crate lazy_static; use cgw_app_args::AppArgs; use cgw_runtime::cgw_initialize_runtimes; +use nix::sys::socket::{setsockopt, sockopt}; use tokio::{ net::TcpListener, runtime::{Builder, Handle, Runtime}, @@ -50,6 +51,14 @@ use cgw_tls::cgw_tls_create_acceptor; use crate::cgw_errors::{Error, Result}; +use tokio::net::TcpStream; + +use std::os::unix::io::AsFd; + +const CGW_TCP_KEEPALIVE_TIMEOUT: u32 = 30; +const CGW_TCP_KEEPALIVE_COUNT: u32 = 3; +const CGW_TCP_KEEPALIVE_INTERVAL: u32 = 10; + #[derive(Copy, Clone)] enum AppCoreLogLevel { /// Print debug-level messages and above @@ -159,6 +168,65 @@ impl AppCore { } } +async fn cgw_set_tcp_keepalive_options(stream: TcpStream) -> Result { + // Convert Tokio's TcpStream to std::net::TcpStream + let std_stream = match stream.into_std() { + Ok(stream) => stream, + Err(e) => { + error!("Failed to convert Tokio TcpStream into Std TcpStream"); + return Err(Error::Tcp(format!( + "Failed to convert Tokio TcpStream into Std TcpStream: {}", + e + ))); + } + }; + + // Get the raw file descriptor (socket) + let raw_fd = std_stream.as_fd(); + + // Set the socket option to enable TCP keepalive + if let Err(e) = setsockopt(&raw_fd, sockopt::KeepAlive, &true) { + error!("Failed to enable TCP keepalive: {}", e); + return Err(Error::Tcp("Failed to enable TCP keepalive".to_string())); + } + + // Set the TCP_KEEPIDLE option (keepalive time) + if let Err(e) = setsockopt(&raw_fd, sockopt::TcpKeepIdle, &CGW_TCP_KEEPALIVE_TIMEOUT) { + error!("Failed to set TCP_KEEPIDLE: {}", e); + return Err(Error::Tcp("Failed to set TCP_KEEPIDLE".to_string())); + } + + // Set the TCP_KEEPINTVL option (keepalive interval) + if let Err(e) = setsockopt(&raw_fd, sockopt::TcpKeepCount, &CGW_TCP_KEEPALIVE_COUNT) { + error!("Failed to set TCP_KEEPINTVL: {}", e); + return Err(Error::Tcp("Failed to set TCP_KEEPINTVL".to_string())); + } + + // Set the TCP_KEEPCNT option (keepalive probes count) + if let Err(e) = setsockopt( + &raw_fd, + sockopt::TcpKeepInterval, + &CGW_TCP_KEEPALIVE_INTERVAL, + ) { + error!("Failed to set TCP_KEEPCNT: {}", e); + return Err(Error::Tcp("Failed to set TCP_KEEPCNT".to_string())); + } + + // Convert the std::net::TcpStream back to Tokio's TcpStream + let stream = match TcpStream::from_std(std_stream) { + Ok(stream) => stream, + Err(e) => { + error!("Failed to convert Std TcpStream into Tokio TcpStream"); + return Err(Error::Tcp(format!( + "Failed to convert Std TcpStream into Tokio TcpStream: {}", + e + ))); + } + }; + + Ok(stream) +} + async fn server_loop(app_core: Arc) -> Result<()> { debug!("server_loop entry"); @@ -214,7 +282,18 @@ async fn server_loop(app_core: Arc) -> Result<()> { } }; - info!("ACK conn: {}", conn_idx); + let socket = match cgw_set_tcp_keepalive_options(socket).await { + Ok(s) => s, + Err(e) => { + error!( + "Failed to set TCP keepalive options. Error: {}", + e.to_string() + ); + break; + } + }; + + info!("ACK conn: {}, remote address: {}", conn_idx, remote_addr); app_core_clone.conn_ack_runtime_handle.spawn(async move { cgw_server_clone