diff --git a/Cargo.toml b/Cargo.toml index 988e7507bef..b4e9d7cc8bf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -164,6 +164,7 @@ suspicious_operation_groupings = { level = "allow", priority = 1 } use_self = { level = "allow", priority = 1 } while_float = { level = "allow", priority = 1 } needless_pass_by_ref_mut = { level = "allow", priority = 1 } +too_long_first_doc_paragraph = { level = "allow", priority = 1 } # cargo-lints: cargo_common_metadata = { level = "allow", priority = 1 } # style-lints: diff --git a/DIRECTORY.md b/DIRECTORY.md index f4e1fa0e58c..df0d5e2f7e3 100644 --- a/DIRECTORY.md +++ b/DIRECTORY.md @@ -53,8 +53,10 @@ * [Decimal To Hexadecimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/decimal_to_hexadecimal.rs) * [Hexadecimal To Binary](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/hexadecimal_to_binary.rs) * [Hexadecimal To Decimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/hexadecimal_to_decimal.rs) + * [Length Conversion](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/length_conversion.rs) * [Octal To Binary](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/octal_to_binary.rs) * [Octal To Decimal](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/octal_to_decimal.rs) + * [Rgb Cmyk Conversion](https://github.com/TheAlgorithms/Rust/blob/master/src/conversions/rgb_cmyk_conversion.rs) * Data Structures * [Avl Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/avl_tree.rs) * [B Tree](https://github.com/TheAlgorithms/Rust/blob/master/src/data_structures/b_tree.rs) @@ -156,6 +158,7 @@ * [Cholesky](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/cholesky.rs) * [K Means](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/k_means.rs) * [Linear Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/linear_regression.rs) + * [Logistic Regression](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/logistic_regression.rs) * Loss Function * [Average Margin Ranking Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/average_margin_ranking_loss.rs) * [Hinge Loss](https://github.com/TheAlgorithms/Rust/blob/master/src/machine_learning/loss_function/hinge_loss.rs) diff --git a/src/conversions/length_conversion.rs b/src/conversions/length_conversion.rs new file mode 100644 index 00000000000..4a056ed3052 --- /dev/null +++ b/src/conversions/length_conversion.rs @@ -0,0 +1,94 @@ +/// Author : https://github.com/ali77gh +/// Conversion of length units. +/// +/// Available Units: +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Millimeter +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Centimeter +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Meter +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Kilometer +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Inch +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Foot +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Yard +/// -> Wikipedia reference: https://en.wikipedia.org/wiki/Mile + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub enum LengthUnit { + Millimeter, + Centimeter, + Meter, + Kilometer, + Inch, + Foot, + Yard, + Mile, +} + +fn unit_to_meter_multiplier(from: LengthUnit) -> f64 { + match from { + LengthUnit::Millimeter => 0.001, + LengthUnit::Centimeter => 0.01, + LengthUnit::Meter => 1.0, + LengthUnit::Kilometer => 1000.0, + LengthUnit::Inch => 0.0254, + LengthUnit::Foot => 0.3048, + LengthUnit::Yard => 0.9144, + LengthUnit::Mile => 1609.34, + } +} + +fn unit_to_meter(input: f64, from: LengthUnit) -> f64 { + input * unit_to_meter_multiplier(from) +} + +fn meter_to_unit(input: f64, to: LengthUnit) -> f64 { + input / unit_to_meter_multiplier(to) +} + +/// This function will convert a value in unit of [from] to value in unit of [to] +/// by first converting it to meter and than convert it to destination unit +pub fn length_conversion(input: f64, from: LengthUnit, to: LengthUnit) -> f64 { + meter_to_unit(unit_to_meter(input, from), to) +} + +#[cfg(test)] +mod length_conversion_tests { + use std::collections::HashMap; + + use super::LengthUnit::*; + use super::*; + + #[test] + fn zero_to_zero() { + let units = vec![ + Millimeter, Centimeter, Meter, Kilometer, Inch, Foot, Yard, Mile, + ]; + + for u1 in units.clone() { + for u2 in units.clone() { + assert_eq!(length_conversion(0f64, u1, u2), 0f64); + } + } + } + + #[test] + fn length_of_one_meter() { + let meter_in_different_units = HashMap::from([ + (Millimeter, 1000f64), + (Centimeter, 100f64), + (Kilometer, 0.001f64), + (Inch, 39.37007874015748f64), + (Foot, 3.280839895013123f64), + (Yard, 1.0936132983377078f64), + (Mile, 0.0006213727366498068f64), + ]); + for (input_unit, input_value) in &meter_in_different_units { + for (target_unit, target_value) in &meter_in_different_units { + assert!( + num_traits::abs( + length_conversion(*input_value, *input_unit, *target_unit) - *target_value + ) < 0.0000001 + ); + } + } + } +} diff --git a/src/conversions/mod.rs b/src/conversions/mod.rs index af02e16a631..a83c46bf600 100644 --- a/src/conversions/mod.rs +++ b/src/conversions/mod.rs @@ -4,13 +4,17 @@ mod decimal_to_binary; mod decimal_to_hexadecimal; mod hexadecimal_to_binary; mod hexadecimal_to_decimal; +mod length_conversion; mod octal_to_binary; mod octal_to_decimal; +mod rgb_cmyk_conversion; pub use self::binary_to_decimal::binary_to_decimal; pub use self::binary_to_hexadecimal::binary_to_hexadecimal; pub use self::decimal_to_binary::decimal_to_binary; pub use self::decimal_to_hexadecimal::decimal_to_hexadecimal; pub use self::hexadecimal_to_binary::hexadecimal_to_binary; pub use self::hexadecimal_to_decimal::hexadecimal_to_decimal; +pub use self::length_conversion::length_conversion; pub use self::octal_to_binary::octal_to_binary; pub use self::octal_to_decimal::octal_to_decimal; +pub use self::rgb_cmyk_conversion::rgb_to_cmyk; diff --git a/src/conversions/rgb_cmyk_conversion.rs b/src/conversions/rgb_cmyk_conversion.rs new file mode 100644 index 00000000000..30a8bc9bd84 --- /dev/null +++ b/src/conversions/rgb_cmyk_conversion.rs @@ -0,0 +1,60 @@ +/// Author : https://github.com/ali77gh\ +/// References:\ +/// RGB: https://en.wikipedia.org/wiki/RGB_color_model\ +/// CMYK: https://en.wikipedia.org/wiki/CMYK_color_model\ + +/// This function Converts RGB to CMYK format +/// +/// ### Params +/// * `r` - red +/// * `g` - green +/// * `b` - blue +/// +/// ### Returns +/// (C, M, Y, K) +pub fn rgb_to_cmyk(rgb: (u8, u8, u8)) -> (u8, u8, u8, u8) { + // Safety: no need to check if input is positive and less than 255 because it's u8 + + // change scale from [0,255] to [0,1] + let (r, g, b) = ( + rgb.0 as f64 / 255f64, + rgb.1 as f64 / 255f64, + rgb.2 as f64 / 255f64, + ); + + match 1f64 - r.max(g).max(b) { + 1f64 => (0, 0, 0, 100), // pure black + k => ( + (100f64 * (1f64 - r - k) / (1f64 - k)) as u8, // c + (100f64 * (1f64 - g - k) / (1f64 - k)) as u8, // m + (100f64 * (1f64 - b - k) / (1f64 - k)) as u8, // y + (100f64 * k) as u8, // k + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! test_rgb_to_cmyk { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (rgb, cmyk) = $tc; + assert_eq!(rgb_to_cmyk(rgb), cmyk); + } + )* + } + } + + test_rgb_to_cmyk! { + white: ((255, 255, 255), (0, 0, 0, 0)), + gray: ((128, 128, 128), (0, 0, 0, 49)), + black: ((0, 0, 0), (0, 0, 0, 100)), + red: ((255, 0, 0), (0, 100, 100, 0)), + green: ((0, 255, 0), (100, 0, 100, 0)), + blue: ((0, 0, 255), (100, 100, 0, 0)), + } +} diff --git a/src/data_structures/segment_tree.rs b/src/data_structures/segment_tree.rs index 1a55dc8a47e..f569381967e 100644 --- a/src/data_structures/segment_tree.rs +++ b/src/data_structures/segment_tree.rs @@ -1,185 +1,224 @@ -use std::cmp::min; +//! A module providing a Segment Tree data structure for efficient range queries +//! and updates. It supports operations like finding the minimum, maximum, +//! and sum of segments in an array. + use std::fmt::Debug; use std::ops::Range; -/// This data structure implements a segment-tree that can efficiently answer range (interval) queries on arrays. -/// It represents this array as a binary tree of merged intervals. From top to bottom: [aggregated value for the overall array], then [left-hand half, right hand half], etc. until [each individual value, ...] -/// It is generic over a reduction function for each segment or interval: basically, to describe how we merge two intervals together. -/// Note that this function should be commutative and associative -/// It could be `std::cmp::min(interval_1, interval_2)` or `std::cmp::max(interval_1, interval_2)`, or `|a, b| a + b`, `|a, b| a * b` -pub struct SegmentTree { - len: usize, // length of the represented - tree: Vec, // represents a binary tree of intervals as an array (as a BinaryHeap does, for instance) - merge: fn(T, T) -> T, // how we merge two values together +/// Custom error types representing possible errors that can occur during operations on the `SegmentTree`. +#[derive(Debug, PartialEq, Eq)] +pub enum SegmentTreeError { + /// Error indicating that an index is out of bounds. + IndexOutOfBounds, + /// Error indicating that a range provided for a query is invalid. + InvalidRange, +} + +/// A structure representing a Segment Tree. This tree can be used to efficiently +/// perform range queries and updates on an array of elements. +pub struct SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// The length of the input array for which the segment tree is built. + size: usize, + /// A vector representing the segment tree. + nodes: Vec, + /// A merging function defined as a closure or callable type. + merge_fn: F, } -impl SegmentTree { - /// Builds a SegmentTree from an array and a merge function - pub fn from_vec(arr: &[T], merge: fn(T, T) -> T) -> Self { - let len = arr.len(); - let mut buf: Vec = vec![T::default(); 2 * len]; - // Populate the tree bottom-up, from right to left - buf[len..(2 * len)].clone_from_slice(&arr[0..len]); // last len pos is the bottom of the tree -> every individual value - for i in (1..len).rev() { - // a nice property of this "flat" representation of a tree: the parent of an element at index i is located at index i/2 - buf[i] = merge(buf[2 * i], buf[2 * i + 1]); +impl SegmentTree +where + T: Debug + Default + Ord + Copy, + F: Fn(T, T) -> T, +{ + /// Creates a new `SegmentTree` from the provided slice of elements. + /// + /// # Arguments + /// + /// * `arr`: A slice of elements of type `T` to initialize the segment tree. + /// * `merge`: A merging function that defines how to merge two elements of type `T`. + /// + /// # Returns + /// + /// A new `SegmentTree` instance populated with the given elements. + pub fn from_vec(arr: &[T], merge: F) -> Self { + let size = arr.len(); + let mut buffer: Vec = vec![T::default(); 2 * size]; + + // Populate the leaves of the tree + buffer[size..(2 * size)].clone_from_slice(arr); + for idx in (1..size).rev() { + buffer[idx] = merge(buffer[2 * idx], buffer[2 * idx + 1]); } + SegmentTree { - len, - tree: buf, - merge, + size, + nodes: buffer, + merge_fn: merge, } } - /// Query the range (exclusive) - /// returns None if the range is out of the array's boundaries (eg: if start is after the end of the array, or start > end, etc.) - /// return the aggregate of values over this range otherwise - pub fn query(&self, range: Range) -> Option { - let mut l = range.start + self.len; - let mut r = min(self.len, range.end) + self.len; - let mut res = None; - // Check Wikipedia or other detailed explanations here for how to navigate the tree bottom-up to limit the number of operations - while l < r { - if l % 2 == 1 { - res = Some(match res { - None => self.tree[l], - Some(old) => (self.merge)(old, self.tree[l]), + /// Queries the segment tree for the result of merging the elements in the given range. + /// + /// # Arguments + /// + /// * `range`: A range specified as `Range`, indicating the start (inclusive) + /// and end (exclusive) indices of the segment to query. + /// + /// # Returns + /// + /// * `Ok(Some(result))` if the query was successful and there are elements in the range, + /// * `Ok(None)` if the range is empty, + /// * `Err(SegmentTreeError::InvalidRange)` if the provided range is invalid. + pub fn query(&self, range: Range) -> Result, SegmentTreeError> { + if range.start >= self.size || range.end > self.size { + return Err(SegmentTreeError::InvalidRange); + } + + let mut left = range.start + self.size; + let mut right = range.end + self.size; + let mut result = None; + + // Iterate through the segment tree to accumulate results + while left < right { + if left % 2 == 1 { + result = Some(match result { + None => self.nodes[left], + Some(old) => (self.merge_fn)(old, self.nodes[left]), }); - l += 1; + left += 1; } - if r % 2 == 1 { - r -= 1; - res = Some(match res { - None => self.tree[r], - Some(old) => (self.merge)(old, self.tree[r]), + if right % 2 == 1 { + right -= 1; + result = Some(match result { + None => self.nodes[right], + Some(old) => (self.merge_fn)(old, self.nodes[right]), }); } - l /= 2; - r /= 2; + left /= 2; + right /= 2; } - res + + Ok(result) } - /// Updates the value at index `idx` in the original array with a new value `val` - pub fn update(&mut self, idx: usize, val: T) { - // change every value where `idx` plays a role, bottom -> up - // 1: change in the right-hand side of the tree (bottom row) - let mut idx = idx + self.len; - self.tree[idx] = val; - - // 2: then bubble up - idx /= 2; - while idx != 0 { - self.tree[idx] = (self.merge)(self.tree[2 * idx], self.tree[2 * idx + 1]); - idx /= 2; + /// Updates the value at the specified index in the segment tree. + /// + /// # Arguments + /// + /// * `idx`: The index (0-based) of the element to update. + /// * `val`: The new value of type `T` to set at the specified index. + /// + /// # Returns + /// + /// * `Ok(())` if the update was successful, + /// * `Err(SegmentTreeError::IndexOutOfBounds)` if the index is out of bounds. + pub fn update(&mut self, idx: usize, val: T) -> Result<(), SegmentTreeError> { + if idx >= self.size { + return Err(SegmentTreeError::IndexOutOfBounds); + } + + let mut index = idx + self.size; + if self.nodes[index] == val { + return Ok(()); } + + self.nodes[index] = val; + while index > 1 { + index /= 2; + self.nodes[index] = (self.merge_fn)(self.nodes[2 * index], self.nodes[2 * index + 1]); + } + + Ok(()) } } #[cfg(test)] mod tests { use super::*; - use quickcheck::TestResult; - use quickcheck_macros::quickcheck; use std::cmp::{max, min}; #[test] fn test_min_segments() { let vec = vec![-30, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; - let min_seg_tree = SegmentTree::from_vec(&vec, min); - assert_eq!(Some(-5), min_seg_tree.query(4..7)); - assert_eq!(Some(-30), min_seg_tree.query(0..vec.len())); - assert_eq!(Some(-30), min_seg_tree.query(0..2)); - assert_eq!(Some(-4), min_seg_tree.query(1..3)); - assert_eq!(Some(-5), min_seg_tree.query(1..7)); + let mut min_seg_tree = SegmentTree::from_vec(&vec, min); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.query(0..vec.len()), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(0..2), Ok(Some(-30))); + assert_eq!(min_seg_tree.query(1..3), Ok(Some(-4))); + assert_eq!(min_seg_tree.query(1..7), Ok(Some(-5))); + assert_eq!(min_seg_tree.update(5, 10), Ok(())); + assert_eq!(min_seg_tree.update(14, -8), Ok(())); + assert_eq!(min_seg_tree.query(4..7), Ok(Some(3))); + assert_eq!( + min_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(min_seg_tree.query(5..5), Ok(None)); + assert_eq!( + min_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + min_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); } #[test] fn test_max_segments() { - let val_at_6 = 6; - let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8]; + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; let mut max_seg_tree = SegmentTree::from_vec(&vec, max); - assert_eq!(Some(15), max_seg_tree.query(0..vec.len())); - let max_4_to_6 = 6; - assert_eq!(Some(max_4_to_6), max_seg_tree.query(4..7)); - let delta = 2; - max_seg_tree.update(6, val_at_6 + delta); - assert_eq!(Some(val_at_6 + delta), max_seg_tree.query(4..7)); + assert_eq!(max_seg_tree.query(0..vec.len()), Ok(Some(15))); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(7))); + assert_eq!(max_seg_tree.query(4..8), Ok(Some(11))); + assert_eq!(max_seg_tree.query(8..10), Ok(Some(9))); + assert_eq!(max_seg_tree.query(9..12), Ok(Some(15))); + assert_eq!(max_seg_tree.update(4, 10), Ok(())); + assert_eq!(max_seg_tree.update(14, -8), Ok(())); + assert_eq!(max_seg_tree.query(3..5), Ok(Some(10))); + assert_eq!( + max_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(max_seg_tree.query(5..5), Ok(None)); + assert_eq!( + max_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + max_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) + ); } #[test] fn test_sum_segments() { - let val_at_6 = 6; - let vec = vec![1, 2, -4, 7, 3, -5, val_at_6, 11, -20, 9, 14, 15, 5, 2, -8]; + let vec = vec![1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]; let mut sum_seg_tree = SegmentTree::from_vec(&vec, |a, b| a + b); - for (i, val) in vec.iter().enumerate() { - assert_eq!(Some(*val), sum_seg_tree.query(i..(i + 1))); - } - let sum_4_to_6 = sum_seg_tree.query(4..7); - assert_eq!(Some(4), sum_4_to_6); - let delta = 3; - sum_seg_tree.update(6, val_at_6 + delta); + assert_eq!(sum_seg_tree.query(0..vec.len()), Ok(Some(38))); + assert_eq!(sum_seg_tree.query(1..4), Ok(Some(5))); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(4))); + assert_eq!(sum_seg_tree.query(6..9), Ok(Some(-3))); + assert_eq!(sum_seg_tree.query(9..vec.len()), Ok(Some(37))); + assert_eq!(sum_seg_tree.update(5, 10), Ok(())); + assert_eq!(sum_seg_tree.update(14, -8), Ok(())); + assert_eq!(sum_seg_tree.query(4..7), Ok(Some(19))); assert_eq!( - sum_4_to_6.unwrap() + delta, - sum_seg_tree.query(4..7).unwrap() + sum_seg_tree.update(15, 100), + Err(SegmentTreeError::IndexOutOfBounds) + ); + assert_eq!(sum_seg_tree.query(5..5), Ok(None)); + assert_eq!( + sum_seg_tree.query(10..16), + Err(SegmentTreeError::InvalidRange) + ); + assert_eq!( + sum_seg_tree.query(15..20), + Err(SegmentTreeError::InvalidRange) ); - } - - // Some properties over segment trees: - // When asking for the range of the overall array, return the same as iter().min() or iter().max(), etc. - // When asking for an interval containing a single value, return this value, no matter the merge function - - #[quickcheck] - fn check_overall_interval_min(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, min); - TestResult::from_bool(array.iter().min().copied() == seg_tree.query(0..array.len())) - } - - #[quickcheck] - fn check_overall_interval_max(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, max); - TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) - } - - #[quickcheck] - fn check_overall_interval_sum(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, max); - TestResult::from_bool(array.iter().max().copied() == seg_tree.query(0..array.len())) - } - - #[quickcheck] - fn check_single_interval_min(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, min); - for (i, value) in array.into_iter().enumerate() { - let res = seg_tree.query(i..(i + 1)); - if res != Some(value) { - return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); - } - } - TestResult::passed() - } - - #[quickcheck] - fn check_single_interval_max(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, max); - for (i, value) in array.into_iter().enumerate() { - let res = seg_tree.query(i..(i + 1)); - if res != Some(value) { - return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); - } - } - TestResult::passed() - } - - #[quickcheck] - fn check_single_interval_sum(array: Vec) -> TestResult { - let seg_tree = SegmentTree::from_vec(&array, max); - for (i, value) in array.into_iter().enumerate() { - let res = seg_tree.query(i..(i + 1)); - if res != Some(value) { - return TestResult::error(format!("Expected {:?}, got {:?}", Some(value), res)); - } - } - TestResult::passed() } } diff --git a/src/dynamic_programming/fibonacci.rs b/src/dynamic_programming/fibonacci.rs index a77f0aedc0f..a7db9f9562c 100644 --- a/src/dynamic_programming/fibonacci.rs +++ b/src/dynamic_programming/fibonacci.rs @@ -158,7 +158,7 @@ fn matrix_multiply(multiplier: &[Vec], multiplicand: &[Vec]) -> Vec< // of columns as the multiplicand has rows. let mut result: Vec> = vec![]; let mut temp; - // Using variable to compare lenghts of rows in multiplicand later + // Using variable to compare lengths of rows in multiplicand later let row_right_length = multiplicand[0].len(); for row_left in 0..multiplier.len() { if multiplier[row_left].len() != multiplicand.len() { @@ -180,6 +180,33 @@ fn matrix_multiply(multiplier: &[Vec], multiplicand: &[Vec]) -> Vec< result } +/// Binary lifting fibonacci +/// +/// Following properties of F(n) could be deduced from the matrix formula above: +/// +/// F(2n) = F(n) * (2F(n+1) - F(n)) +/// F(2n+1) = F(n+1)^2 + F(n)^2 +/// +/// Therefore F(n) and F(n+1) can be derived from F(n>>1) and F(n>>1 + 1), which +/// has a smaller constant in both time and space compared to matrix fibonacci. +pub fn binary_lifting_fibonacci(n: u32) -> u128 { + // the state always stores F(k), F(k+1) for some k, initially F(0), F(1) + let mut state = (0u128, 1u128); + + for i in (0..u32::BITS - n.leading_zeros()).rev() { + // compute F(2k), F(2k+1) from F(k), F(k+1) + state = ( + state.0 * (2 * state.1 - state.0), + state.0 * state.0 + state.1 * state.1, + ); + if n & (1 << i) != 0 { + state = (state.1, state.0 + state.1); + } + } + + state.0 +} + /// nth_fibonacci_number_modulo_m(n, m) returns the nth fibonacci number modulo the specified m /// i.e. F(n) % m pub fn nth_fibonacci_number_modulo_m(n: i64, m: i64) -> i128 { @@ -195,7 +222,7 @@ pub fn nth_fibonacci_number_modulo_m(n: i64, m: i64) -> i128 { fn get_pisano_sequence_and_period(m: i64) -> (i128, Vec) { let mut a = 0; let mut b = 1; - let mut lenght: i128 = 0; + let mut length: i128 = 0; let mut pisano_sequence: Vec = vec![a, b]; // Iterating through all the fib numbers to get the sequence @@ -213,12 +240,12 @@ fn get_pisano_sequence_and_period(m: i64) -> (i128, Vec) { // This is a less elegant way to do it. pisano_sequence.pop(); pisano_sequence.pop(); - lenght = pisano_sequence.len() as i128; + length = pisano_sequence.len() as i128; break; } } - (lenght, pisano_sequence) + (length, pisano_sequence) } /// last_digit_of_the_sum_of_nth_fibonacci_number(n) returns the last digit of the sum of n fibonacci numbers. @@ -251,6 +278,7 @@ pub fn last_digit_of_the_sum_of_nth_fibonacci_number(n: i64) -> i64 { #[cfg(test)] mod tests { + use super::binary_lifting_fibonacci; use super::classical_fibonacci; use super::fibonacci; use super::last_digit_of_the_sum_of_nth_fibonacci_number; @@ -328,7 +356,7 @@ mod tests { } #[test] - /// Check that the itterative and recursive fibonacci + /// Check that the iterative and recursive fibonacci /// produce the same value. Both are combinatorial ( F(0) = F(1) = 1 ) fn test_iterative_and_recursive_equivalence() { assert_eq!(fibonacci(0), recursive_fibonacci(0)); @@ -398,6 +426,24 @@ mod tests { ); } + #[test] + fn test_binary_lifting_fibonacci() { + assert_eq!(binary_lifting_fibonacci(0), 0); + assert_eq!(binary_lifting_fibonacci(1), 1); + assert_eq!(binary_lifting_fibonacci(2), 1); + assert_eq!(binary_lifting_fibonacci(3), 2); + assert_eq!(binary_lifting_fibonacci(4), 3); + assert_eq!(binary_lifting_fibonacci(5), 5); + assert_eq!(binary_lifting_fibonacci(10), 55); + assert_eq!(binary_lifting_fibonacci(20), 6765); + assert_eq!(binary_lifting_fibonacci(21), 10946); + assert_eq!(binary_lifting_fibonacci(100), 354224848179261915075); + assert_eq!( + binary_lifting_fibonacci(184), + 127127879743834334146972278486287885163 + ); + } + #[test] fn test_nth_fibonacci_number_modulo_m() { assert_eq!(nth_fibonacci_number_modulo_m(5, 10), 5); diff --git a/src/dynamic_programming/is_subsequence.rs b/src/dynamic_programming/is_subsequence.rs index 07950fd4519..22b43c387b1 100644 --- a/src/dynamic_programming/is_subsequence.rs +++ b/src/dynamic_programming/is_subsequence.rs @@ -1,33 +1,71 @@ -// Given two strings str1 and str2, return true if str1 is a subsequence of str2, or false otherwise. -// A subsequence of a string is a new string that is formed from the original string -// by deleting some (can be none) of the characters without disturbing the relative -// positions of the remaining characters. -// (i.e., "ace" is a subsequence of "abcde" while "aec" is not). -pub fn is_subsequence(str1: &str, str2: &str) -> bool { - let mut it1 = 0; - let mut it2 = 0; +//! A module for checking if one string is a subsequence of another string. +//! +//! A subsequence is formed by deleting some (can be none) of the characters +//! from the original string without disturbing the relative positions of the +//! remaining characters. This module provides a function to determine if +//! a given string is a subsequence of another string. - let byte1 = str1.as_bytes(); - let byte2 = str2.as_bytes(); +/// Checks if `sub` is a subsequence of `main`. +/// +/// # Arguments +/// +/// * `sub` - A string slice that may be a subsequence. +/// * `main` - A string slice that is checked against. +/// +/// # Returns +/// +/// Returns `true` if `sub` is a subsequence of `main`, otherwise returns `false`. +pub fn is_subsequence(sub: &str, main: &str) -> bool { + let mut sub_iter = sub.chars().peekable(); + let mut main_iter = main.chars(); - while it1 < str1.len() && it2 < str2.len() { - if byte1[it1] == byte2[it2] { - it1 += 1; + while let Some(&sub_char) = sub_iter.peek() { + match main_iter.next() { + Some(main_char) if main_char == sub_char => { + sub_iter.next(); + } + None => return false, + _ => {} } - - it2 += 1; } - it1 == str1.len() + true } #[cfg(test)] mod tests { use super::*; - #[test] - fn test() { - assert!(is_subsequence("abc", "ahbgdc")); - assert!(!is_subsequence("axc", "ahbgdc")); + macro_rules! subsequence_tests { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (sub, main, expected) = $test_case; + assert_eq!(is_subsequence(sub, main), expected); + } + )* + }; + } + + subsequence_tests! { + test_empty_subsequence: ("", "ahbgdc", true), + test_empty_strings: ("", "", true), + test_non_empty_sub_empty_main: ("abc", "", false), + test_subsequence_found: ("abc", "ahbgdc", true), + test_subsequence_not_found: ("axc", "ahbgdc", false), + test_longer_sub: ("abcd", "abc", false), + test_single_character_match: ("a", "ahbgdc", true), + test_single_character_not_match: ("x", "ahbgdc", false), + test_subsequence_at_start: ("abc", "abchello", true), + test_subsequence_at_end: ("cde", "abcde", true), + test_same_characters: ("aaa", "aaaaa", true), + test_interspersed_subsequence: ("ace", "abcde", true), + test_different_chars_in_subsequence: ("aceg", "abcdef", false), + test_single_character_in_main_not_match: ("a", "b", false), + test_single_character_in_main_match: ("b", "b", true), + test_subsequence_with_special_chars: ("a1!c", "a1!bcd", true), + test_case_sensitive: ("aBc", "abc", false), + test_subsequence_with_whitespace: ("hello world", "h e l l o w o r l d", true), } } diff --git a/src/dynamic_programming/mod.rs b/src/dynamic_programming/mod.rs index 76059465899..f28fc7c615c 100644 --- a/src/dynamic_programming/mod.rs +++ b/src/dynamic_programming/mod.rs @@ -20,6 +20,7 @@ mod word_break; pub use self::coin_change::coin_change; pub use self::egg_dropping::egg_drop; +pub use self::fibonacci::binary_lifting_fibonacci; pub use self::fibonacci::classical_fibonacci; pub use self::fibonacci::fibonacci; pub use self::fibonacci::last_digit_of_the_sum_of_nth_fibonacci_number; diff --git a/src/machine_learning/logistic_regression.rs b/src/machine_learning/logistic_regression.rs new file mode 100644 index 00000000000..fc020a795ac --- /dev/null +++ b/src/machine_learning/logistic_regression.rs @@ -0,0 +1,92 @@ +use super::optimization::gradient_descent; +use std::f64::consts::E; + +/// Returns the wieghts after performing Logistic regression on the input data points. +pub fn logistic_regression( + data_points: Vec<(Vec, f64)>, + iterations: usize, + learning_rate: f64, +) -> Option> { + if data_points.is_empty() { + return None; + } + + let num_features = data_points[0].0.len() + 1; + let mut params = vec![0.0; num_features]; + + let derivative_fn = |params: &[f64]| derivative(params, &data_points); + + gradient_descent(derivative_fn, &mut params, learning_rate, iterations as i32); + + Some(params) +} + +fn derivative(params: &[f64], data_points: &[(Vec, f64)]) -> Vec { + let num_features = params.len(); + let mut gradients = vec![0.0; num_features]; + + for (features, y_i) in data_points { + let z = params[0] + + params[1..] + .iter() + .zip(features) + .map(|(p, x)| p * x) + .sum::(); + let prediction = 1.0 / (1.0 + E.powf(-z)); + + gradients[0] += prediction - y_i; + for (i, x_i) in features.iter().enumerate() { + gradients[i + 1] += (prediction - y_i) * x_i; + } + } + + gradients +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_logistic_regression_simple() { + let data = vec![ + (vec![0.0], 0.0), + (vec![1.0], 0.0), + (vec![2.0], 0.0), + (vec![3.0], 1.0), + (vec![4.0], 1.0), + (vec![5.0], 1.0), + ]; + + let result = logistic_regression(data, 10000, 0.05); + assert!(result.is_some()); + + let params = result.unwrap(); + assert!((params[0] + 17.65).abs() < 1.0); + assert!((params[1] - 7.13).abs() < 1.0); + } + + #[test] + fn test_logistic_regression_extreme_data() { + let data = vec![ + (vec![-100.0], 0.0), + (vec![-10.0], 0.0), + (vec![0.0], 0.0), + (vec![10.0], 1.0), + (vec![100.0], 1.0), + ]; + + let result = logistic_regression(data, 10000, 0.05); + assert!(result.is_some()); + + let params = result.unwrap(); + assert!((params[0] + 6.20).abs() < 1.0); + assert!((params[1] - 5.5).abs() < 1.0); + } + + #[test] + fn test_logistic_regression_no_data() { + let result = logistic_regression(vec![], 5000, 0.1); + assert_eq!(result, None); + } +} diff --git a/src/machine_learning/mod.rs b/src/machine_learning/mod.rs index c77fd65116b..534326d2121 100644 --- a/src/machine_learning/mod.rs +++ b/src/machine_learning/mod.rs @@ -1,12 +1,14 @@ mod cholesky; mod k_means; mod linear_regression; +mod logistic_regression; mod loss_function; mod optimization; pub use self::cholesky::cholesky; pub use self::k_means::k_means; pub use self::linear_regression::linear_regression; +pub use self::logistic_regression::logistic_regression; pub use self::loss_function::average_margin_ranking_loss; pub use self::loss_function::hng_loss; pub use self::loss_function::huber_loss; diff --git a/src/machine_learning/optimization/gradient_descent.rs b/src/machine_learning/optimization/gradient_descent.rs index 6701a688d15..fd322a23ff3 100644 --- a/src/machine_learning/optimization/gradient_descent.rs +++ b/src/machine_learning/optimization/gradient_descent.rs @@ -23,7 +23,7 @@ /// A reference to the optimized parameter vector `x`. pub fn gradient_descent( - derivative_fn: fn(&[f64]) -> Vec, + derivative_fn: impl Fn(&[f64]) -> Vec, x: &mut Vec, learning_rate: f64, num_iterations: i32, diff --git a/src/searching/linear_search.rs b/src/searching/linear_search.rs index c2995754509..d38b224d0a6 100644 --- a/src/searching/linear_search.rs +++ b/src/searching/linear_search.rs @@ -1,6 +1,15 @@ -use std::cmp::PartialEq; - -pub fn linear_search(item: &T, arr: &[T]) -> Option { +/// Performs a linear search on the given array, returning the index of the first occurrence of the item. +/// +/// # Arguments +/// +/// * `item` - A reference to the item to search for in the array. +/// * `arr` - A slice of items to search within. +/// +/// # Returns +/// +/// * `Some(usize)` - The index of the first occurrence of the item, if found. +/// * `None` - If the item is not found in the array. +pub fn linear_search(item: &T, arr: &[T]) -> Option { for (i, data) in arr.iter().enumerate() { if item == data { return Some(i); @@ -14,36 +23,54 @@ pub fn linear_search(item: &T, arr: &[T]) -> Option { mod tests { use super::*; - #[test] - fn search_strings() { - let index = linear_search(&"a", &["a", "b", "c", "d", "google", "zoo"]); - assert_eq!(index, Some(0)); - } - - #[test] - fn search_ints() { - let index = linear_search(&4, &[1, 2, 3, 4]); - assert_eq!(index, Some(3)); - - let index = linear_search(&3, &[1, 2, 3, 4]); - assert_eq!(index, Some(2)); - - let index = linear_search(&2, &[1, 2, 3, 4]); - assert_eq!(index, Some(1)); - - let index = linear_search(&1, &[1, 2, 3, 4]); - assert_eq!(index, Some(0)); - } - - #[test] - fn not_found() { - let index = linear_search(&5, &[1, 2, 3, 4]); - assert_eq!(index, None); + macro_rules! test_cases { + ($($name:ident: $tc:expr,)*) => { + $( + #[test] + fn $name() { + let (item, arr, expected) = $tc; + if let Some(expected_index) = expected { + assert_eq!(arr[expected_index], item); + } + assert_eq!(linear_search(&item, arr), expected); + } + )* + } } - #[test] - fn empty() { - let index = linear_search(&1, &[]); - assert_eq!(index, None); + test_cases! { + empty: ("a", &[] as &[&str], None), + one_item_found: ("a", &["a"], Some(0)), + one_item_not_found: ("b", &["a"], None), + search_strings_asc_start: ("a", &["a", "b", "c", "d", "google", "zoo"], Some(0)), + search_strings_asc_middle: ("google", &["a", "b", "c", "d", "google", "zoo"], Some(4)), + search_strings_asc_last: ("zoo", &["a", "b", "c", "d", "google", "zoo"], Some(5)), + search_strings_asc_not_found: ("x", &["a", "b", "c", "d", "google", "zoo"], None), + search_strings_desc_start: ("zoo", &["zoo", "google", "d", "c", "b", "a"], Some(0)), + search_strings_desc_middle: ("google", &["zoo", "google", "d", "c", "b", "a"], Some(1)), + search_strings_desc_last: ("a", &["zoo", "google", "d", "c", "b", "a"], Some(5)), + search_strings_desc_not_found: ("x", &["zoo", "google", "d", "c", "b", "a"], None), + search_ints_asc_start: (1, &[1, 2, 3, 4], Some(0)), + search_ints_asc_middle: (3, &[1, 2, 3, 4], Some(2)), + search_ints_asc_end: (4, &[1, 2, 3, 4], Some(3)), + search_ints_asc_not_found: (5, &[1, 2, 3, 4], None), + search_ints_desc_start: (4, &[4, 3, 2, 1], Some(0)), + search_ints_desc_middle: (3, &[4, 3, 2, 1], Some(1)), + search_ints_desc_end: (1, &[4, 3, 2, 1], Some(3)), + search_ints_desc_not_found: (5, &[4, 3, 2, 1], None), + with_gaps_0: (0, &[1, 3, 8, 11], None), + with_gaps_1: (1, &[1, 3, 8, 11], Some(0)), + with_gaps_2: (2, &[1, 3, 8, 11], None), + with_gaps_3: (3, &[1, 3, 8, 11], Some(1)), + with_gaps_4: (4, &[1, 3, 8, 10], None), + with_gaps_5: (5, &[1, 3, 8, 10], None), + with_gaps_6: (6, &[1, 3, 8, 10], None), + with_gaps_7: (7, &[1, 3, 8, 11], None), + with_gaps_8: (8, &[1, 3, 8, 11], Some(2)), + with_gaps_9: (9, &[1, 3, 8, 11], None), + with_gaps_10: (10, &[1, 3, 8, 11], None), + with_gaps_11: (11, &[1, 3, 8, 11], Some(3)), + with_gaps_12: (12, &[1, 3, 8, 11], None), + with_gaps_13: (13, &[1, 3, 8, 11], None), } } diff --git a/src/string/anagram.rs b/src/string/anagram.rs index b81b7804707..9ea37dc4f6f 100644 --- a/src/string/anagram.rs +++ b/src/string/anagram.rs @@ -1,10 +1,68 @@ -pub fn check_anagram(s: &str, t: &str) -> bool { - sort_string(s) == sort_string(t) +use std::collections::HashMap; + +/// Custom error type representing an invalid character found in the input. +#[derive(Debug, PartialEq)] +pub enum AnagramError { + NonAlphabeticCharacter, } -fn sort_string(s: &str) -> Vec { - let mut res: Vec = s.to_ascii_lowercase().chars().collect::>(); - res.sort_unstable(); +/// Checks if two strings are anagrams, ignoring spaces and case sensitivity. +/// +/// # Arguments +/// +/// * `s` - First input string. +/// * `t` - Second input string. +/// +/// # Returns +/// +/// * `Ok(true)` if the strings are anagrams. +/// * `Ok(false)` if the strings are not anagrams. +/// * `Err(AnagramError)` if either string contains non-alphabetic characters. +pub fn check_anagram(s: &str, t: &str) -> Result { + let s_cleaned = clean_string(s)?; + let t_cleaned = clean_string(t)?; + + Ok(char_count(&s_cleaned) == char_count(&t_cleaned)) +} + +/// Cleans the input string by removing spaces and converting to lowercase. +/// Returns an error if any non-alphabetic character is found. +/// +/// # Arguments +/// +/// * `s` - Input string to clean. +/// +/// # Returns +/// +/// * `Ok(String)` containing the cleaned string (no spaces, lowercase). +/// * `Err(AnagramError)` if the string contains non-alphabetic characters. +fn clean_string(s: &str) -> Result { + s.chars() + .filter(|c| !c.is_whitespace()) + .map(|c| { + if c.is_alphabetic() { + Ok(c.to_ascii_lowercase()) + } else { + Err(AnagramError::NonAlphabeticCharacter) + } + }) + .collect() +} + +/// Computes the histogram of characters in a string. +/// +/// # Arguments +/// +/// * `s` - Input string. +/// +/// # Returns +/// +/// * A `HashMap` where the keys are characters and values are their count. +fn char_count(s: &str) -> HashMap { + let mut res = HashMap::new(); + for c in s.chars() { + *res.entry(c).or_insert(0) += 1; + } res } @@ -12,16 +70,42 @@ fn sort_string(s: &str) -> Vec { mod tests { use super::*; - #[test] - fn test_check_anagram() { - assert!(check_anagram("", "")); - assert!(check_anagram("A", "a")); - assert!(check_anagram("anagram", "nagaram")); - assert!(check_anagram("abcde", "edcba")); - assert!(check_anagram("sIlEnT", "LiStEn")); - - assert!(!check_anagram("", "z")); - assert!(!check_anagram("a", "z")); - assert!(!check_anagram("rat", "car")); + macro_rules! test_cases { + ($($name:ident: $test_case:expr,)*) => { + $( + #[test] + fn $name() { + let (s, t, expected) = $test_case; + assert_eq!(check_anagram(s, t), expected); + assert_eq!(check_anagram(t, s), expected); + } + )* + } + } + + test_cases! { + empty_strings: ("", "", Ok(true)), + empty_and_non_empty: ("", "Ted Morgan", Ok(false)), + single_char_same: ("z", "Z", Ok(true)), + single_char_diff: ("g", "h", Ok(false)), + valid_anagram_lowercase: ("cheater", "teacher", Ok(true)), + valid_anagram_with_spaces: ("madam curie", "radium came", Ok(true)), + valid_anagram_mixed_cases: ("Satan", "Santa", Ok(true)), + valid_anagram_with_spaces_and_mixed_cases: ("Anna Madrigal", "A man and a girl", Ok(true)), + new_york_times: ("New York Times", "monkeys write", Ok(true)), + church_of_scientology: ("Church of Scientology", "rich chosen goofy cult", Ok(true)), + mcdonalds_restaurants: ("McDonald's restaurants", "Uncle Sam's standard rot", Err(AnagramError::NonAlphabeticCharacter)), + coronavirus: ("coronavirus", "carnivorous", Ok(true)), + synonym_evil: ("evil", "vile", Ok(true)), + synonym_gentleman: ("a gentleman", "elegant man", Ok(true)), + antigram: ("restful", "fluster", Ok(true)), + sentences: ("William Shakespeare", "I am a weakish speller", Ok(true)), + part_of_speech_adj_to_verb: ("silent", "listen", Ok(true)), + anagrammatized: ("Anagrams", "Ars magna", Ok(true)), + non_anagram: ("rat", "car", Ok(false)), + invalid_anagram_with_special_char: ("hello!", "world", Err(AnagramError::NonAlphabeticCharacter)), + invalid_anagram_with_numeric_chars: ("test123", "321test", Err(AnagramError::NonAlphabeticCharacter)), + invalid_anagram_with_symbols: ("check@anagram", "check@nagaram", Err(AnagramError::NonAlphabeticCharacter)), + non_anagram_length_mismatch: ("abc", "abcd", Ok(false)), } } diff --git a/src/string/shortest_palindrome.rs b/src/string/shortest_palindrome.rs index 80f52395194..f72a97119dd 100644 --- a/src/string/shortest_palindrome.rs +++ b/src/string/shortest_palindrome.rs @@ -1,72 +1,119 @@ -/* -The function shortest_palindrome expands the given string to shortest palindrome by adding a shortest prefix. -KMP. Source:https://www.scaler.com/topics/data-structures/kmp-algorithm/ -Prefix Functions and KPM. Source:https://oi-wiki.org/string/kmp/ -*/ +//! This module provides functions for finding the shortest palindrome +//! that can be formed by adding characters to the left of a given string. +//! References +//! +//! - [KMP](https://www.scaler.com/topics/data-structures/kmp-algorithm/) +//! - [Prefix Functions and KPM](https://oi-wiki.org/string/kmp/) +/// Finds the shortest palindrome that can be formed by adding characters +/// to the left of the given string `s`. +/// +/// # Arguments +/// +/// * `s` - A string slice that holds the input string. +/// +/// # Returns +/// +/// Returns a new string that is the shortest palindrome, formed by adding +/// the necessary characters to the beginning of `s`. pub fn shortest_palindrome(s: &str) -> String { if s.is_empty() { return "".to_string(); } - let p_chars: Vec = s.chars().collect(); - let suffix = raw_suffix_function(&p_chars); + let original_chars: Vec = s.chars().collect(); + let suffix_table = compute_suffix(&original_chars); - let mut s_chars: Vec = s.chars().rev().collect(); - // The prefix of the original string matches the suffix of the flipped string. - let dp = invert_suffix_function(&p_chars, &s_chars, &suffix); + let mut reversed_chars: Vec = s.chars().rev().collect(); + // The prefix of the original string matches the suffix of the reversed string. + let prefix_match = compute_prefix_match(&original_chars, &reversed_chars, &suffix_table); - s_chars.append(&mut p_chars[dp[p_chars.len() - 1]..p_chars.len()].to_vec()); - s_chars.iter().collect() + reversed_chars.append(&mut original_chars[prefix_match[original_chars.len() - 1]..].to_vec()); + reversed_chars.iter().collect() } -pub fn raw_suffix_function(p_chars: &[char]) -> Vec { - let mut suffix = vec![0; p_chars.len()]; - for i in 1..p_chars.len() { +/// Computes the suffix table used for the KMP (Knuth-Morris-Pratt) string +/// matching algorithm. +/// +/// # Arguments +/// +/// * `chars` - A slice of characters for which the suffix table is computed. +/// +/// # Returns +/// +/// Returns a vector of `usize` representing the suffix table. Each element +/// at index `i` indicates the longest proper suffix which is also a proper +/// prefix of the substring `chars[0..=i]`. +pub fn compute_suffix(chars: &[char]) -> Vec { + let mut suffix = vec![0; chars.len()]; + for i in 1..chars.len() { let mut j = suffix[i - 1]; - while j > 0 && p_chars[j] != p_chars[i] { + while j > 0 && chars[j] != chars[i] { j = suffix[j - 1]; } - suffix[i] = j + if p_chars[j] == p_chars[i] { 1 } else { 0 }; + suffix[i] = j + if chars[j] == chars[i] { 1 } else { 0 }; } suffix } -pub fn invert_suffix_function(p_chars: &[char], s_chars: &[char], suffix: &[usize]) -> Vec { - let mut dp = vec![0; p_chars.len()]; - dp[0] = if p_chars[0] == s_chars[0] { 1 } else { 0 }; - for i in 1..p_chars.len() { - let mut j = dp[i - 1]; - while j > 0 && s_chars[i] != p_chars[j] { +/// Computes the prefix matches of the original string against its reversed +/// version using the suffix table. +/// +/// # Arguments +/// +/// * `original` - A slice of characters representing the original string. +/// * `reversed` - A slice of characters representing the reversed string. +/// * `suffix` - A slice containing the suffix table computed for the original string. +/// +/// # Returns +/// +/// Returns a vector of `usize` where each element at index `i` indicates the +/// length of the longest prefix of `original` that matches a suffix of +/// `reversed[0..=i]`. +pub fn compute_prefix_match(original: &[char], reversed: &[char], suffix: &[usize]) -> Vec { + let mut match_table = vec![0; original.len()]; + match_table[0] = if original[0] == reversed[0] { 1 } else { 0 }; + for i in 1..original.len() { + let mut j = match_table[i - 1]; + while j > 0 && reversed[i] != original[j] { j = suffix[j - 1]; } - dp[i] = j + if s_chars[i] == p_chars[j] { 1 } else { 0 }; + match_table[i] = j + if reversed[i] == original[j] { 1 } else { 0 }; } - dp + match_table } #[cfg(test)] mod tests { - use crate::string::shortest_palindrome; + use super::*; + use crate::string::is_palindrome; + macro_rules! test_shortest_palindrome { - ($($name:ident: $inputs:expr,)*) => { - $( - #[test] - fn $name() { - use crate::string::is_palindrome; - let (s, expected) = $inputs; - assert!(is_palindrome(expected)); - assert_eq!(shortest_palindrome(s), expected); - assert_eq!(shortest_palindrome(expected), expected); - } - )* + ($($name:ident: $inputs:expr,)*) => { + $( + #[test] + fn $name() { + let (input, expected) = $inputs; + assert!(is_palindrome(expected)); + assert_eq!(shortest_palindrome(input), expected); + assert_eq!(shortest_palindrome(expected), expected); + } + )* } } + test_shortest_palindrome! { empty: ("", ""), extend_left_1: ("aacecaaa", "aaacecaaa"), extend_left_2: ("abcd", "dcbabcd"), unicode_1: ("അ", "അ"), unicode_2: ("a牛", "牛a牛"), + single_char: ("x", "x"), + already_palindrome: ("racecar", "racecar"), + extend_left_3: ("abcde", "edcbabcde"), + extend_left_4: ("abca", "acbabca"), + long_string: ("abcdefg", "gfedcbabcdefg"), + repetitive: ("aaaaa", "aaaaa"), + complex: ("abacdfgdcaba", "abacdgfdcabacdfgdcaba"), } }