From a275bc8e9b2f364c01d010b9447ce059c3c63c0a Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Sun, 4 Feb 2024 14:20:40 +0100 Subject: [PATCH 01/13] net/tcp: add logging to BIC Signed-off-by: Valentin Obst --- net/ipv4/tcp_bic.c | 51 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/net/ipv4/tcp_bic.c b/net/ipv4/tcp_bic.c index 58358bf92e1b8a..757033a32fd78b 100644 --- a/net/ipv4/tcp_bic.c +++ b/net/ipv4/tcp_bic.c @@ -16,6 +16,7 @@ #include #include +#include #include #define BICTCP_BETA_SCALE 1024 /* Scale factor beta calculation @@ -55,6 +56,7 @@ struct bictcp { u32 epoch_start; /* beginning of an epoch */ #define ACK_RATIO_SHIFT 4 u32 delayed_ack; /* estimate the ratio of Packets/ACKs << 4 */ + u64 start_time; }; static inline void bictcp_reset(struct bictcp *ca) @@ -65,6 +67,7 @@ static inline void bictcp_reset(struct bictcp *ca) ca->last_time = 0; ca->epoch_start = 0; ca->delayed_ack = 2 << ACK_RATIO_SHIFT; + ca->start_time = ktime_get_boot_fast_ns(); } static void bictcp_init(struct sock *sk) @@ -75,6 +78,19 @@ static void bictcp_init(struct sock *sk) if (initial_ssthresh) tcp_sk(sk)->snd_ssthresh = initial_ssthresh; + + pr_info("Socket created: start %llu\n", ca->start_time); +} + +static void bictcp_release(struct sock* sk) +{ + struct bictcp *ca = inet_csk_ca(sk); + + pr_info( + "Socket destroyed: start %llu, end %llu\n", + ca->start_time, + ktime_get_boot_fast_ns() + ); } /* @@ -147,11 +163,23 @@ static void bictcp_cong_avoid(struct sock *sk, u32 ack, u32 acked) if (tcp_in_slow_start(tp)) { acked = tcp_slow_start(tp, acked); - if (!acked) + if (!acked) { + pr_info( + "New cwnd: %u, time %llu, ssthresh %u, start %llu, ss 1\n", + tp->snd_cwnd, ktime_get_boot_fast_ns(), + tp->snd_ssthresh, ca->start_time + ); return; + } } bictcp_update(ca, tcp_snd_cwnd(tp)); tcp_cong_avoid_ai(tp, ca->cnt, acked); + + pr_info( + "New cwnd: %u, time %llu, ssthresh %u, start %llu, ss 1\n", + tp->snd_cwnd, ktime_get_boot_fast_ns(), + tp->snd_ssthresh, ca->start_time + ); } /* @@ -163,6 +191,12 @@ static u32 bictcp_recalc_ssthresh(struct sock *sk) const struct tcp_sock *tp = tcp_sk(sk); struct bictcp *ca = inet_csk_ca(sk); + pr_info( + "Enter fast retransmit: time %llu, start %llu\n", + ktime_get_boot_fast_ns(), + ca->start_time + ); + ca->epoch_start = 0; /* end of epoch */ /* Wmax and fast convergence */ @@ -180,8 +214,20 @@ static u32 bictcp_recalc_ssthresh(struct sock *sk) static void bictcp_state(struct sock *sk, u8 new_state) { - if (new_state == TCP_CA_Loss) + if (new_state == TCP_CA_Loss) { + struct bictcp *ca = inet_csk_ca(sk); + u64 tmp = ca->start_time; + + pr_info( + "Retransmission timeout fired: time %llu, start %llu\n", + ktime_get_boot_fast_ns(), + ca->start_time + ); + bictcp_reset(inet_csk_ca(sk)); + + ca->start_time = tmp; + } } /* Track delayed acknowledgment ratio using sliding window @@ -201,6 +247,7 @@ static void bictcp_acked(struct sock *sk, const struct ack_sample *sample) static struct tcp_congestion_ops bictcp __read_mostly = { .init = bictcp_init, + .release = bictcp_release, .ssthresh = bictcp_recalc_ssthresh, .cong_avoid = bictcp_cong_avoid, .set_state = bictcp_state, From 12e579b1d2b36da960a79ee6e66c38817a1b6a50 Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Sat, 27 Jan 2024 17:52:54 +0100 Subject: [PATCH 02/13] scripts/gen_rust_analyzer: search modules in net/ When generating the `rust-project.json`, also search for modules in the `net/` folder. Signed-off-by: Valentin Obst --- scripts/generate_rust_analyzer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/generate_rust_analyzer.py b/scripts/generate_rust_analyzer.py index fc52bc41d3e7bd..4a687e36091eb5 100755 --- a/scripts/generate_rust_analyzer.py +++ b/scripts/generate_rust_analyzer.py @@ -116,7 +116,7 @@ def is_root_crate(build_file, target): # Then, the rest outside of `rust/`. # # We explicitly mention the top-level folders we want to cover. - extra_dirs = map(lambda dir: srctree / dir, ("samples", "drivers")) + extra_dirs = map(lambda dir: srctree / dir, ("samples", "drivers", "net")) if external_src is not None: extra_dirs = [external_src] for folder in extra_dirs: From bd2d900c32b8c762ba7a28b1ceebd37f83fa6327 Mon Sep 17 00:00:00 2001 From: Wedson Almeida Filho Date: Fri, 29 Sep 2023 17:58:08 -0300 Subject: [PATCH 03/13] rust: introduce `InPlaceModule` This allows modules to be initialised in-place in pinned memory, which enables the usage of pinned types (e.g., mutexes, spinlocks, driver registrations, etc.) in modules without any extra allocations. Drivers that don't need this may continue to implement `Module` without any changes. Signed-off-by: Wedson Almeida Filho [kernel@valentinobst.de: remove feature return_position_impl_trait_in_trait as it is now stabilised] [kernel@valentinobst.de: remove `Send` trait bound on `Module` and `InPlaceModule`] --- rust/kernel/lib.rs | 23 +++++++++++++++++++++++ rust/macros/module.rs | 18 ++++++------------ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index 1e5f229b82638e..e10c40f9fee3d6 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -75,6 +75,29 @@ pub trait Module: Sized + Sync { fn init(module: &'static ThisModule) -> error::Result; } +/// A module that is pinned and initialised in-place. +pub trait InPlaceModule: Sync { + /// Creates an initialiser for the module. + /// + /// It is called when the module is loaded. + fn init(module: &'static ThisModule) -> impl init::PinInit; +} + +impl InPlaceModule for T { + fn init(module: &'static ThisModule) -> impl init::PinInit { + let initer = move |slot: *mut Self| { + let m = ::init(module)?; + + // SAFETY: `slot` is valid for write per the contract with `pin_init_from_closure`. + unsafe { slot.write(m) }; + Ok(()) + }; + + // SAFETY: On success, `initer` always fully initialises an instance of `Self`. + unsafe { init::pin_init_from_closure(initer) } + } +} + /// Equivalent to `THIS_MODULE` in the C API. /// /// C header: [`include/linux/export.h`](srctree/include/linux/export.h) diff --git a/rust/macros/module.rs b/rust/macros/module.rs index d62d8710d77ab0..9152bd691c5a6d 100644 --- a/rust/macros/module.rs +++ b/rust/macros/module.rs @@ -208,7 +208,7 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { #[used] static __IS_RUST_MODULE: () = (); - static mut __MOD: Option<{type_}> = None; + static mut __MOD: core::mem::MaybeUninit<{type_}> = core::mem::MaybeUninit::uninit(); // SAFETY: `__this_module` is constructed by the kernel at load time and will not be // freed until the module is unloaded. @@ -270,23 +270,17 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { }} fn __init() -> core::ffi::c_int {{ - match <{type_} as kernel::Module>::init(&THIS_MODULE) {{ - Ok(m) => {{ - unsafe {{ - __MOD = Some(m); - }} - return 0; - }} - Err(e) => {{ - return e.to_errno(); - }} + let initer = <{type_} as kernel::InPlaceModule>::init(&THIS_MODULE); + match unsafe {{ initer.__pinned_init(__MOD.as_mut_ptr()) }} {{ + Ok(m) => 0, + Err(e) => e.to_errno(), }} }} fn __exit() {{ unsafe {{ // Invokes `drop()` on `__MOD`, which should be used for cleanup. - __MOD = None; + __MOD.assume_init_drop(); }} }} From a13e70088cd5fad8e3f1fe1254b123db80486aea Mon Sep 17 00:00:00 2001 From: Wedson Almeida Filho Date: Fri, 29 Sep 2023 17:58:09 -0300 Subject: [PATCH 04/13] rust: init: introduce `Opaque::try_ffi_init` We'll need it, for example, when calling `register_filesystem` to initialise a file system registration. Signed-off-by: Wedson Almeida Filho --- rust/kernel/types.rs | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index 8aabe348b19473..23824782ad7230 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -240,14 +240,22 @@ impl Opaque { /// uninitialized. Additionally, access to the inner `T` requires `unsafe`, so the caller needs /// to verify at that point that the inner value is valid. pub fn ffi_init(init_func: impl FnOnce(*mut T)) -> impl PinInit { + Self::try_ffi_init(move |slot| { + init_func(slot); + Ok(()) + }) + } + + /// Similar to [`Self::ffi_init`], except that the closure can fail. + /// + /// To avoid leaks on failure, the closure must drop any fields it has initialised before the + /// failure. + pub fn try_ffi_init( + init_func: impl FnOnce(*mut T) -> Result<(), E>, + ) -> impl PinInit { // SAFETY: We contain a `MaybeUninit`, so it is OK for the `init_func` to not fully // initialize the `T`. - unsafe { - init::pin_init_from_closure::<_, ::core::convert::Infallible>(move |slot| { - init_func(Self::raw_get(slot)); - Ok(()) - }) - } + unsafe { init::pin_init_from_closure(|slot| init_func(Self::raw_get(slot))) } } /// Returns a raw pointer to the opaque data. From 08fc8c35f996d1b6552360e05399d5a10dabca29 Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Sun, 4 Feb 2024 14:24:39 +0100 Subject: [PATCH 05/13] rust/kernel: add time primitives for net/tcp In net/tcp time values are usually 32bit wide unsigned integers, and either in units of jiffies, microseconds or milliseconds. Add types, constants, and functions to work with 32bit time values. This is, for example, used in the CUBIC and BIC CCAs. Signed-off-by: Valentin Obst --- rust/kernel/time.rs | 71 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/rust/kernel/time.rs b/rust/kernel/time.rs index 25a896eed4689f..875fafd44f51ee 100644 --- a/rust/kernel/time.rs +++ b/rust/kernel/time.rs @@ -8,9 +8,43 @@ /// The time unit of Linux kernel. One jiffy equals (1/HZ) second. pub type Jiffies = core::ffi::c_ulong; +/// Jiffies, but with a fixed width of 32bit. +pub type Jiffies32 = u32; + /// The millisecond time unit. pub type Msecs = core::ffi::c_uint; +/// Milliseconds per second. +pub const MSEC_PER_SEC: Msecs = 1000; + +/// The milliseconds time unit with a fixed width of 32bit. +/// +/// This is used in networking. +pub type Msecs32 = u32; + +/// The microseconds time unit. +pub type Usecs = u64; + +/// Microseconds per millisecond. +pub const USEC_PER_MSEC: Usecs = 1000; + +/// Microseconds per second. +pub const USEC_PER_SEC: Usecs = 1_000_000; + +/// The microseconds time unit with a fixed width of 32bit. +/// +/// This is used in networking. +pub type Usecs32 = u32; + +/// The nanosecond time unit. +pub type Nsecs = u64; + +/// Nanoseconds per microsecond. +pub const NSEC_PER_USEC: Nsecs = 1000; + +/// Nanoseconds per millisecond. +pub const NSEC_PER_MSEC: Nsecs = 1_000_000; + /// Converts milliseconds to jiffies. #[inline] pub fn msecs_to_jiffies(msecs: Msecs) -> Jiffies { @@ -18,3 +52,40 @@ pub fn msecs_to_jiffies(msecs: Msecs) -> Jiffies { // matter what the argument is. unsafe { bindings::__msecs_to_jiffies(msecs) } } + +/// Converts jiffies to milliseconds. +#[inline] +pub fn jiffies_to_msecs(jiffies: Jiffies) -> Msecs { + // SAFETY: The `__msecs_to_jiffies` function is always safe to call no + // matter what the argument is. + unsafe { bindings::jiffies_to_msecs(jiffies) } +} + +/// Returns the current time in 32bit jiffies. +#[inline] +pub fn jiffies32() -> Jiffies32 { + // SAFETY: It is always atomic to read the lower 32bit of jiffies. + unsafe { bindings::jiffies as u32 } +} + +/// Returns the time elapsed since system boot, in nanoseconds. Does include the +/// time the system was suspended. +#[inline] +pub fn ktime_get_boot_fast_ns() -> Nsecs { + // SAFETY: FFI call without safety requirements. + unsafe { bindings::ktime_get_boot_fast_ns() } +} + +/// Returns the time elapsed since system boot, in 32bit microseconds. Does +/// include the time the system was suspended. +#[inline] +pub fn ktime_get_boot_fast_us32() -> Usecs32 { + (ktime_get_boot_fast_ns() / NSEC_PER_USEC) as Usecs32 +} + +/// Returns the time elapsed since system boot, in 32bit milliseconds. Does +/// include the time the system was suspended. +#[inline] +pub fn ktime_get_boot_fast_ms32() -> Msecs32 { + (ktime_get_boot_fast_ns() / NSEC_PER_MSEC) as Msecs32 +} From a01311bd5ffd22da71f5c98d83c6c89d56837439 Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Sun, 4 Feb 2024 14:58:03 +0100 Subject: [PATCH 06/13] rust/kernel: add `field_size` macro Add a macro to determine the size of a structure field at compile time. This is used by the CCA abstractions to ensure that the private data of every CCA will fit into the space that the kernel provides for it. Signed-off-by: Valentin Obst --- rust/kernel/types.rs | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index 23824782ad7230..91ba53a783516a 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -398,3 +398,36 @@ pub enum Either { /// Constructs an instance of [`Either`] containing a value of type `R`. Right(R), } + +/// Returns the size of a struct field in bytes. +/// +/// This macro can be used in const contexts. +/// +/// # Examples +/// +/// ``` +/// use kernel::field_size; +/// +/// struct Foo { +/// bar: u64, +/// baz: [i8; 100], +/// } +/// +/// assert_eq!(field_size!(Foo, bar), 8); +/// assert_eq!(field_size!(Foo, baz), 100); +/// ``` +// Link: https://stackoverflow.com/a/70222282 +#[macro_export] +macro_rules! field_size { + ($t:ty, $field:ident) => {{ + let m = core::mem::MaybeUninit::<$t>::uninit(); + // SAFETY: It is OK to dereference invalid pointers inside of + // `addr_of!`. + let p = unsafe { core::ptr::addr_of!((*m.as_ptr()).$field) }; + + const fn size_of_raw(_: *const T) -> usize { + core::mem::size_of::() + } + size_of_raw(p) + }}; +} From 5c84c6c31babd57ff0882770509d8ad6865400d9 Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Fri, 1 Mar 2024 23:28:50 +0100 Subject: [PATCH 07/13] EDITME: cover title for tcp-cca-rfc # Describe the purpose of this series. The information you put here # will be used by the project maintainer to make a decision whether # your patches should be reviewed, and in what priority order. Please be # very detailed and link to any relevant discussions or sites that the # maintainer can review to better understand your proposed changes. If you # only have a single patch in your series, the contents of the cover # letter will be appended to the "under-the-cut" portion of the patch. # Lines starting with # will be removed from the cover letter. You can # use them to add notes or reminders to yourself. If you want to use # markdown headers in your cover letter, start the line with ">#". # You can add trailers to the cover letter. Any email addresses found in # these trailers will be added to the addresses specified/generated # during the b4 send stage. You can also run "b4 prep --auto-to-cc" to # auto-populate the To: and Cc: trailers based on the code being # modified. Signed-off-by: Valentin Obst --- b4-submit-tracking --- # This section is used internally by b4 prep for tracking purposes. { "series": { "revision": 1, "change-id": "20240301-tcp-cca-rfc-1a0fbcaa533c", "prefixes": [] } } From ded5104d0b208b55b308711125c563e21dbfcc56 Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Fri, 1 Mar 2024 23:02:10 +0100 Subject: [PATCH 08/13] rust/net: add `sock`, `tcp_sock` and icsk wrappers Signed-off-by: Valentin Obst --- net/ipv4/Kconfig | 16 ++++ rust/bindings/bindings_helper.h | 1 + rust/helpers.c | 31 +++++++ rust/kernel/net.rs | 4 + rust/kernel/net/sock.rs | 143 ++++++++++++++++++++++++++++++++ rust/kernel/net/tcp.rs | 128 ++++++++++++++++++++++++++++ 6 files changed, 323 insertions(+) create mode 100644 rust/kernel/net/sock.rs create mode 100644 rust/kernel/net/tcp.rs diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 8e94ed7c56a0ea..0566dda5a6e3a6 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -466,6 +466,22 @@ config INET_DIAG_DESTROY had been disconnected. If unsure, say N. +config RUST_SOCK_ABSTRACTIONS + bool "INET: Rust sock abstractions" + depends on RUST + help + Adds Rust abstractions for working with `struct sock`s. + + If unsure, say N. + +config RUST_TCP_ABSTRACTIONS + bool "TCP: Rust abstractions" + depends on RUST_SOCK_ABSTRACTIONS + help + Adds support for writing Rust kernel modules that integrate with TCP. + + If unsure, say N. + menuconfig TCP_CONG_ADVANCED bool "TCP: advanced congestion control" help diff --git a/rust/bindings/bindings_helper.h b/rust/bindings/bindings_helper.h index 65b98831b97560..978885d6336272 100644 --- a/rust/bindings/bindings_helper.h +++ b/rust/bindings/bindings_helper.h @@ -17,6 +17,7 @@ #include #include #include +#include /* `bindgen` gets confused at certain things. */ const size_t RUST_CONST_HELPER_ARCH_SLAB_MINALIGN = ARCH_SLAB_MINALIGN; diff --git a/rust/helpers.c b/rust/helpers.c index 70e59efd92bc43..ae88a4291bb08f 100644 --- a/rust/helpers.c +++ b/rust/helpers.c @@ -31,6 +31,7 @@ #include #include #include +#include __noreturn void rust_helper_BUG(void) { @@ -157,6 +158,36 @@ void rust_helper_init_work_with_key(struct work_struct *work, work_func_t func, } EXPORT_SYMBOL_GPL(rust_helper_init_work_with_key); +bool rust_helper_tcp_in_slow_start(const struct tcp_sock *tp) +{ + return tcp_in_slow_start(tp); +} +EXPORT_SYMBOL_GPL(rust_helper_tcp_in_slow_start); + +bool rust_helper_tcp_is_cwnd_limited(const struct sock *sk) +{ + return tcp_is_cwnd_limited(sk); +} +EXPORT_SYMBOL_GPL(rust_helper_tcp_is_cwnd_limited); + +struct tcp_sock *rust_helper_tcp_sk(struct sock *sk) +{ + return tcp_sk(sk); +} +EXPORT_SYMBOL_GPL(rust_helper_tcp_sk); + +u32 rust_helper_tcp_snd_cwnd(const struct tcp_sock *tp) +{ + return tcp_snd_cwnd(tp); +} +EXPORT_SYMBOL_GPL(rust_helper_tcp_snd_cwnd); + +struct inet_connection_sock *rust_helper_inet_csk(const struct sock *sk) +{ + return inet_csk(sk); +} +EXPORT_SYMBOL_GPL(rust_helper_inet_csk); + /* * `bindgen` binds the C `size_t` type as the Rust `usize` type, so we can * use it in contexts where Rust expects a `usize` like slice (array) indices. diff --git a/rust/kernel/net.rs b/rust/kernel/net.rs index fe415cb369d3ac..a17555940d6418 100644 --- a/rust/kernel/net.rs +++ b/rust/kernel/net.rs @@ -4,3 +4,7 @@ #[cfg(CONFIG_RUST_PHYLIB_ABSTRACTIONS)] pub mod phy; +#[cfg(CONFIG_RUST_SOCK_ABSTRACTIONS)] +pub mod sock; +#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] +pub mod tcp; diff --git a/rust/kernel/net/sock.rs b/rust/kernel/net/sock.rs new file mode 100644 index 00000000000000..37f5c05545c3b6 --- /dev/null +++ b/rust/kernel/net/sock.rs @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! Representation of a C `struct sock`. +//! +//! C header: [`include/net/sock.h`](srctree/include/net/sock.h) + +#[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] +use crate::net::tcp::{self, InetConnectionSock, TcpSock}; +use crate::types::Opaque; +use core::convert::TryFrom; +use core::ptr::addr_of; + +/// Representation of a C `struct sock`. +/// +/// Not intended to be used directly by modules. Abstractions should provide a +/// safe interface to only those operations that are OK to use for the module. +/// +/// # Invariants +/// +/// Referencing a `sock` using this struct asserts that you are in +/// a context where all safe methods defined on this struct are indeed safe to +/// call. +#[repr(transparent)] +pub(crate) struct Sock { + sk: Opaque, +} + +impl Sock { + /// Returns a raw pointer to the wrapped `struct sock`. + /// + /// It is up to the caller to use it correctly. + #[inline] + pub(crate) fn raw_sk_mut(&mut self) -> *mut bindings::sock { + self.sk.get() + } + + /// Returns the sockets pacing rate in bytes per second. + #[inline] + pub(crate) fn sk_pacing_rate(&self) -> u64 { + // NOTE: C uses READ_ONCE for this field, thus `read_volatile`. + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. It is a C unsigned + // long so we can always convert it to a u64 without loss. + unsafe { addr_of!((*self.sk.get()).sk_pacing_rate).read_volatile() as u64 } + } + + /// Returns the sockets pacing status. + #[inline] + pub(crate) fn sk_pacing_status(&self) -> Result { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { Pacing::try_from(*addr_of!((*self.sk.get()).sk_pacing_status)) } + } + + /// Returns the sockets maximum GSO segment size to build. + #[inline] + pub(crate) fn sk_gso_max_size(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. It is an unsigned int + // and we are guaranteed that this will always fit into a u32. + unsafe { *addr_of!((*self.sk.get()).sk_gso_max_size) as u32 } + } + + /// Returns the [`TcpSock`] that is containing the `Sock`. + /// + /// # Safety + /// + /// `sk` must be valid for `tcp_sk`. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn tcp_sk<'a>(&'a self) -> &'a TcpSock { + // SAFETY: + // - Downcasting via `tcp_sk` is OK by the functions precondition. + // - The cast is OK since `TcpSock` is transparent to `struct tcp_sock`. + unsafe { &*(bindings::tcp_sk(self.sk.get()) as *const TcpSock) } + } + + /// Returns the [`TcpSock`] that is containing the `Sock`. + /// + /// # Safety + /// + /// `sk` must be valid for `tcp_sk`. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn tcp_sk_mut<'a>(&'a mut self) -> &'a mut TcpSock { + // SAFETY: + // - Downcasting via `tcp_sk` is OK by the functions precondition. + // - The cast is OK since `TcpSock` is transparent to `struct tcp_sock`. + unsafe { &mut *(bindings::tcp_sk(self.sk.get()) as *mut TcpSock) } + } + + /// Returns the [`InetConnectionSock`] view of this socket. + /// + /// # Safety + /// + /// `sk` must be valid for `inet_csk`. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn inet_csk<'a>(&'a self) -> &'a InetConnectionSock { + // SAFETY: + // - Calling `inet_csk` is OK by the functions precondition. + // - The cast is OK since `InetConnectionSock` is transparent to + // `struct inet_connection_sock`. + unsafe { &*(bindings::inet_csk(self.sk.get()) as *const InetConnectionSock) } + } + + /// Tests if the connection's sending rate is limited by the cwnd. + /// + /// # Safety + /// + /// `sk` must be valid for `tcp_is_cwnd_limited`. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn tcp_is_cwnd_limited(&self) -> bool { + // SAFETY: Calling `tcp_is_cwnd_limited` is OK by the functions + // precondition. + unsafe { bindings::tcp_is_cwnd_limited(self.sk.get()) } + } +} + +/// The socket's pacing status. +#[repr(u32)] +#[allow(missing_docs)] +pub enum Pacing { + r#None = bindings::sk_pacing_SK_PACING_NONE, + Needed = bindings::sk_pacing_SK_PACING_NEEDED, + Fq = bindings::sk_pacing_SK_PACING_FQ, +} + +// TODO: Replace with automatically generated code by bindgen when it becomes +// possible. +impl TryFrom for Pacing { + type Error = (); + + fn try_from(val: u32) -> Result { + match val { + x if x == Pacing::r#None as u32 => Ok(Pacing::r#None), + x if x == Pacing::Needed as u32 => Ok(Pacing::Needed), + x if x == Pacing::Fq as u32 => Ok(Pacing::Fq), + _ => Err(()), + } + } +} diff --git a/rust/kernel/net/tcp.rs b/rust/kernel/net/tcp.rs new file mode 100644 index 00000000000000..410095f01852dc --- /dev/null +++ b/rust/kernel/net/tcp.rs @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! Transmission Control Protocol (TCP). + +use crate::time; +use crate::types::Opaque; +use core::{num, ptr}; + +/// Representation of a `struct inet_connection_sock`. +/// +/// # Invariants +/// +/// Referencing a `inet_connection_sock` using this struct asserts that you are +/// in a context where all safe methods defined on this struct are indeed safe +/// to call. +/// +/// C header: [`include/net/inet_connection_sock.h`](srctree/include/net/inet_connection_sock.h) +#[repr(transparent)] +pub struct InetConnectionSock { + icsk: Opaque, +} + +/// Representation of a `struct tcp_sock`. +/// +/// # Invariants +/// +/// Referencing a `tcp_sock` using this struct asserts that you are in +/// a context where all safe methods defined on this struct are indeed safe to +/// call. +/// +/// C header: [`include/linux/tcp.h`](srctree/include/linux/tcp.h) +#[repr(transparent)] +pub struct TcpSock { + tp: Opaque, +} + +impl TcpSock { + /// Returns true iff `snd_cwnd < snd_ssthresh`. + #[inline] + pub fn in_slow_start(&self) -> bool { + // SAFETY: The struct invariant ensures that we may call this function + // without additional synchronization. + unsafe { bindings::tcp_in_slow_start(self.tp.get()) } + } + + /// Performs the standard slow start increment of cwnd. + /// + /// If this causes the socket to exit slow start, any leftover ACKs are + /// returned. + #[inline] + pub fn slow_start(&mut self, acked: u32) -> u32 { + // SAFETY: The struct invariant ensures that we may call this function + // without additional synchronization. + unsafe { bindings::tcp_slow_start(self.tp.get(), acked) } + } + + /// Performs the standard increase of cwnd during congestion avoidance. + /// + /// The increase per ACK is upper bounded by `1 / w`. + #[inline] + pub fn cong_avoid_ai(&mut self, w: num::NonZeroU32, acked: u32) { + // SAFETY: The struct invariant ensures that we may call this function + // without additional synchronization. + unsafe { bindings::tcp_cong_avoid_ai(self.tp.get(), w.get(), acked) }; + } + + /// Returns the connection's current cwnd. + #[inline] + pub fn snd_cwnd(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may call this function + // without additional synchronization. + unsafe { bindings::tcp_snd_cwnd(self.tp.get()) } + } + + /// Returns the connection's current ssthresh. + #[inline] + pub fn snd_ssthresh(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).snd_ssthresh) } + } + + /// Returns the sequence number of the next byte that will be sent. + #[inline] + pub fn snd_nxt(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).snd_nxt) } + } + + /// Returns the sequence number of the first unacknowledged byte. + #[inline] + pub fn snd_una(&self) -> u32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).snd_una) } + } + + /// Returns the time when the last packet was received or sent. + #[inline] + pub fn tcp_mstamp(&self) -> time::Usecs { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).tcp_mstamp) } + } + + /// Sets the connection's ssthresh. + #[inline] + pub fn set_snd_ssthresh(&mut self, new: u32) { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of_mut!((*self.tp.get()).snd_ssthresh) = new }; + } + + /// Returns the timestamp of the last send data packet in 32bit Jiffies. + #[inline] + pub fn lsndtime(&self) -> time::Jiffies32 { + // SAFETY: The struct invariant ensures that we may access + // this field without additional synchronization. + unsafe { *ptr::addr_of!((*self.tp.get()).lsndtime) as time::Jiffies32 } + } +} + +/// Tests if `sqn_1` comes after `sqn_2`. +#[inline] +pub fn after(sqn_1: u32, sqn_2: u32) -> bool { + (sqn_2.wrapping_sub(sqn_1) as i32) < 0 +} From f410fa041a1334d4340b5c06d9dc5f470e6ae63a Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Fri, 1 Mar 2024 23:13:51 +0100 Subject: [PATCH 09/13] rust/net: add CCA abstractions Signed-off-by: Valentin Obst --- rust/helpers.c | 6 + rust/kernel/net/sock.rs | 40 +++ rust/kernel/net/tcp.rs | 17 + rust/kernel/net/tcp/cong.rs | 645 ++++++++++++++++++++++++++++++++++++ 4 files changed, 708 insertions(+) create mode 100644 rust/kernel/net/tcp/cong.rs diff --git a/rust/helpers.c b/rust/helpers.c index ae88a4291bb08f..fc01594fafbe8a 100644 --- a/rust/helpers.c +++ b/rust/helpers.c @@ -188,6 +188,12 @@ struct inet_connection_sock *rust_helper_inet_csk(const struct sock *sk) } EXPORT_SYMBOL_GPL(rust_helper_inet_csk); +void *rust_helper_inet_csk_ca(struct sock *sk) +{ + return inet_csk_ca(sk); +} +EXPORT_SYMBOL_GPL(rust_helper_inet_csk_ca); + /* * `bindgen` binds the C `size_t` type as the Rust `usize` type, so we can * use it in contexts where Rust expects a `usize` like slice (array) indices. diff --git a/rust/kernel/net/sock.rs b/rust/kernel/net/sock.rs index 37f5c05545c3b6..c4fc539303a88b 100644 --- a/rust/kernel/net/sock.rs +++ b/rust/kernel/net/sock.rs @@ -89,6 +89,46 @@ impl Sock { unsafe { &mut *(bindings::tcp_sk(self.sk.get()) as *mut TcpSock) } } + /// Returns the [private data] of the instance of the CCA used by this + /// socket. + /// + /// [private data]: tcp::cong::Algorithm::Data + /// + /// # Safety + /// + /// - `sk` must be valid for `inet_csk_ca`, + /// - `sk` must use the CCA `T`, the `init` CB of the CCA must have been + /// called, the `release` CB of the CCA must not have been called. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn inet_csk_ca<'a, T: tcp::cong::Algorithm + ?Sized>( + &'a self, + ) -> &'a T::Data { + // SAFETY: By the function's preconditions, calling `inet_csk_ca` is OK + // and the returned pointer points to a valid instance of `T::Data`. + unsafe { &*(bindings::inet_csk_ca(self.sk.get()) as *const T::Data) } + } + + /// Returns the [private data] of the instance of the CCA used by this + /// socket. + /// + /// [private data]: tcp::cong::Algorithm::Data + /// + /// # Safety + /// + /// - `sk` must be valid for `inet_csk_ca`, + /// - `sk` must use the CCA `T`, the `init` CB of the CCA must have been + /// called, the `release` CB of the CCA must not have been called. + #[inline] + #[cfg(CONFIG_RUST_TCP_ABSTRACTIONS)] + pub(crate) unsafe fn inet_csk_ca_mut<'a, T: tcp::cong::Algorithm + ?Sized>( + &'a mut self, + ) -> &'a mut T::Data { + // SAFETY: By the function's preconditions, calling `inet_csk_ca` is OK + // and the returned pointer points to a valid instance of `T::Data`. + unsafe { &mut *(bindings::inet_csk_ca(self.sk.get()) as *mut T::Data) } + } + /// Returns the [`InetConnectionSock`] view of this socket. /// /// # Safety diff --git a/rust/kernel/net/tcp.rs b/rust/kernel/net/tcp.rs index 410095f01852dc..62002c777411a5 100644 --- a/rust/kernel/net/tcp.rs +++ b/rust/kernel/net/tcp.rs @@ -6,6 +6,8 @@ use crate::time; use crate::types::Opaque; use core::{num, ptr}; +pub mod cong; + /// Representation of a `struct inet_connection_sock`. /// /// # Invariants @@ -20,6 +22,21 @@ pub struct InetConnectionSock { icsk: Opaque, } +impl InetConnectionSock { + /// Returns the congestion control state of this socket. + #[inline] + pub fn ca_state(&self) -> Result { + const CA_STATE_MASK: u8 = 0b11111; + // TODO: Replace code to access the bit field with automatically + // generated code by bindgen when it becomes possible. + // SAFETY: By the type invariants, it is okay to read `icsk_ca_state`, which is the first + // member of the bitfield and has a size of five. + cong::State::try_from(unsafe { + *(ptr::addr_of!((*self.icsk.get())._bitfield_1).cast::()) & CA_STATE_MASK + }) + } +} + /// Representation of a `struct tcp_sock`. /// /// # Invariants diff --git a/rust/kernel/net/tcp/cong.rs b/rust/kernel/net/tcp/cong.rs new file mode 100644 index 00000000000000..1640c85ffe7e4f --- /dev/null +++ b/rust/kernel/net/tcp/cong.rs @@ -0,0 +1,645 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! Congestion control algorithms (CCA). +//! +//! Abstractions for implementing pluggable CCAs in Rust. + +use crate::bindings; +use crate::error::{self, Error, VTABLE_DEFAULT_ERROR}; +use crate::init::PinInit; +use crate::net::sock; +use crate::prelude::{pr_err, vtable}; +use crate::str::CStr; +use crate::time; +use crate::types::Opaque; +use crate::ThisModule; +use crate::{build_assert, build_error, field_size, try_pin_init}; +use core::convert::TryFrom; +use core::marker::PhantomData; +use core::pin::Pin; +use macros::{pin_data, pinned_drop}; + +use super::{InetConnectionSock, TcpSock}; + +/// Congestion control algorithm (CCA). +/// +/// A CCA is implemented as a set of callbacks that are invoked whenever +/// specific events occur in a connection. Each socket has its own instance of +/// some CCA. Every instance of a CCA has its own private data that is stored in +/// the socket and is mutated by the callbacks. +/// +/// Callbacks that operate on the same instance are guaranteed to run +/// sequentially, and each callback has exclusive mutable access to the private +/// data of the instance it operates on. +#[vtable] +pub trait Algorithm { + /// Private data. Each socket has its own instance. + type Data: Default + Send + Sync; + + /// Name of the algorithm. + const NAME: &'static CStr; + + /// Called when entering CWR, Recovery, or Loss states from Open or Disorder + /// states. Returns the new slow start threshold. + fn ssthresh(sk: &mut Sock<'_, Self>) -> u32; + + /// Called when one of the events in [`Event`] occurs. + fn cwnd_event(_sk: &mut Sock<'_, Self>, _ev: Event) { + build_error!(VTABLE_DEFAULT_ERROR); + } + + /// Called towards the end of processing an ACK if a cwnd increase is + /// possible. Performs a new cwnd calculation and sets it on the socket. + // Note: In fact, one of `cong_avoid` and `cond_control` is required. + // (see `tcp_validate_congestion_control`) + fn cong_avoid(sk: &mut Sock<'_, Self>, ack: u32, acked: u32); + + /// Called before the sender's congestion state is changed. + fn set_state(_sk: &mut Sock<'_, Self>, _new_state: State) { + build_error!(VTABLE_DEFAULT_ERROR); + } + + /// Called when removing ACKed packets from the retransmission queue. Can be + /// used for packet ACK accounting. + fn pkts_acked(_sk: &mut Sock<'_, Self>, _sample: &AckSample) { + build_error!(VTABLE_DEFAULT_ERROR); + } + + /// Called to undo a recent cwnd reduction that was found to has been + /// unnecessary. Returns the new value of cwnd. + fn undo_cwnd(sk: &mut Sock<'_, Self>) -> u32; + + /// Initializes the private data. + /// + /// When this function is called, [`sk.inet_csk_ca()`] will contain a value + /// returned by `Self::Data::default()`. + /// + /// Only implement this function when you need to perform additional setup + /// tasks. + /// + /// [`sk.inet_csk_ca()`]: Sock::inet_csk_ca + fn init(_sk: &mut Sock<'_, Self>) { + build_error!(VTABLE_DEFAULT_ERROR); + } + + /// Cleans up the private data. + /// + /// After this function returns, [`sk.inet_csk_ca()`] will be dropped. + /// + /// Only implement this function when you need to perform additional cleanup + /// tasks. + /// + /// [`sk.inet_csk_ca()`]: Sock::inet_csk_ca + fn release(_sk: &mut Sock<'_, Self>) { + build_error!(VTABLE_DEFAULT_ERROR); + } +} + +pub mod reno { + //! TCP Reno congestion control. + //! + //! Algorithms may choose to invoke these callbacks instead of providing + //! their own implementation. This is convenient as a new CCA might have + //! the same logic as an existing one in some of its callbacks. + use super::{Algorithm, Sock}; + use crate::bindings; + + /// Implementation of [`undo_cwnd`] that returns `max(snd_cwnd, prior_cwnd)`, + /// where `prior_cwnd` is the value of cwnd before the last reduction. + /// + /// [`undo_cwnd`]: super::Algorithm::undo_cwnd + #[inline] + pub fn undo_cwnd(sk: &mut Sock<'_, T>) -> u32 { + // SAFETY: + // - `sk` has been passed to the callback that invoked us, + // - it is OK to pass it to the callback of the Reno algorithm as it + // will never touch the private data. + unsafe { bindings::tcp_reno_undo_cwnd(sk.sk.raw_sk_mut()) } + } +} + +/// Representation of the `struct sock *` that is passed to the callbacks of the +/// CCA. +/// +/// Every callback receives a pointer to the socket that it is operating on. +/// There are certain operations that callbacks are allowed to perform on the +/// socket, and this type just exposes methods for performing those. This +/// prevents callbacks from performing arbitrary manipulations on the socket. +// TODO: Currently all callbacks can perform all operations. However, this +// might be too permissive, e.g., the `pkts_acked` callback should probably not +// be changing cwnd... +/// +/// # Invariants +/// +/// The wrapped `sk` must have been obtained as the argument to a callback of +/// the congestion algorithm `T` (other than the `init` cb) and may only be used +/// for the duration of that callback. In particular: +/// +/// - `sk` points to a valid `struct sock`. +/// - `tcp_sk(sk)` points to a valid `struct tcp_sock`. +/// - The socket uses the CCA `T`. +/// - `inet_csk_ca(sk)` points to a valid instance of `T::Data`, which belongs +/// to the instance of the algorithm used by this socket. A callback has +/// exclusive, mutable access to this data. +pub struct Sock<'a, T: Algorithm + ?Sized> { + sk: &'a mut sock::Sock, + _pd: PhantomData, +} + +impl<'a, T: Algorithm + ?Sized> Sock<'a, T> { + /// Creates a new `Sock`. + /// + /// # Safety + /// + /// - `sk` must have been obtained as the argument to a callback of the + /// congestion algorithm `T`. + /// - The CCAs private data must have been initialised. + /// - The returned value must not live longer than the duration of the + /// callback. + unsafe fn new(sk: *mut bindings::sock) -> Self { + // INVARIANTS: Satisfied by the functions precondition. + Self { + // SAFETY: + // - The cast is OK since `sock::Sock` is transparent to + // `struct sock`. + // - Dereferencing `sk` is OK since the pointers passed to CCA CBs + // are valid. + // - By the function's preconditions, the produced `Self` value will + // only live for the duration of the callback; thus, the wrapped + // reference will always be valid. + sk: unsafe { &mut *(sk as *mut sock::Sock) }, + _pd: PhantomData, + } + } + + /// Returns the [`TcpSock`] that is containing the `Sock`. + #[inline] + pub fn tcp_sk<'b>(&'b self) -> &'b TcpSock { + // SAFETY: By the type invariants, `sk` is valid for `tcp_sk`. + unsafe { self.sk.tcp_sk() } + } + + /// Returns the [`TcpSock`] that is containing the `Sock`. + #[inline] + pub fn tcp_sk_mut<'b>(&'b mut self) -> &'b mut TcpSock { + // SAFETY: By the type invariants, `sk` is valid for `tcp_sk`. + unsafe { self.sk.tcp_sk_mut() } + } + + /// Returns the [private data] of the instance of the CCA used by this + /// socket. + /// + /// [private data]: Algorithm::Data + #[inline] + pub fn inet_csk_ca<'b>(&'b self) -> &'b T::Data { + // SAFETY: By the type invariants, `sk` is valid for `inet_csk_ca`, it + // it uses the algorithm `T`, and the private data is valid. + unsafe { self.sk.inet_csk_ca::() } + } + + /// Returns the [private data] of the instance of the CCA used by this + /// socket. + /// + /// [private data]: Algorithm::Data + #[inline] + pub fn inet_csk_ca_mut<'b>(&'b mut self) -> &'b mut T::Data { + // SAFETY: By the type invariants, `sk` is valid for `inet_csk_ca`, it + // it uses the algorithm `T`, and the private data is valid. + unsafe { self.sk.inet_csk_ca_mut::() } + } + + /// Returns the [`InetConnectionSock`] of this socket. + #[inline] + pub fn inet_csk<'b>(&'b self) -> &'b InetConnectionSock { + // SAFETY: By the type invariants, `sk` is valid for `inet_csk`. + unsafe { self.sk.inet_csk() } + } + + /// Tests if the connection's sending rate is limited by the cwnd. + // NOTE: This feels like it should be a method on `TcpSock`, but C defines + // it on `struct sock` so there is not much we can do about it. At least, if + // we don't want to reimplement the function (or perform the conversion from + // `struct tcp_sock` to `struct sock` just to have C reverse it right away. + #[inline] + pub fn tcp_is_cwnd_limited(&self) -> bool { + // SAFETY: By the type invariants, `sk` is valid for + // `tcp_is_cwnd_limited`. + unsafe { self.sk.tcp_is_cwnd_limited() } + } + + /// Returns the sockets pacing rate in bytes per second. + #[inline] + pub fn sk_pacing_rate(&self) -> u64 { + self.sk.sk_pacing_rate() + } + + /// Returns the sockets pacing status. + #[inline] + pub fn sk_pacing_status(&self) -> Result { + self.sk.sk_pacing_status() + } + + /// Returns the sockets maximum GSO segment size to build. + #[inline] + pub fn sk_gso_max_size(&self) -> u32 { + self.sk.sk_gso_max_size() + } +} + +/// Representation of the `struct ack_sample *` that is passed to the +/// `pkts_acked` callback of the CCA. +/// +/// # Invariants +/// +/// - `sample` points to a valid `struct ack_sample`, +/// - all fields of `sample` can be read without additional synchronization. +pub struct AckSample { + sample: *const bindings::ack_sample, +} + +impl AckSample { + /// Creates a new `AckSample`. + /// + /// # Safety + /// + /// `sample` must have been obtained as the argument to the `pkts_acked` + /// callback. + unsafe fn new(sample: *const bindings::ack_sample) -> Self { + // INVARIANTS: Satisfied by the function's precondition. + Self { sample } + } + + /// Returns the number of packets that were ACKed. + #[inline] + pub fn pkts_acked(&self) -> u32 { + // SAFETY: By the type invariants it is OK to read any field. + unsafe { (*self.sample).pkts_acked } + } + + /// Returns the RTT measurement of this ACK sample. + // Note: Some samples might not include a RTT measurement. This is indicated + // by a negative value for `rtt_us`, we return `None` in that case. + #[inline] + pub fn rtt_us(&self) -> Option { + // SAFETY: By the type invariants it is OK to read any field. + match unsafe { (*self.sample).rtt_us } { + t if t < 0 => None, + t => Some(t as time::Usecs32), + } + } +} + +/// States of the TCP sender state machine. +/// +/// The TCP sender's congestion state indicating normal or abnormal situations +/// in the last round of packets sent. The state is driven by the ACK +/// information and timer events. +#[repr(u8)] +pub enum State { + /// Nothing bad has been observed recently. No apparent reordering, packet + /// loss, or ECN marks. + Open = bindings::tcp_ca_state_TCP_CA_Open as u8, + /// The sender enters disordered state when it has received DUPACKs or + /// SACKs in the last round of packets sent. This could be due to packet + /// loss or reordering but needs further information to confirm packets + /// have been lost. + Disorder = bindings::tcp_ca_state_TCP_CA_Disorder as u8, + /// The sender enters Congestion Window Reduction (CWR) state when it + /// has received ACKs with ECN-ECE marks, or has experienced congestion + /// or packet discard on the sender host (e.g. qdisc). + Cwr = bindings::tcp_ca_state_TCP_CA_CWR as u8, + /// The sender is in fast recovery and retransmitting lost packets, + /// typically triggered by ACK events. + Recovery = bindings::tcp_ca_state_TCP_CA_Recovery as u8, + /// The sender is in loss recovery triggered by retransmission timeout. + Loss = bindings::tcp_ca_state_TCP_CA_Loss as u8, +} + +// TODO: Replace with automatically generated code by bindgen when it becomes +// possible. +impl TryFrom for State { + type Error = (); + + fn try_from(val: u8) -> Result { + match val { + x if x == State::Open as u8 => Ok(State::Open), + x if x == State::Disorder as u8 => Ok(State::Disorder), + x if x == State::Cwr as u8 => Ok(State::Cwr), + x if x == State::Recovery as u8 => Ok(State::Recovery), + x if x == State::Loss as u8 => Ok(State::Loss), + _ => Err(()), + } + } +} + +/// Events passed to congestion control interface. +#[repr(u32)] +pub enum Event { + /// First transmit when no packets in flight. + TxStart = bindings::tcp_ca_event_CA_EVENT_TX_START, + /// Congestion window restart. + CwndRestart = bindings::tcp_ca_event_CA_EVENT_CWND_RESTART, + /// End of congestion recovery. + CompleteCwr = bindings::tcp_ca_event_CA_EVENT_COMPLETE_CWR, + /// Loss timeout. + Loss = bindings::tcp_ca_event_CA_EVENT_LOSS, + /// ECT set, but not CE marked. + EcnNoCe = bindings::tcp_ca_event_CA_EVENT_ECN_NO_CE, + /// Received CE marked IP packet. + EcnIsCe = bindings::tcp_ca_event_CA_EVENT_ECN_IS_CE, +} + +// TODO: Replace with automatically generated code by bindgen when it becomes +// possible. +impl TryFrom for Event { + type Error = (); + + fn try_from(ev: bindings::tcp_ca_event) -> Result { + match ev { + x if x == Event::TxStart as u32 => Ok(Event::TxStart), + x if x == Event::CwndRestart as u32 => Ok(Event::CwndRestart), + x if x == Event::CompleteCwr as u32 => Ok(Event::CompleteCwr), + x if x == Event::Loss as u32 => Ok(Event::Loss), + x if x == Event::EcnNoCe as u32 => Ok(Event::EcnNoCe), + x if x == Event::EcnIsCe as u32 => Ok(Event::EcnIsCe), + _ => Err(()), + } + } +} + +#[pin_data(PinnedDrop)] +struct Registration { + #[pin] + ops: Opaque, + _pd: PhantomData, +} + +// SAFETY: `Registration` doesn't provide any `&self` methods, so it is safe to +// pass references to it around. +unsafe impl Sync for Registration {} + +// SAFETY: Both registration and unregistration are implemented in C and safe to +// be performed from any thread, so `Registration` is `Send`. +unsafe impl Send for Registration {} + +impl Registration { + const NAME_FIELD: [i8; 16] = Self::gen_name_field::<16>(); + // Maximal size of the private data. + const ICSK_CA_PRIV_SIZE: usize = field_size!(bindings::inet_connection_sock, icsk_ca_priv); + const DATA_SIZE: usize = core::mem::size_of::(); + + fn new(module: &'static ThisModule) -> impl PinInit { + try_pin_init!(Self { + _pd: PhantomData, + ops <- Opaque::try_ffi_init(|ops_ptr: *mut bindings::tcp_congestion_ops| { + // SAFETY: `try_ffi_init` guarantees that `ops_ptr` is valid for + // write. + unsafe { ops_ptr.write(bindings::tcp_congestion_ops::default()) }; + + // SAFETY: `try_ffi_init` guarantees that `ops_ptr` is valid for + // write, and it has just been initialised above, so it's also + // valid for read. + let ops = unsafe { &mut *ops_ptr }; + + ops.ssthresh = Some(Self::ssthresh_cb); + ops.cong_avoid = Some(Self::cong_avoid_cb); + ops.undo_cwnd = Some(Self::undo_cwnd_cb); + if T::HAS_SET_STATE { + ops.set_state = Some(Self::set_state_cb); + } + if T::HAS_PKTS_ACKED { + ops.pkts_acked = Some(Self::pkts_acked_cb); + } + if T::HAS_CWND_EVENT { + ops.cwnd_event = Some(Self::cwnd_event_cb); + } + + // Even though it is not mandated by the C side, we + // unconditionally set these CBs to ensure that it is always + // safe to access the CCA's private data. + // Future work could allow the CCA to declare whether it wants + // to be able to use the private data. + ops.init = Some(Self::init_cb); + ops.release = Some(Self::release_cb); + + ops.owner = module.0; + ops.name = Self::NAME_FIELD; + + // SAFETY: Pointers stored in `ops` are static so they will live + // for as long as the registration is active (it is undone in + // `drop`). + error::to_result( unsafe { bindings::tcp_register_congestion_control(ops_ptr) }) + }), + }) + } + + const fn gen_name_field() -> [i8; N] { + let mut name_field: [i8; N] = [0; N]; + let mut i = 0; + + while i < T::NAME.len_with_nul() { + name_field[i] = T::NAME.as_bytes_with_nul()[i] as i8; + i += 1; + } + + name_field + } + + unsafe extern "C" fn cwnd_event_cb(sk: *mut bindings::sock, ev: bindings::tcp_ca_event) { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + match Event::try_from(ev) { + Ok(ev) => T::cwnd_event(&mut sk, ev), + Err(_) => pr_err!("cwnd_event: event was {}", ev), + } + } + + unsafe extern "C" fn init_cb(sk: *mut bindings::sock) { + // Fail the build if the module-defined private data is larger than the + // storage that the kernel provides. + build_assert!(Self::DATA_SIZE <= Self::ICSK_CA_PRIV_SIZE); + + // SAFETY: + // - The `sk` that is passed to this callback is valid for + // `inet_csk_ca`. + // - We just checked that there is enough space for the cast to be okay. + let ca = unsafe { bindings::inet_csk_ca(sk) as *mut T::Data }; + + unsafe { ca.write(T::Data::default()) }; + + if T::HAS_INIT { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - We just initialized the `Data`. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::init(&mut sk) + } + } + + unsafe extern "C" fn release_cb(sk: *mut bindings::sock) { + if T::HAS_RELEASE { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::release(&mut sk) + } + + // We have to manually dispose the private data that we stored with the + // kernel. + // SAFETY: + // - The `sk` passed to callbacks is valid for `inet_csk_ca`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - After we return no other callback will be invoked with this socket. + unsafe { core::ptr::drop_in_place(bindings::inet_csk_ca(sk) as *mut T::Data) }; + } + + unsafe extern "C" fn ssthresh_cb(sk: *mut bindings::sock) -> u32 { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::ssthresh(&mut sk) + } + + unsafe extern "C" fn cong_avoid_cb(sk: *mut bindings::sock, ack: u32, acked: u32) { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::cong_avoid(&mut sk, ack, acked) + } + + unsafe extern "C" fn set_state_cb(sk: *mut bindings::sock, new_state: u8) { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + match State::try_from(new_state) { + Ok(new_state) => T::set_state(&mut sk, new_state), + Err(_) => pr_err!("set_state: new_state was {}", new_state), + } + } + + unsafe extern "C" fn pkts_acked_cb( + sk: *mut bindings::sock, + sample: *const bindings::ack_sample, + ) { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + // SAFETY: + // - `sample` points to a valid `struct ack_sample`. + let sample = unsafe { AckSample::new(sample) }; + T::pkts_acked(&mut sk, &sample) + } + + unsafe extern "C" fn undo_cwnd_cb(sk: *mut bindings::sock) -> u32 { + // SAFETY: + // - `sk` was passed to a callback of the CCA `T`. + // - `Data` is guaranteed to be initialized since the `init_cb` took + // care of it. + // - This value will be dropped at the end of the callback. + let mut sk = unsafe { Sock::new(sk) }; + T::undo_cwnd(&mut sk) + } +} + +#[pinned_drop] +impl PinnedDrop for Registration { + fn drop(self: Pin<&mut Self>) { + // SAFETY: + // - The fact that `Self` exists implies that a previous call to + // `tcp_register_congestion_control` with `self.ops.get()` was + // successful. + unsafe { bindings::tcp_unregister_congestion_control(self.ops.get()) }; + } +} + +/// Kernel module that implements a single CCA `T`. +#[pin_data] +pub struct Module { + #[pin] + reg: Registration, +} + +impl crate::InPlaceModule for Module { + fn init(module: &'static ThisModule) -> impl PinInit { + try_pin_init!(Self { + reg <- Registration::::new(module), + }) + } +} + +/// Defines a kernel module that implements a single congestion control +/// algorithm. +/// +/// # Examples +/// +/// To start experimenting with your own congestion control algorithm, implement +/// the [`Algorithm`] trait and use this macro to declare the module to the +/// rest of the kerne. That's it! +/// +/// ```ignore +/// use kernel::{c_str, module_cca}; +/// use kernel::prelude::*; +/// use kernel::net::tcp::cong::*; +/// use core::num::NonZeroU32; +/// +/// struct MyCca {} +/// +/// #[vtable] +/// impl Algorithm for MyCca { +/// type Data = (); +/// +/// const NAME: &'static CStr = c_str!("my_cca"); +/// +/// fn undo_cwnd(sk: &mut Sock<'_, Self>) -> u32 { +/// reno::undo_cwnd(sk) +/// } +/// +/// fn ssthresh(_sk: &mut Sock<'_, Self>) -> u32 { +/// 2 +/// } +/// +/// fn cong_avoid(sk: &mut Sock<'_, Self>, _ack: u32, acked: u32) { +/// sk.tcp_sk_mut().cong_avoid_ai(NonZeroU32::new(1).unwrap(), acked) +/// } +/// } +/// +/// module_cca! { +/// type: MyCca, +/// name: "my_cca", +/// author: "Rust for Linux Contributors", +/// description: "Sample congestion control algorithm implemented in Rust.", +/// license: "GPL v2", +/// } +/// ``` +#[macro_export] +macro_rules! module_cca { + (type: $type:ty, $($f:tt)*) => { + type ModuleType = $crate::net::tcp::cong::Module<$type>; + $crate::macros::module! { + type: ModuleType, + $($f)* + } + } +} +pub use module_cca; From 19a81369cc4f032c5e87dd705e75496a7bb8a29a Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Sun, 11 Feb 2024 22:12:28 +0100 Subject: [PATCH 10/13] samples/rust: add minimal CCA Add an example that uses the `module_cca` macro and the `Algorithm` trait to implement a minimal CCA. IMPORTANT: This CCA is not compliant with the relevant RFCs and must not be used outside of test environments. Signed-off-by: Valentin Obst --- samples/rust/Kconfig | 11 +++++++++++ samples/rust/Makefile | 1 + samples/rust/rust_cca.rs | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+) create mode 100644 samples/rust/rust_cca.rs diff --git a/samples/rust/Kconfig b/samples/rust/Kconfig index b0f74a81c8f9ad..8b9cc6cc7d301d 100644 --- a/samples/rust/Kconfig +++ b/samples/rust/Kconfig @@ -30,6 +30,17 @@ config SAMPLE_RUST_PRINT If unsure, say N. +config SAMPLE_RUST_CCA + tristate "Congestion control algorithm" + depends on RUST_TCP_ABSTRACTIONS + help + This option builds the Rust congestion control algorithm sample. + + To compile this as a module, choose M here: + the module will be called rust_cca. + + If unsure, say N. + config SAMPLE_RUST_HOSTPROGS bool "Host programs" help diff --git a/samples/rust/Makefile b/samples/rust/Makefile index 03086dabbea44f..ee0b9bb7b6ab21 100644 --- a/samples/rust/Makefile +++ b/samples/rust/Makefile @@ -2,5 +2,6 @@ obj-$(CONFIG_SAMPLE_RUST_MINIMAL) += rust_minimal.o obj-$(CONFIG_SAMPLE_RUST_PRINT) += rust_print.o +obj-$(CONFIG_SAMPLE_RUST_CCA) += rust_cca.o subdir-$(CONFIG_SAMPLE_RUST_HOSTPROGS) += hostprogs diff --git a/samples/rust/rust_cca.rs b/samples/rust/rust_cca.rs new file mode 100644 index 00000000000000..4c092582112b07 --- /dev/null +++ b/samples/rust/rust_cca.rs @@ -0,0 +1,35 @@ +//! Congestion control algorithm example. +use core::num::NonZeroU32; +use kernel::net::tcp::cong::*; +use kernel::prelude::*; +use kernel::{c_str, module_cca}; + +struct MyCca {} + +#[vtable] +impl Algorithm for MyCca { + type Data = (); + + const NAME: &'static CStr = c_str!("my_cca"); + + fn undo_cwnd(sk: &mut Sock<'_, Self>) -> u32 { + reno::undo_cwnd(sk) + } + + fn ssthresh(_sk: &mut Sock<'_, Self>) -> u32 { + 2 + } + + fn cong_avoid(sk: &mut Sock<'_, Self>, _ack: u32, acked: u32) { + sk.tcp_sk_mut() + .cong_avoid_ai(NonZeroU32::new(1).unwrap(), acked) + } +} + +module_cca! { + type: MyCca, + name: "my_cca", + author: "Rust for Linux Contributors", + description: "Sample congestion control algorithm implemented in Rust.", + license: "GPL v2", +} From 29c6b402117f224ff58d94b69f4e0c0e663dd16e Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Sun, 11 Feb 2024 22:13:51 +0100 Subject: [PATCH 11/13] net/tcp: add Rust implementation of BIC Reimplement the Binary Increase Congestion (BIC) control algorithm in Rust. BIC is one of the smallest CCAs in the kernel and this mainly serves as a minimal example for a real-world algorithm. Signed-off-by: Valentin Obst --- net/ipv4/Kconfig | 13 ++ net/ipv4/Makefile | 1 + net/ipv4/tcp_bic_rust.rs | 312 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 326 insertions(+) create mode 100644 net/ipv4/tcp_bic_rust.rs diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 0566dda5a6e3a6..00cc70f54dc59c 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -509,6 +509,15 @@ config TCP_CONG_BIC increase provides TCP friendliness. See http://www.csc.ncsu.edu/faculty/rhee/export/bitcp/ +config TCP_CONG_BIC_RUST + tristate "Binary Increase Congestion (BIC) control (Rust rewrite)" + depends on RUST_TCP_ABSTRACTIONS + help + Rust rewrite of the original implementation of Binary Increase + Congestion (BIC) control. + + If unsure, say N. + config TCP_CONG_CUBIC tristate "CUBIC TCP" default y @@ -704,6 +713,9 @@ choice config DEFAULT_BIC bool "Bic" if TCP_CONG_BIC=y + config DEFAULT_BIC_RUST + bool "Bic (Rust)" if TCP_CONG_BIC_RUST=y + config DEFAULT_CUBIC bool "Cubic" if TCP_CONG_CUBIC=y @@ -745,6 +757,7 @@ config TCP_CONG_CUBIC config DEFAULT_TCP_CONG string default "bic" if DEFAULT_BIC + default "bic_rust" if DEFAULT_BIC_RUST default "cubic" if DEFAULT_CUBIC default "htcp" if DEFAULT_HTCP default "hybla" if DEFAULT_HYBLA diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile index ec36d2ec059e80..f93213a62c58ec 100644 --- a/net/ipv4/Makefile +++ b/net/ipv4/Makefile @@ -46,6 +46,7 @@ obj-$(CONFIG_INET_UDP_DIAG) += udp_diag.o obj-$(CONFIG_INET_RAW_DIAG) += raw_diag.o obj-$(CONFIG_TCP_CONG_BBR) += tcp_bbr.o obj-$(CONFIG_TCP_CONG_BIC) += tcp_bic.o +obj-$(CONFIG_TCP_CONG_BIC_RUST) += tcp_bic_rust.o obj-$(CONFIG_TCP_CONG_CDG) += tcp_cdg.o obj-$(CONFIG_TCP_CONG_CUBIC) += tcp_cubic.o obj-$(CONFIG_TCP_CONG_DCTCP) += tcp_dctcp.o diff --git a/net/ipv4/tcp_bic_rust.rs b/net/ipv4/tcp_bic_rust.rs new file mode 100644 index 00000000000000..adbcd03d3b1dcd --- /dev/null +++ b/net/ipv4/tcp_bic_rust.rs @@ -0,0 +1,312 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Binary Increase Congestion control (BIC). Based on: +//! Binary Increase Congestion Control (BIC) for Fast Long-Distance +//! Networks - Lisong Xu, Khaled Harfoush, and Injong Rhee +//! IEEE INFOCOM 2004, Hong Kong, China, 2004, pp. 2514-2524 vol.4 +//! doi: 10.1109/INFCOM.2004.1354672 +//! Link: https://doi.org/10.1109/INFCOM.2004.1354672 +//! Link: https://web.archive.org/web/20160417213452/http://netsrv.csc.ncsu.edu/export/bitcp.pdf + +use core::cmp::{max, min}; +use core::num::NonZeroU32; +use kernel::c_str; +use kernel::net::tcp::cong::{self, module_cca}; +use kernel::prelude::*; +use kernel::time; + +const ACK_RATIO_SHIFT: u32 = 4; + +// TODO: Convert to module parameters once they are available. +/// The initial value of ssthresh for new connections. Setting this to `None` +/// implies `i32::MAX`. +const INITIAL_SSTHRESH: Option = None; +/// If cwnd is larger than this threshold, BIC engages; otherwise normal TCP +/// increase/decrease will be performed. +const LOW_WINDOW: u32 = 14; +/// In binary search, go to point: `cwnd + (W_max - cwnd) / BICTCP_B`. +// TODO: Convert to `new::(x).unwrap()` once `const_option` is stabilised. +// SAFETY: This will panic at compile time when passing zero. +const BICTCP_B: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(4) }; +/// The maximum increment, i.e., `S_max`. This is used during additive increase. +/// After crossing `W_max`, slow start is performed until passing +/// `MAX_INCREMENT * (BICTCP_B - 1)`. +// TODO: Convert to `new::(x).unwrap()` once `const_option` is stabilised. +// SAFETY: This will panic at compile time when passing zero. +const MAX_INCREMENT: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(16) }; +/// The number of RTT it takes to get from `W_max - BICTCP_B` to `W_max` (and +/// from `W_max` to `W_max + BICTCP_B`). This is not part of the original paper +/// and results in a slow additive increase across `W_max`. +const SMOOTH_PART: u32 = 20; +/// Whether to use fast convergence. This is a heuristic to increase the +/// release of bandwidth by existing flows to speed up the convergence to a +/// steady state when a new flow joins the link. +const FAST_CONVERGENCE: bool = true; +/// Factor for multiplicative decrease. In fast retransmit we have: +/// `cwnd = cwnd * BETA/BETA_SCALE` +/// and if fast convergence is active: +/// `W_max = cwnd * (1 + BETA/BETA_SCALE)/2` +/// instead of `W_max = cwnd`. +const BETA: u32 = 819; +/// Used to calculate beta in [0, 1] with integer arithmetics. +// TODO: Convert to `new::(x).unwrap()` once `const_option` is stabilised. +// SAFETY: This will panic at compile time when passing zero. +const BETA_SCALE: NonZeroU32 = unsafe { NonZeroU32::new_unchecked(1024) }; +/// The minimum amount of time that has to pass between two updates of the cwnd. +const MIN_UPDATE_INTERVAL: time::Msecs32 = time::MSEC_PER_SEC / 32; + +module_cca! { + type: Bic, + name: "tcp_bic_rust", + author: "Rust for Linux Contributors", + description: "Binary Increase Congestion control (BIC) algorithm, Rust implementation", + license: "GPL v2", +} + +struct Bic {} + +#[vtable] +impl cong::Algorithm for Bic { + type Data = BicState; + + const NAME: &'static CStr = c_str!("bic_rust"); + + fn pkts_acked(sk: &mut cong::Sock<'_, Self>, sample: &cong::AckSample) { + if let Ok(cong::State::Open) = sk.inet_csk().ca_state() { + let ca = sk.inet_csk_ca_mut(); + + // Track delayed acknowledgment ratio using sliding window: + // ratio = (15*ratio + sample) / 16 + ca.delayed_ack = ca.delayed_ack.wrapping_add( + sample + .pkts_acked() + .wrapping_sub(ca.delayed_ack >> ACK_RATIO_SHIFT), + ); + } + } + + fn ssthresh(sk: &mut cong::Sock<'_, Self>) -> u32 { + let cwnd = sk.tcp_sk().snd_cwnd(); + let ca = sk.inet_csk_ca_mut(); + + pr_info!( + // TODO: remove + "Enter fast retransmit: time {}, start {}", + time::ktime_get_boot_fast_ns(), + ca.start_time + ); + + // Epoch has ended. + ca.epoch_start = 0; + ca.last_max_cwnd = if cwnd < ca.last_max_cwnd && FAST_CONVERGENCE { + (cwnd * (BETA_SCALE.get() + BETA)) / (2 * BETA_SCALE.get()) + } else { + cwnd + }; + + if cwnd <= LOW_WINDOW { + // Act like normal TCP. + max(cwnd >> 1, 2) + } else { + max((cwnd * BETA) / BETA_SCALE, 2) + } + } + + fn cong_avoid(sk: &mut cong::Sock<'_, Self>, _ack: u32, mut acked: u32) { + if !sk.tcp_is_cwnd_limited() { + return; + } + + let tp = sk.tcp_sk_mut(); + + if tp.in_slow_start() { + acked = tp.slow_start(acked); + if acked == 0 { + pr_info!( + // TODO: remove + "New cwnd {}, time {}, ssthresh {}, start {}, ss 1", + sk.tcp_sk().snd_cwnd(), + time::ktime_get_boot_fast_ns(), + sk.tcp_sk().snd_ssthresh(), + sk.inet_csk_ca().start_time + ); + return; + } + } + + let cwnd = tp.snd_cwnd(); + let cnt = sk.inet_csk_ca_mut().update(cwnd); + sk.tcp_sk_mut().cong_avoid_ai(cnt, acked); + + pr_info!( + // TODO: remove + "New cwnd {}, time {}, ssthresh {}, start {}, ss 0", + sk.tcp_sk().snd_cwnd(), + time::ktime_get_boot_fast_ns(), + sk.tcp_sk().snd_ssthresh(), + sk.inet_csk_ca().start_time + ); + } + + fn set_state(sk: &mut cong::Sock<'_, Self>, new_state: cong::State) { + if matches!(new_state, cong::State::Loss) { + pr_info!( + // TODO: remove + "Retransmission timeout fired: time {}, start {}", + time::ktime_get_boot_fast_ns(), + sk.inet_csk_ca().start_time + ); + sk.inet_csk_ca_mut().reset() + } + } + + fn undo_cwnd(sk: &mut cong::Sock<'_, Self>) -> u32 { + pr_info!( + // TODO: remove + "Undo cwnd reduction: time {}, start {}", + time::ktime_get_boot_fast_ns(), + sk.inet_csk_ca().start_time + ); + + cong::reno::undo_cwnd(sk) + } + + fn init(sk: &mut cong::Sock<'_, Self>) { + if let Some(ssthresh) = INITIAL_SSTHRESH { + sk.tcp_sk_mut().set_snd_ssthresh(ssthresh); + } + + // TODO: remove + pr_info!("Socket created: start {}", sk.inet_csk_ca().start_time); + } + + // TODO: remove + fn release(sk: &mut cong::Sock<'_, Self>) { + pr_info!( + "Socket destroyed: start {}, end {}", + sk.inet_csk_ca().start_time, + time::ktime_get_boot_fast_ns() + ); + } +} + +/// Internal state of each instance of the algorithm. +struct BicState { + /// During congestion avoidance, cwnd is increased at most every `cnt` + /// acknowledged packets, i.e., the average increase per acknowledged packet + /// is proportional to `1 / cnt`. + // NOTE: The C impl initialises this to zero. It then ensures that zero is + // never passed to `cong_avoid_ai`, which could divide by it. Make it + // explicit in the types that zero is not a valid value. + cnt: NonZeroU32, + /// Last maximum `snd_cwnd`, i.e, `W_max`. + last_max_cwnd: u32, + /// The last `snd_cwnd`. + last_cwnd: u32, + /// Time when `last_cwnd` was updated. + last_time: time::Msecs32, + /// Records the beginning of an epoch. + epoch_start: time::Msecs32, + /// Estimates the ratio of `packets/ACK << 4`. This allows us to adjust cwnd + /// per packet when a receiver is sending a single ACK for multiple received + /// packets. + delayed_ack: u32, + /// Time when algorithm was initialised. + // TODO: remove + start_time: time::Nsecs, +} + +impl Default for BicState { + fn default() -> Self { + Self { + // NOTE: Initialising this to 1 deviates from the C code. It does + // not change the behaviour of the algorithm. + cnt: NonZeroU32::MIN, + last_max_cwnd: 0, + last_cwnd: 0, + last_time: 0, + epoch_start: 0, + delayed_ack: 2 << ACK_RATIO_SHIFT, + // TODO: remove + start_time: time::ktime_get_boot_fast_ns(), + } + } +} + +impl BicState { + /// Compute congestion window to use. Returns the new `cnt`. + /// + /// This governs the behavior of the algorithm during congestion avoidance. + fn update(&mut self, cwnd: u32) -> NonZeroU32 { + let now = time::ktime_get_boot_fast_ms32(); + + // Do nothing if we are invoked too frequently. + if self.last_cwnd == cwnd && now.wrapping_sub(self.last_time) <= MIN_UPDATE_INTERVAL { + return self.cnt; + } + + self.last_cwnd = cwnd; + self.last_time = now; + + // Record the beginning of an epoch. + if self.epoch_start == 0 { + self.epoch_start = now; + } + + // Start off like normal TCP. + if cwnd <= LOW_WINDOW { + self.cnt = NonZeroU32::new(cwnd).unwrap_or(NonZeroU32::MIN); + return self.cnt; + } + + let mut new_cnt = if cwnd < self.last_max_cwnd { + // binary increase + let dist: u32 = (self.last_max_cwnd - cwnd) / BICTCP_B; + + if dist > MAX_INCREMENT.get() { + // additive increase + cwnd / MAX_INCREMENT + } else if dist <= 1 { + // careful additive increase + (cwnd * SMOOTH_PART) / BICTCP_B + } else { + // binary search + cwnd / dist + } + } else { + if cwnd < self.last_max_cwnd + BICTCP_B.get() { + // careful additive increase + (cwnd * SMOOTH_PART) / BICTCP_B + } else if cwnd < self.last_max_cwnd + MAX_INCREMENT.get() * (BICTCP_B.get() - 1) { + // slow start + (cwnd * (BICTCP_B.get() - 1)) / (cwnd - self.last_max_cwnd) + } else { + // linear increase + cwnd / MAX_INCREMENT + } + }; + + // If in initial slow start or link utilization is very low. + if self.last_max_cwnd == 0 { + new_cnt = min(new_cnt, 20); + } + + // Account for estimated packets/ACK to ensure that we increase per + // packet. + new_cnt = (new_cnt << ACK_RATIO_SHIFT) / self.delayed_ack; + + self.cnt = NonZeroU32::new(new_cnt).unwrap_or(NonZeroU32::MIN); + + self.cnt + } + + fn reset(&mut self) { + // TODO: remove + let tmp = self.start_time; + + *self = Self::default(); + + // TODO: remove + self.start_time = tmp; + } +} From 67c08f7f23772c28afe70cbdb2430891d91f477e Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Fri, 1 Mar 2024 23:16:05 +0100 Subject: [PATCH 12/13] rust/net: add implementation of HyStart Signed-off-by: Valentin Obst --- rust/kernel/lib.rs | 1 + rust/kernel/net/tcp/cong.rs | 2 + rust/kernel/net/tcp/cong/hystart.rs | 265 ++++++++++++++++++++++++++++ 3 files changed, 268 insertions(+) create mode 100644 rust/kernel/net/tcp/cong/hystart.rs diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index e10c40f9fee3d6..b90282b9962c2d 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -13,6 +13,7 @@ #![no_std] #![feature(allocator_api)] +#![feature(associated_type_bounds)] #![feature(coerce_unsized)] #![feature(dispatch_from_dyn)] #![feature(new_uninit)] diff --git a/rust/kernel/net/tcp/cong.rs b/rust/kernel/net/tcp/cong.rs index 1640c85ffe7e4f..a08ec04b946621 100644 --- a/rust/kernel/net/tcp/cong.rs +++ b/rust/kernel/net/tcp/cong.rs @@ -21,6 +21,8 @@ use macros::{pin_data, pinned_drop}; use super::{InetConnectionSock, TcpSock}; +pub mod hystart; + /// Congestion control algorithm (CCA). /// /// A CCA is implemented as a set of callbacks that are invoked whenever diff --git a/rust/kernel/net/tcp/cong/hystart.rs b/rust/kernel/net/tcp/cong/hystart.rs new file mode 100644 index 00000000000000..5bc847902c5f14 --- /dev/null +++ b/rust/kernel/net/tcp/cong/hystart.rs @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! HyStart slow start algorithm. +//! +//! Based on: +//! Sangtae Ha, Injong Rhee, +//! Taming the elephants: New TCP slow start, +//! Computer Networks, Volume 55, Issue 9, 2011, Pages 2092-2110, +//! ISSN 1389-1286, + +use crate::net::sock; +use crate::net::tcp::{self, cong}; +use crate::time; +use crate::{pr_err, pr_info}; +use core::cmp::min; + +/// The heuristic that is used to find the exit point for slow start. +pub enum HystartDetect { + /// Exits slow start when the length of so-called ACK-trains becomes equal + /// to the estimated minimum forward path one-way delay. + AckTrain = 1, + /// Exits slow start when the estimated RTT increase between two consecutive + /// rounds exceeds a threshold that is based on the last RTT. + Delay = 2, + /// Combine both algorithms. + Both = 3, +} + +/// Internal state of the [`HyStart`] algorithm. +pub struct HyStartState { + /// Number of ACKs already sampled to determine the RTT of this round. + sample_cnt: u8, + /// Whether the slow start exit point was found. + found: bool, + /// Time when the current round has started. + round_start: time::Usecs32, + /// Sequence number of the byte that marks the end of the current round. + end_seq: u32, + /// Time when the last ACK was received in this round. + last_ack: time::Usecs32, + /// The minimum RTT of the current round. + curr_rtt: time::Usecs32, + /// Estimate of the minimum forward path one-way delay of the link. + pub delay_min: Option, + /// Time when the connection was created. + // TODO: remove + pub start_time: time::Usecs32, +} + +impl Default for HyStartState { + fn default() -> Self { + Self { + sample_cnt: 0, + found: false, + round_start: 0, + end_seq: 0, + last_ack: 0, + curr_rtt: 0, + delay_min: None, + // TODO: remove + start_time: time::ktime_get_boot_fast_us32(), + } + } +} + +impl HyStartState { + /// Returns true iff the algorithm `T` is in hybrid slow start. + #[inline] + pub fn in_hystart(&self, cwnd: u32) -> bool { + !self.found && cwnd >= T::LOW_WINDOW + } +} + +/// Implement this trait on [`Algorithm::Data`] to use [`HyStart`] for your CCA. +/// +/// [`Algorithm::Data`]: cong::Algorithm::Data +pub trait HasHyStartState { + /// Returns the private data of the HyStart algorithm. + fn hy(&self) -> &HyStartState; + + /// Returns the private data of the HyStart algorithm. + fn hy_mut(&mut self) -> &mut HyStartState; +} + +/// Implement this trait on your [`Algorithm`] to use HyStart. You still need to +/// invoke the [`reset`] and [`update`] methods at the right places. +/// +/// [`Algorithm`]: cong::Algorithm +/// [`reset`]: HyStart::reset +/// [`update`]: HyStart::update +pub trait HyStart: cong::Algorithm { + // TODO: Those constants should be configurable via module parameters. + /// Which heuristic to use for deciding when it is time to exit slow start. + const DETECT: HystartDetect; + + /// Lower bound for cwnd during hybrid slow start. + const LOW_WINDOW: u32; + + /// Max spacing between ACKs in an ACK-train. + const ACK_DELTA: time::Usecs32; + + /// Number of ACKs to sample at the beginning of each round to estimate the + /// RTT of this round. + const MIN_SAMPLES: u8 = 8; + + /// Lower bound on the increase in RTT between to consecutive rounds that is + /// needed to trigger an exit from slow start. + const DELAY_MIN: time::Usecs32 = 4000; + + /// Upper bound on the increase in RTT between to consecutive rounds that is + /// needed to trigger an exit from slow start. + const DELAY_MAX: time::Usecs32 = 16000; + + /// Corresponds to the function eta from the paper. Returns the increase in + /// RTT between consecutive rounds that triggers and exit from slow start. + /// `t` is the RTT of the last round. + fn delay_thresh(mut t: time::Usecs32) -> time::Usecs32 { + t >>= 3; + + if t < Self::DELAY_MIN { + Self::DELAY_MIN + } else if t > Self::DELAY_MAX { + Self::DELAY_MAX + } else { + t + } + } + + /// TODO + fn ack_delay(sk: &cong::Sock<'_, Self>) -> time::Usecs32 { + (match sk.sk_pacing_rate() { + 0 => 0, + rate => min( + time::USEC_PER_MSEC, + ((sk.sk_gso_max_size() as u64) * 4 * time::USEC_PER_SEC) / rate, + ), + } as time::Usecs32) + } + + /// Called in slow start at the beginning of a new round of incoming ACKs. + fn reset(sk: &mut cong::Sock<'_, Self>) { + let tp = sk.tcp_sk(); + let now = tp.tcp_mstamp() as time::Usecs32; + let snd_nxt = tp.snd_nxt(); + + let hy = sk.inet_csk_ca_mut().hy_mut(); + + hy.round_start = now; + hy.last_ack = now; + hy.end_seq = snd_nxt; + hy.curr_rtt = u32::MAX; + hy.sample_cnt = 0; + } + + /// Called in slow start to decide if it is time to exit slow start. Sets + /// [`HyStartState`] `found` to true when it is time to exit. + fn update(sk: &mut cong::Sock<'_, Self>, delay: time::Usecs32) { + // Start of a new round. + if tcp::after(sk.tcp_sk().snd_una(), sk.inet_csk_ca().hy().end_seq) { + Self::reset(sk); + } + let hy = sk.inet_csk_ca().hy(); + let Some(delay_min) = hy.delay_min else { + // This should not happen. + pr_err!("hystart: update: delay_min was None"); + return; + }; + + if matches!(Self::DETECT, HystartDetect::Both | HystartDetect::AckTrain) { + let tp = sk.tcp_sk(); + let now = tp.tcp_mstamp() as time::Usecs32; + + // Is this ACK part of a train? + // NOTE: I don't get it. C is doing this as a signed comparison but + // for: + // -- `0 <= now < ca->last_ack <= 0x7F..F` this means it always + // passes, + // -- `ca->last_ack = 0x80..0` and `0 <= new <= 0x7F..F` it also + // always passes, + // -- `0x80..00 < ca->last_ack` and `now < 0x80.0` (big enough) + // also always passes. + // If I understand the paper correctly, this is not what is + // intended. What we really want here is the unsigned version I + // guess, please correct me if I am wrong. + // Commit: c54b4b7655447c1f24f6d50779c22eba9ee0fd24 + // Purposefully introduced the cast ... am I just stupid? + // Link: https://godbolt.org/z/E7ocxae69 + if now.wrapping_sub(hy.last_ack) <= Self::ACK_DELTA { + let threshold = if let Ok(sock::Pacing::r#None) = sk.sk_pacing_status() { + (delay_min + Self::ack_delay(sk)) >> 1 + } else { + delay_min + Self::ack_delay(sk) + }; + + // Does the length of this ACK-train indicate it is time to + // exit slow start? + // NOTE: C is a bit weird here ... `threshold` is unsigned but + // the lhs is still cast to signed, even though the usual + // arithmetic conversions will immediately cast it back to + // unsigned; thus, I guess we can just do everything unsigned. + if now.wrapping_sub(hy.round_start) > threshold { + // TODO: change to debug + pr_info!( + "hystart_ack_train ({}us > {}us) delay_min {}us (+ ack_delay {}us) cwnd {}, start {}us", + now.wrapping_sub(hy.round_start), + threshold, + delay_min, + Self::ack_delay(sk), + tp.snd_cwnd(), + hy.start_time + ); + + let tp = sk.tcp_sk_mut(); + + tp.set_snd_ssthresh(tp.snd_cwnd()); + + sk.inet_csk_ca_mut().hy_mut().found = true; + + // TODO: Update net stats. + } + + sk.inet_csk_ca_mut().hy_mut().last_ack = now; + } + } + + if matches!(Self::DETECT, HystartDetect::Both | HystartDetect::Delay) { + let hy = sk.inet_csk_ca_mut().hy_mut(); + + // The paper only takes the min RTT of the first `MIN_SAMPLES` + // ACKs in a round, but it does no harm to consider later ACKs as + // well. + if hy.curr_rtt > delay { + hy.curr_rtt = delay + } + + if hy.sample_cnt < Self::MIN_SAMPLES { + hy.sample_cnt += 1; + } else { + // Does the increase in RTT indicate its time to exit slow + // start? + if hy.curr_rtt > delay_min + Self::delay_thresh(delay_min) { + hy.found = true; + + // TODO: change to debug + let curr_rtt = hy.curr_rtt; + let start_time = hy.start_time; + pr_info!( + "hystart_delay: {}us > {}us, delay_min {}us (+ delay_thresh {}us), cwnd {}, start {}us", + curr_rtt, + delay_min + Self::delay_thresh(delay_min), + delay_min, + Self::delay_thresh(delay_min), + sk.tcp_sk().snd_cwnd(), + start_time, + ); + // TODO: Update net stats. + + let tp = sk.tcp_sk_mut(); + + tp.set_snd_ssthresh(tp.snd_cwnd()); + } + } + } + } +} From 2253d8e4b114740d3d1878f7bcb9471afe208503 Mon Sep 17 00:00:00 2001 From: Valentin Obst Date: Sun, 18 Feb 2024 22:01:21 +0100 Subject: [PATCH 13/13] net/tcp: add Rust implementation of CUBIC CUBIC is the default CCA since 2.6.18. Missing features compared to the C implementation: - configuration via module parameters, - exporting callbacks to BPF programs as kfuncs. Changes compared to the C implementation: - uses only SI units for time, i.e., no jiffies and `BICTCP_HZ`, Signed-off-by: Valentin Obst --- net/ipv4/Kconfig | 13 + net/ipv4/Makefile | 1 + net/ipv4/tcp_cubic_rust.rs | 510 +++++++++++++++++++++++++++++++++++++ 3 files changed, 524 insertions(+) create mode 100644 net/ipv4/tcp_cubic_rust.rs diff --git a/net/ipv4/Kconfig b/net/ipv4/Kconfig index 00cc70f54dc59c..d2ea9811c131c7 100644 --- a/net/ipv4/Kconfig +++ b/net/ipv4/Kconfig @@ -526,6 +526,15 @@ config TCP_CONG_CUBIC among other techniques. See http://www.csc.ncsu.edu/faculty/rhee/export/bitcp/cubic-paper.pdf +config TCP_CONG_CUBIC_RUST + tristate "CUBIC TCP (Rust rewrite)" + depends on RUST_TCP_ABSTRACTIONS + help + Rust rewrite of the original implementation of TCP CUBIC congestion + control. + + If unsure, say N. + config TCP_CONG_WESTWOOD tristate "TCP Westwood+" default m @@ -719,6 +728,9 @@ choice config DEFAULT_CUBIC bool "Cubic" if TCP_CONG_CUBIC=y + config DEFAULT_CUBIC_RUST + bool "Cubic (Rust)" if TCP_CONG_CUBIC_RUST=y + config DEFAULT_HTCP bool "Htcp" if TCP_CONG_HTCP=y @@ -759,6 +771,7 @@ config DEFAULT_TCP_CONG default "bic" if DEFAULT_BIC default "bic_rust" if DEFAULT_BIC_RUST default "cubic" if DEFAULT_CUBIC + default "cubic_rust" if DEFAULT_CUBIC_RUST default "htcp" if DEFAULT_HTCP default "hybla" if DEFAULT_HYBLA default "vegas" if DEFAULT_VEGAS diff --git a/net/ipv4/Makefile b/net/ipv4/Makefile index f93213a62c58ec..8aecd5fa55e96d 100644 --- a/net/ipv4/Makefile +++ b/net/ipv4/Makefile @@ -49,6 +49,7 @@ obj-$(CONFIG_TCP_CONG_BIC) += tcp_bic.o obj-$(CONFIG_TCP_CONG_BIC_RUST) += tcp_bic_rust.o obj-$(CONFIG_TCP_CONG_CDG) += tcp_cdg.o obj-$(CONFIG_TCP_CONG_CUBIC) += tcp_cubic.o +obj-$(CONFIG_TCP_CONG_CUBIC_RUST) += tcp_cubic_rust.o obj-$(CONFIG_TCP_CONG_DCTCP) += tcp_dctcp.o obj-$(CONFIG_TCP_CONG_WESTWOOD) += tcp_westwood.o obj-$(CONFIG_TCP_CONG_HSTCP) += tcp_highspeed.o diff --git a/net/ipv4/tcp_cubic_rust.rs b/net/ipv4/tcp_cubic_rust.rs new file mode 100644 index 00000000000000..f93c15f27d77be --- /dev/null +++ b/net/ipv4/tcp_cubic_rust.rs @@ -0,0 +1,510 @@ +// SPDX-License-Identifier: GPL-2.0-only + +//! TCP CUBIC congestion control algorithm. +//! +//! Based on: +//! Sangtae Ha, Injong Rhee, and Lisong Xu. 2008. +//! CUBIC: A New TCP-Friendly High-Speed TCP Variant. +//! SIGOPS Oper. Syst. Rev. 42, 5 (July 2008), 64–74. +//! +//! +//! CUBIC is also described in [RFC9438](https://www.rfc-editor.org/rfc/rfc9438). + +use core::cmp::{max, min}; +use core::num::NonZeroU32; +use kernel::c_str; +use kernel::net::tcp; +use kernel::net::tcp::cong::{self, hystart, hystart::HystartDetect, module_cca}; +use kernel::prelude::*; +use kernel::time; + +const BICTCP_BETA_SCALE: u32 = 1024; + +// TODO: Convert to module parameters once they are available. Currently these +// are the defaults from the C implementation. +// TODO: Use `NonZeroU32` where appropriate. +/// Whether to use fast convergence. This is a heuristic to increase the +/// release of bandwidth by existing flows to speed up the convergence to a +/// steady state when a new flow joins the link. +const FAST_CONVERGENCE: bool = true; +/// The factor for multiplicative decrease of cwnd upon a loss event. Will be +/// divided by `BICTCP_BETA_SCALE`, approximately 0.7. +const BETA: u32 = 717; +/// The initial value of ssthresh for new connections. Setting this to `None` +/// implies `i32::MAX`. +const INITIAL_SSTHRESH: Option = None; +/// The parameter `C` that scales the cubic term is defined as `BIC_SCALE/2^10`. +/// (For C: Dimension: Time^-2, Unit: s^-2). +const BIC_SCALE: u32 = 41; +/// In environments where CUBIC grows cwnd less aggressively than normal TCP, +/// enabling this option causes it to behave like normal TCP instead. This is +/// the case in short RTT and/or low bandwidth delay product networks. +const TCP_FRIENDLINESS: bool = true; +/// Whether to use the [HyStart] slow start algorithm. +/// +/// [HyStart]: hystart::HyStart +const HYSTART: bool = true; + +impl hystart::HyStart for Cubic { + /// Which mechanism to use for deciding when it is time to exit slow start. + const DETECT: HystartDetect = HystartDetect::Both; + /// Lower bound for cwnd during hybrid slow start. + const LOW_WINDOW: u32 = 16; + /// Spacing between ACKs indicating an ACK-train. + /// (Dimension: Time. Unit: us). + const ACK_DELTA: time::Usecs32 = 2000; +} + +// TODO: Those are computed based on the module parameters in the init. Even +// with module parameters available this will be a bit tricky to do in Rust. +/// Factor of `8/3 * (1 + beta) / (1 - beta)` that is used in various +/// calculations. (Dimension: none) +const BETA_SCALE: u32 = ((8 * (BICTCP_BETA_SCALE + BETA)) / 3) / (BICTCP_BETA_SCALE - BETA); +/// Factor of `2^10*C/SRTT` where `SRTT = 100ms` that is used in various +/// calculations. (Dimension: Time^-3, Unit: s^-3). +const CUBE_RTT_SCALE: u32 = BIC_SCALE * 10; +/// Factor of `SRTT/C` where `SRTT = 100ms` and `C` from above. +/// (Dimension: Time^3. Unit: (ms)^3) +// Note: C uses a custom time unit of 2^-10 s called `BICTCP_HZ`. This +// implementation consistently uses milliseconds instead. +const CUBE_FACTOR: u64 = 1_000_000_000 * (1u64 << 10) / (CUBE_RTT_SCALE as u64); + +module_cca! { + type: Cubic, + name: "tcp_cubic_rust", + author: "Rust for Linux Contributors", + description: "TCP CUBIC congestion control algorithm, Rust implementation", + license: "GPL v2", +} + +struct Cubic {} + +#[vtable] +impl cong::Algorithm for Cubic { + type Data = CubicState; + + const NAME: &'static CStr = c_str!("cubic_rust"); + + fn init(sk: &mut cong::Sock<'_, Self>) { + if HYSTART { + ::reset(sk) + } else if let Some(ssthresh) = INITIAL_SSTHRESH { + sk.tcp_sk_mut().set_snd_ssthresh(ssthresh); + } + + // TODO: remove + pr_info!( + "init: socket created: start {}us", + sk.inet_csk_ca().hystart_state.start_time + ); + } + + // TODO: remove + fn release(sk: &mut cong::Sock<'_, Self>) { + pr_info!( + "release: socket destroyed: start {}us, end {}us", + sk.inet_csk_ca().hystart_state.start_time, + time::ktime_get_boot_fast_us32(), + ); + } + + fn cwnd_event(sk: &mut cong::Sock<'_, Self>, ev: cong::Event) { + if matches!(ev, cong::Event::TxStart) { + // Here we cannot avoid jiffies as the `lsndtime` field is measured + // in jiffies. + let now = time::jiffies32(); + let delta: time::Jiffies32 = now.wrapping_sub(sk.tcp_sk().lsndtime()); + + if (delta as i32) <= 0 { + return; + } + + let ca = sk.inet_csk_ca_mut(); + // Ok, lets switch to SI units. + let now = time::ktime_get_boot_fast_ms32(); + let delta = time::jiffies_to_msecs(delta as time::Jiffies); + // TODO: remove + pr_debug!("cwnd_event: TxStart, now {}ms, delta {}ms", now, delta); + // We were application limited, i.e., idle, for a while. If we are + // in congestion avoidance, shift `epoch_start` by the time we were + // idle to keep cwnd growth to cubic curve. + ca.epoch_start = ca.epoch_start.map(|mut epoch_start| { + epoch_start = epoch_start.wrapping_add(delta); + if tcp::after(epoch_start, now) { + epoch_start = now; + } + epoch_start + }); + } + } + + fn set_state(sk: &mut cong::Sock<'_, Self>, new_state: cong::State) { + if matches!(new_state, cong::State::Loss) { + pr_info!( + // TODO: remove + "set_state: Loss, time {}us, start {}us", + time::ktime_get_boot_fast_us32(), + sk.inet_csk_ca().hystart_state.start_time + ); + sk.inet_csk_ca_mut().reset(); + ::reset(sk); + } + } + + fn pkts_acked(sk: &mut cong::Sock<'_, Self>, sample: &cong::AckSample) { + // Some samples do not include RTTs. + let Some(rtt_us) = sample.rtt_us() else { + // TODO: remove + pr_debug!( + "pkts_acked: no RTT sample, start {}us", + sk.inet_csk_ca().hystart_state.start_time, + ); + return; + }; + + let epoch_start = sk.inet_csk_ca().epoch_start; + // For some time after existing fast recovery the samples might still be + // inaccurate. + if epoch_start.is_some_and(|epoch_start| { + time::ktime_get_boot_fast_ms32().wrapping_sub(epoch_start) < time::MSEC_PER_SEC + }) { + // TODO: remove + pr_debug!( + "pkts_acked: {}ms - {}ms < 1s, too close to epoch_start", + time::ktime_get_boot_fast_ms32(), + epoch_start.unwrap() + ); + return; + } + + let delay = max(1, rtt_us); + let cwnd = sk.tcp_sk().snd_cwnd(); + let in_slow_start = sk.tcp_sk().in_slow_start(); + let ca = sk.inet_csk_ca_mut(); + + // TODO: remove + pr_debug!( + "pkts_acked: delay {}us, cwnd {}, ss {}", + delay, + cwnd, + in_slow_start + ); + + // First call after reset or the delay decreased. + if ca.hystart_state.delay_min.is_none() + || ca + .hystart_state + .delay_min + .is_some_and(|delay_min| delay_min > delay) + { + ca.hystart_state.delay_min = Some(delay); + } + + if in_slow_start && HYSTART && ca.hystart_state.in_hystart::(cwnd) { + hystart::HyStart::update(sk, delay); + } + } + + fn ssthresh(sk: &mut cong::Sock<'_, Self>) -> u32 { + let cwnd = sk.tcp_sk().snd_cwnd(); + let ca = sk.inet_csk_ca_mut(); + + pr_info!( + // TODO: remove + "ssthresh: time {}us, start {}us", + time::ktime_get_boot_fast_us32(), + ca.hystart_state.start_time + ); + + // Epoch has ended. + ca.epoch_start = None; + ca.last_max_cwnd = if cwnd < ca.last_max_cwnd && FAST_CONVERGENCE { + (cwnd * (BICTCP_BETA_SCALE + BETA)) / (2 * BICTCP_BETA_SCALE) + } else { + cwnd + }; + + max((cwnd * BETA) / BICTCP_BETA_SCALE, 2) + } + + fn undo_cwnd(sk: &mut cong::Sock<'_, Self>) -> u32 { + pr_info!( + // TODO: remove + "undo_cwnd: time {}us, start {}us", + time::ktime_get_boot_fast_us32(), + sk.inet_csk_ca().hystart_state.start_time + ); + + cong::reno::undo_cwnd(sk) + } + + fn cong_avoid(sk: &mut cong::Sock<'_, Self>, _ack: u32, mut acked: u32) { + if !sk.tcp_is_cwnd_limited() { + return; + } + + let tp = sk.tcp_sk_mut(); + + if tp.in_slow_start() { + acked = tp.slow_start(acked); + if acked == 0 { + pr_info!( + // TODO: remove + "cong_avoid: new cwnd {}, time {}us, ssthresh {}, start {}us, ss 1", + sk.tcp_sk().snd_cwnd(), + time::ktime_get_boot_fast_us32(), + sk.tcp_sk().snd_ssthresh(), + sk.inet_csk_ca().hystart_state.start_time + ); + return; + } + } + + let cwnd = tp.snd_cwnd(); + let cnt = sk.inet_csk_ca_mut().update(cwnd, acked); + sk.tcp_sk_mut().cong_avoid_ai(cnt, acked); + + pr_info!( + // TODO: remove + "cong_avoid: new cwnd {}, time {}us, ssthresh {}, start {}us, ss 0", + sk.tcp_sk().snd_cwnd(), + time::ktime_get_boot_fast_us32(), + sk.tcp_sk().snd_ssthresh(), + sk.inet_csk_ca().hystart_state.start_time + ); + } +} + +#[allow(non_snake_case)] +struct CubicState { + /// Increase cwnd by one step after `cnt` ACKs. + cnt: NonZeroU32, + /// W__last_max. + last_max_cwnd: u32, + /// Value of cwnd before it was updated the last time. + last_cwnd: u32, + /// Time when `last_cwnd` was updated. + last_time: time::Msecs32, + /// Value of cwnd where the plateau of the cubic function is located. + origin_point: u32, + /// Time it takes to reach `origin_point`, measured from the beginning of + /// an epoch. + K: time::Msecs32, + /// Time when the current epoch has started. `None` when not in congestion + /// avoidance. + epoch_start: Option, + /// Number of packets that have been ACKed in the current epoch. + ack_cnt: u32, + /// Estimate for the cwnd of TCP Reno. + tcp_cwnd: u32, + /// State of the HyStart slow start algorithm. + hystart_state: hystart::HyStartState, +} + +impl hystart::HasHyStartState for CubicState { + fn hy(&self) -> &hystart::HyStartState { + &self.hystart_state + } + + fn hy_mut(&mut self) -> &mut hystart::HyStartState { + &mut self.hystart_state + } +} + +impl Default for CubicState { + fn default() -> Self { + Self { + // NOTE: Initializing this to 1 deviates from the C code. It does + // not change the behavior. + cnt: NonZeroU32::MIN, + last_max_cwnd: 0, + last_cwnd: 0, + last_time: 0, + origin_point: 0, + K: 0, + epoch_start: None, + ack_cnt: 0, + tcp_cwnd: 0, + hystart_state: hystart::HyStartState::default(), + } + } +} + +impl CubicState { + /// Checks if the current CUBIC increase is less aggressive than normal TCP, + /// i.e., if we are in the TCP-friendly region. If so, returns `cnt` that + /// increases at the speed of normal TCP. + #[inline] + fn tcp_friendliness(&mut self, cnt: u32, cwnd: u32) -> u32 { + if !TCP_FRIENDLINESS { + return cnt; + } + + // Estimate cwnd of normal TCP. + // cwnd/3 * (1 + BETA)/(1 - BETA) + let delta = (cwnd * BETA_SCALE) >> 3; + // W__tcp(t) = W__tcp(t__0) + (acks(t) - acks(t__0)) / delta + while self.ack_cnt > delta { + self.ack_cnt -= delta; + self.tcp_cwnd += 1; + } + + //TODO: remove + pr_info!( + "tcp_friendliness: tcp_cwnd {}, cwnd {}, start {}us", + self.tcp_cwnd, + cwnd, + self.hystart_state.start_time, + ); + + // We are slower than normal TCP. + if self.tcp_cwnd > cwnd { + let delta = self.tcp_cwnd - cwnd; + + min(cnt, cwnd / delta) + } else { + cnt + } + } + + /// Returns the new value of `cnt` to keep the window grow on the cubic + /// curve. + fn update(&mut self, cwnd: u32, acked: u32) -> NonZeroU32 { + let now: time::Msecs32 = time::ktime_get_boot_fast_ms32(); + + self.ack_cnt += acked; + + if self.last_cwnd == cwnd && now.wrapping_sub(self.last_time) <= time::MSEC_PER_SEC / 32 { + return self.cnt; + } + + // We can update the CUBIC function at most once every ms. + if self.epoch_start.is_some() && now == self.last_time { + let cnt = self.tcp_friendliness(self.cnt.get(), cwnd); + + // SAFETY: 2 != 0. QED. + self.cnt = unsafe { NonZeroU32::new_unchecked(max(2, cnt)) }; + + return self.cnt; + } + + self.last_cwnd = cwnd; + self.last_time = now; + + if self.epoch_start.is_none() { + self.epoch_start = Some(now); + self.ack_cnt = acked; + self.tcp_cwnd = cwnd; + + if self.last_max_cwnd <= cwnd { + self.K = 0; + self.origin_point = cwnd; + } else { + // K = (SRTT/C * (W__max - cwnd))^1/3 + self.K = cubic_root(CUBE_FACTOR * ((self.last_max_cwnd - cwnd) as u64)); + self.origin_point = self.last_max_cwnd; + } + } + + // PANIC: This is always `Some`. + let epoch_start: time::Msecs32 = self.epoch_start.unwrap(); + let Some(delay_min) = self.hystart_state.delay_min else { + pr_err!("update: delay_min was None"); + return self.cnt; + }; + + // NOTE: Addition might overflow after 50 days without a loss, C uses a + // `u64` here. + let t: time::Msecs32 = + now.wrapping_sub(epoch_start) + (delay_min / (time::USEC_PER_MSEC as time::Usecs32)); + let offs: time::Msecs32 = if t < self.K { self.K - t } else { t - self.K }; + + // Calculate c/rtt * (t-K)^3 and change units to seconds. + // Widen type to prevent overflow. + let offs = offs as u64; + let delta = (((CUBE_RTT_SCALE as u64 * offs * offs * offs) >> 10) / 1_000_000_000) as u32; + // Calculate the full cubic function c/rtt * (t - K)^3 + W__max. + let target = if t < self.K { + self.origin_point - delta + } else { + self.origin_point + delta + }; + + // TODO: remove + pr_info!( + "update: now {}ms, epoch_start {}ms, t {}ms, K {}ms, |t - K| {}ms, last_max_cwnd {}, origin_point {}, target {}, start {}us", + now, + epoch_start, + t, + self.K, + offs, + self.last_max_cwnd, + self.origin_point, + target, + self.hystart_state.start_time, + ); + + let mut cnt = if target > cwnd { + cwnd / (target - cwnd) + } else { + // Effectively keeps cwnd constant for the next RTT. + 100 * cwnd + }; + + // In initial epoch or after timeout we grow at a minimum rate. + if self.last_max_cwnd == 0 { + cnt = min(cnt, 20); + } + + // SAFETY: 2 != 0. QED. + self.cnt = unsafe { NonZeroU32::new_unchecked(max(2, self.tcp_friendliness(cnt, cwnd))) }; + + self.cnt + } + + fn reset(&mut self) { + // TODO: remove + let tmp = self.hystart_state.start_time; + + *self = Self::default(); + + // TODO: remove + self.hystart_state.start_time = tmp; + } +} + +/// Calculate the cubic root of `a` using a table lookup followed by one +/// Newton-Raphson iteration. +// E[ |(cubic_root(x) - x.cbrt()) / x.cbrt()| ] = 0.71% for x in 1..1_000_000. +// E[ |(cubic_root(x) - x.cbrt()) / x.cbrt()| ] = 8.87% for x in 1..63. +// Where everything is `f64` and `.cbrt` is Rust's builtin. No overflow panics +// in this domain. +const fn cubic_root(a: u64) -> u32 { + const V: [u8; 64] = [ + 0, 54, 54, 54, 118, 118, 118, 118, 123, 129, 134, 138, 143, 147, 151, 156, 157, 161, 164, + 168, 170, 173, 176, 179, 181, 185, 187, 190, 192, 194, 197, 199, 200, 202, 204, 206, 209, + 211, 213, 215, 217, 219, 221, 222, 224, 225, 227, 229, 231, 232, 234, 236, 237, 239, 240, + 242, 244, 245, 246, 248, 250, 251, 252, 254, + ]; + + let mut b = fls64(a) as u32; + if b < 7 { + return ((V[a as usize] as u32) + 35) >> 6; + } + + b = ((b * 84) >> 8) - 1; + let shift = a >> (b * 3); + + let mut x = (((V[shift as usize] as u32) + 10) << b) >> 6; + x = 2 * x + (a / ((x * (x - 1)) as u64)) as u32; + + (x * 341) >> 10 +} + +/// Find last set bit in a 64-bit word. +/// +/// The last (most significant) bit is at position 64. +#[inline] +const fn fls64(x: u64) -> u8 { + (64 - x.leading_zeros()) as u8 +}