From 875fdedeef3ebc8db0f90771eac35c33614eff49 Mon Sep 17 00:00:00 2001 From: Bastian Schmidt Date: Thu, 20 Jun 2024 14:27:17 +0200 Subject: [PATCH] Implement CellRntiRingBuffer and fix sending delay * Add RingBuffer to count most frequent RNTIs over mutliple iterations * Move mathematic utilitites from util to new module math_util * Add MatchingErrorHandling -> Reset traffic generator on matching error -> Make math functions return Result<> Debug CPU usage of rntimatcher.gen thread: In the traffic_generator thread, the TrafficPattern was in every sending iteration (a) cloned and (b) wrapped into a new Box instance which lead to high frequent memory allocation. This induced a high delay on the traffic pattern sending and high CPU usage of ~80% for the traffic_generator thread. With this fix, the sending rate for pattern A increased from ~800 Kbit/s to ~12 Mbit/s!!! --- src/logic/mod.rs | 4 +- src/logic/rnti_matcher.rs | 238 ++++++++++++++++++++++------------ src/logic/traffic_patterns.rs | 25 ++-- src/main.rs | 1 + src/math_util.rs | 134 +++++++++++++++++++ src/util.rs | 203 +++++++++++------------------ 6 files changed, 384 insertions(+), 221 deletions(-) create mode 100644 src/math_util.rs diff --git a/src/logic/mod.rs b/src/logic/mod.rs index fe1ab68..8a5ceb3 100644 --- a/src/logic/mod.rs +++ b/src/logic/mod.rs @@ -184,6 +184,8 @@ impl RntiMatcherState { #[derive(Clone, Debug, PartialEq)] pub enum RntiMatchingErrorType { ExceededDciTimestampDelta, + ErrorGeneratingTrafficPatternFeatures, + ErrorFindingBestMatchingRnti, } impl WorkerState for RntiMatcherState { @@ -250,7 +252,7 @@ pub struct MessageCellInfo { #[derive(Clone, Debug, PartialEq, Default)] pub struct MessageRnti { /* cell_id -> ue_rnti */ - cell_rnti: HashMap, + cell_rnti: HashMap, } /* -------------- */ diff --git a/src/logic/rnti_matcher.rs b/src/logic/rnti_matcher.rs index 403af71..da36867 100644 --- a/src/logic/rnti_matcher.rs +++ b/src/logic/rnti_matcher.rs @@ -20,11 +20,16 @@ use crate::ngscope::types::NgScopeCellDci; use crate::parse::{Arguments, FlattenedRntiMatchingArgs}; use crate::util::{ + CellRntiRingBuffer, + log_rnti_matching_traffic, print_debug, print_info, determine_process_id, +}; + +use crate::math_util::{ calculate_mean_variance, calculate_median, calculate_weighted_euclidean_distance, - log_rnti_matching_traffic, print_debug, print_info, standardize_feature_vec, - determine_process_id, calculate_weighted_euclidean_distance_matrix, + calculate_weighted_euclidean_distance_matrix, standardize_feature_vec, }; + pub const MATCHING_INTERVAL_MS: u64 = 1000; pub const MATCHING_TRAFFIC_PATTERN_TIME_OVERLAP_FACTOR: f64 = 1.1; pub const MATCHING_MAX_DCI_TIMESTAMP_DELTA_MS: u64 = 100; @@ -40,39 +45,53 @@ pub const BASIC_FILTER_MIN_TOTAL_UL_FACTOR: f64 = 0.005; pub const BASIC_FILTER_MAX_UL_PER_DCI: u64 = 5_000_000; pub const BASIC_FILTER_MIN_OCCURENCES_FACTOR: f64 = 0.005; - /* - * Feature vector, order matters: - * - * DCI count (occurences) - * Total UL bytes - * UL bytes median - * UL bytes mean - * UL bytes variance - * DCI timestamp delta median - * DCI timestamp delta mean - * DCI timestamp delta variance - * */ -/* not as good as all-weighted one */ +pub const RNTI_RING_BUFFER_SIZE: usize = 5; + + +/* + * Feature vector, order matters: + * + * DCI count (occurences) + * Total UL bytes + * UL bytes median + * UL bytes mean + * UL bytes variance + * DCI timestamp delta median + * DCI timestamp delta mean + * DCI timestamp delta variance + * */ // pub const MATCHING_WEIGHTINGS: [f64; 8] = [ // 0.5, /* DCI count (occurences) */ -// 0.0, /* Total UL bytes */ -// 0.5, /* UL bytes median */ -// 0.0, /* UL bytes mean */ -// 0.0, /* UL bytes variance */ -// 0.0, /* DCI time delta median */ -// 0.0, /* DCI time delta mean */ -// 0.0, /* DCI time delta variance */ +// 0.1, /* Total UL bytes */ +// 0.15, /* UL bytes median */ +// 0.025, /* UL bytes mean */ +// 0.025, /* UL bytes variance */ +// 0.15, /* DCI time delta median */ +// 0.025, /* DCI time delta mean */ +// 0.025, /* DCI time delta variance */ +// ]; + +/* on D, not so nice */ +// pub const MATCHING_WEIGHTINGS: [f64; 8] = [ +// 0.3, /* DCI count (occurences) */ +// 0.3, /* Total UL bytes */ +// 0.1, /* UL bytes median */ +// 0.2, /* UL bytes mean */ +// 0.025, /* UL bytes variance */ +// 0.025, /* DCI time delta median */ +// 0.025, /* DCI time delta mean */ +// 0.025, /* DCI time delta variance */ // ]; pub const MATCHING_WEIGHTINGS: [f64; 8] = [ 0.5, /* DCI count (occurences) */ - 0.1, /* Total UL bytes */ - 0.15, /* UL bytes median */ - 0.025, /* UL bytes mean */ - 0.025, /* UL bytes variance */ - 0.15, /* DCI time delta median */ - 0.025, /* DCI time delta mean */ - 0.025, /* DCI time delta variance */ + 0.3, /* Total UL bytes */ + 0.1, /* UL bytes median */ + 0.020, /* UL bytes mean */ + 0.020, /* UL bytes variance */ + 0.020, /* DCI time delta median */ + 0.020, /* DCI time delta mean */ + 0.020, /* DCI time delta variance */ ]; #[derive(Clone, Debug, PartialEq)] @@ -176,6 +195,7 @@ fn run(run_args: &mut RunArgs) -> Result<()> { )?); run_args.tx_gen_thread_handle = Some(tx_gen_thread.clone()); + let mut cell_rnti_ring_buffer: CellRntiRingBuffer = CellRntiRingBuffer::new(RNTI_RING_BUFFER_SIZE); let traffic_destination = matching_args.matching_traffic_destination; let traffic_pattern = matching_args.matching_traffic_pattern.generate_pattern(); let matching_log_file_path = &format!( @@ -213,7 +233,9 @@ fn run(run_args: &mut RunArgs) -> Result<()> { handle_collect_dci(latest_dcis, *traffic_collection) } RntiMatcherState::MatchingProcessDci(traffic_collection) => { - handle_process_dci(*traffic_collection, matching_log_file_path) + handle_process_dci(*traffic_collection, + matching_log_file_path, + &mut cell_rnti_ring_buffer) } RntiMatcherState::MatchingPublishRnti(rnti) => { tx_rnti.broadcast(rnti); @@ -222,7 +244,10 @@ fn run(run_args: &mut RunArgs) -> Result<()> { Box::new(RntiMatcherState::StartMatching), ) } - RntiMatcherState::MatchingError(error_type) => handle_matching_error(error_type), + RntiMatcherState::MatchingError(error_type) => handle_matching_error( + error_type, + &tx_gen_thread, + ), RntiMatcherState::SleepMs(time_ms, next_state) => { thread::sleep(Duration::from_millis(time_ms)); *next_state @@ -255,12 +280,18 @@ fn handle_start_matching( let start_timestamp_ms = chrono::Utc::now().timestamp_millis() as u64; let finish_timestamp_ms = start_timestamp_ms + (MATCHING_TRAFFIC_PATTERN_TIME_OVERLAP_FACTOR * pattern_total_ms as f64) as u64; + let traffic_pattern_features = match TrafficPatternFeatures::from_traffic_pattern(&traffic_pattern) { + Ok(features) => features, + Err(_) => { + return RntiMatcherState::MatchingError(RntiMatchingErrorType::ErrorGeneratingTrafficPatternFeatures); + } + }; let traffic_collection: TrafficCollection = TrafficCollection { cell_traffic: Default::default(), start_timestamp_ms, finish_timestamp_ms, - traffic_pattern_features: TrafficPatternFeatures::from_traffic_pattern(&traffic_pattern) + traffic_pattern_features, }; let _ = tx_gen_thread.send(LocalGeneratorState::SendPattern( @@ -293,6 +324,7 @@ fn handle_collect_dci( fn handle_process_dci( mut traffic_collection: TrafficCollection, log_file_path: &str, + cell_rnti_ring_buffer: &mut CellRntiRingBuffer, ) -> RntiMatcherState { // Check number of packets plausability: expected ms -> expected dcis let mut message_rnti: MessageRnti = MessageRnti::default(); @@ -302,15 +334,36 @@ fn handle_process_dci( traffic_collection.apply_basic_filter(); - message_rnti.cell_rnti = traffic_collection.find_best_matching_rnti(); + let best_matches = match traffic_collection.find_best_matching_rnti() { + Ok(matches) => matches, + Err(e) => { + print_info(&format!("[rntimatcher] Error during handle_process_dci: {:?}", e)); + return RntiMatcherState::MatchingError(RntiMatchingErrorType::ErrorFindingBestMatchingRnti) + } + }; + cell_rnti_ring_buffer.update(&best_matches); + print_debug(&format!("DEBUG [rntimatcher] cell_rnti_ring_buffer: {:#?}", cell_rnti_ring_buffer)); + message_rnti.cell_rnti = cell_rnti_ring_buffer.most_frequent(); RntiMatcherState::MatchingPublishRnti(message_rnti) } -fn handle_matching_error(error_type: RntiMatchingErrorType) -> RntiMatcherState { - print_info(&format!( - "[rntimatcher] error during RNTI matching: {:?}\n -> stopping pattern", - error_type - )); +fn handle_matching_error( + error_type: RntiMatchingErrorType, + tx_gen_thread: &SyncSender, +) -> RntiMatcherState { + + match error_type { + RntiMatchingErrorType::ExceededDciTimestampDelta => {}, + RntiMatchingErrorType::ErrorGeneratingTrafficPatternFeatures | + RntiMatchingErrorType::ErrorFindingBestMatchingRnti => { + print_info(&format!( + "[rntimatcher] error during RNTI matching: {:?}\n -> going back to Idle", + error_type + )); + let _ = tx_gen_thread.send(LocalGeneratorState::Idle); + } + } + RntiMatcherState::SleepMs( MATCHING_INTERVAL_MS, Box::new(RntiMatcherState::StartMatching), @@ -352,20 +405,32 @@ fn run_traffic_generator( )); loop { - if let Some(new_state) = check_rx_state(&rx_local_gen_state)? { - gen_state = new_state; + match check_rx_state(&rx_local_gen_state) { + Ok(Some(new_state)) => gen_state = new_state, + Ok(None) => {}, + Err(e) => { + print_info(&format!("{}", e)); + break; + } } match gen_state { LocalGeneratorState::Idle => { - /* Idle here, because it shall not interfere the sendpattern */ + /* Sleep here, because it shall not interfere the sendpattern */ thread::sleep(Duration::from_millis(DEFAULT_WORKER_SLEEP_MS)); } LocalGeneratorState::Stop => { break; } - LocalGeneratorState::SendPattern(destination, pattern) => { - gen_state = gen_handle_send_pattern(&socket, &destination, *pattern.clone()); + LocalGeneratorState::SendPattern(ref destination, ref mut pattern) => { + match gen_handle_send_pattern(&socket, destination, pattern) { + Ok(Some(_)) => { /* stay in the state and keep sending */ }, + Ok(None) => gen_state = LocalGeneratorState::PatternSent, + Err(e) => { + print_info(&format!("[rntimatcher.gen] Error occured while sendig the pattern: {:?}", e)); + gen_state = LocalGeneratorState::Stop; + } + } } LocalGeneratorState::PatternSent => { print_info("[rntimatcher.gen] Finished sending pattern!"); @@ -401,15 +466,15 @@ fn check_rx_state( fn gen_handle_send_pattern( socket: &UdpSocket, destination: &str, - mut pattern: TrafficPattern, -) -> LocalGeneratorState { + pattern: &mut TrafficPattern, +) -> Result> { match pattern.messages.pop_front() { Some(msg) => { thread::sleep(Duration::from_millis(msg.time_ms as u64)); - let _ = socket.send_to(&msg.payload, destination); - LocalGeneratorState::SendPattern(destination.to_string(), Box::new(pattern)) + socket.send_to(&msg.payload, destination)?; + Ok(Some(())) } - None => LocalGeneratorState::PatternSent, + None => Ok(None) } } @@ -536,11 +601,12 @@ impl TrafficCollection { }) /* ZERO MEDIAN */ .filter(|(_, ue_traffic)| { - if ue_traffic.feature_ul_bytes_median_mean_variance().0 < 0.0 { - stats.zero_ul_median += 1; - false - } else { - true + match ue_traffic.feature_ul_bytes_median_mean_variance() { + Ok((median, _, _)) if median <= 0.0 => true, + _ => { + stats.zero_ul_median += 1; + false + } } }) .map(|(&rnti, _)| rnti) @@ -570,7 +636,7 @@ impl TrafficCollection { * cell_id -> { (rnti, distance ) } * * */ - pub fn find_best_matching_rnti(&self) -> HashMap { + pub fn find_best_matching_rnti(&self) -> Result> { let pattern_std_vec = &self.traffic_pattern_features.std_vec; let pattern_feature_vec = &self.traffic_pattern_features.std_feature_vec; /* Change this to use the functional approach */ @@ -592,10 +658,10 @@ impl UeTraffic { * DCI timestamp delta mean * DCI timestamp delta variance * */ - pub fn generate_standardized_feature_vec(&self, std_vec: &[(f64, f64)]) -> Vec { + pub fn generate_standardized_feature_vec(&self, std_vec: &[(f64, f64)]) -> Result> { let mut non_std_feature_vec = vec![]; - let (ul_median, ul_mean, ul_variance) = self.feature_ul_bytes_median_mean_variance(); - let (tx_median, tx_mean, tx_variance) = self.feature_dci_time_delta_median_mean_variance(); + let (ul_median, ul_mean, ul_variance) = self.feature_ul_bytes_median_mean_variance()?; + let (tx_median, tx_mean, tx_variance) = self.feature_dci_time_delta_median_mean_variance()?; non_std_feature_vec.push(self.feature_dci_count()); non_std_feature_vec.push(self.feature_total_ul_bytes()); @@ -606,7 +672,7 @@ impl UeTraffic { non_std_feature_vec.push(tx_mean); non_std_feature_vec.push(tx_variance); - standardize_feature_vec(&non_std_feature_vec, std_vec) + Ok(standardize_feature_vec(&non_std_feature_vec, std_vec)) } pub fn feature_total_ul_bytes(&self) -> f64 { @@ -617,7 +683,7 @@ impl UeTraffic { self.traffic.len() as f64 } - pub fn feature_dci_time_delta_median_mean_variance(&self) -> (f64, f64, f64) { + pub fn feature_dci_time_delta_median_mean_variance(&self) -> Result<(f64, f64, f64)> { let mut sorted_timestamps: Vec = self.traffic.keys().cloned().collect(); sorted_timestamps.sort_by(|a, b| a.partial_cmp(b).unwrap()); let timestamp_deltas: Vec = sorted_timestamps @@ -625,22 +691,22 @@ impl UeTraffic { .map(|window| (window[1] - window[0]) as f64) .collect(); - let (mean, variance) = calculate_mean_variance(×tamp_deltas); - let median = calculate_median(×tamp_deltas); + let (mean, variance) = calculate_mean_variance(×tamp_deltas)?; + let median = calculate_median(×tamp_deltas)?; - (median, mean, variance) + Ok((median, mean, variance)) } - pub fn feature_ul_bytes_median_mean_variance(&self) -> (f64, f64, f64) { + pub fn feature_ul_bytes_median_mean_variance(&self) -> Result<(f64, f64, f64)> { let ul_bytes: Vec = self .traffic .values() .map(|ul_dl_traffic| ul_dl_traffic.ul_bytes as f64) .collect(); - let (mean, variance) = calculate_mean_variance(&ul_bytes); - let median = calculate_median(&ul_bytes); + let (mean, variance) = calculate_mean_variance(&ul_bytes)?; + let median = calculate_median(&ul_bytes)?; - (median, mean, variance) + Ok((median, mean, variance)) } } @@ -657,53 +723,63 @@ fn feature_distance_functional( traffic: &HashMap, pattern_std_vec: &[(f64, f64)], pattern_feature_vec: &[f64] -) -> HashMap { +) -> Result> { traffic.iter() .map(|(&cell_id, cell_traffic)| { - let mut rnti_and_distance: Vec<(u16, f64)> = cell_traffic.traffic + let rnti_and_distance: Result> = cell_traffic.traffic .iter() .map(|(&rnti, ue_traffic)| { - ( - rnti, - calculate_weighted_euclidean_distance( + let std_feature_vec = ue_traffic.generate_standardized_feature_vec(pattern_std_vec)?; + let distance = calculate_weighted_euclidean_distance( pattern_feature_vec, - &ue_traffic.generate_standardized_feature_vec(pattern_std_vec), + &std_feature_vec, &MATCHING_WEIGHTINGS, - ), - ) + ); + Ok((rnti, distance)) }) - .collect(); + .collect::>>(); + let mut rnti_and_distance = rnti_and_distance?; rnti_and_distance.sort_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap()); - (cell_id, *rnti_and_distance.first().unwrap()) + Ok((cell_id, rnti_and_distance.first().unwrap().0)) }) - .collect() + .collect::>>() } fn feature_distance_matrices( traffic: &HashMap, pattern_std_vec: &[(f64, f64)], pattern_feature_vec: &[f64] -) -> HashMap { +) -> Result> { let num_features = pattern_std_vec.len(); let weightings_vector = DVector::from_row_slice(&MATCHING_WEIGHTINGS); + traffic.iter() .map(|(&cell_id, cell_traffic)| { - let standardized_feature_vecs: Vec> = cell_traffic.traffic + let standardized_feature_vecs: Result>> = cell_traffic.traffic .values() .map(|ue_traffic| { ue_traffic.generate_standardized_feature_vec(pattern_std_vec) + .map_err(|e| anyhow!(e)) }) .collect(); + let standardized_feature_vecs = standardized_feature_vecs?; let num_vectors = standardized_feature_vecs.len(); let data: Vec = standardized_feature_vecs.into_iter().flatten().collect(); let feature_matrix: DMatrix = DMatrix::from_row_slice(num_vectors, num_features, &data); + + // Uncomment and implement debug print if needed + // print_debug(&format!("DEBUG [rntimatcher] feature_matrix: {:.2}", feature_matrix)); + let pattern_feature_matrix = DMatrix::from_fn(num_vectors, num_features, |_, r| pattern_feature_vec[r]); let euclidean_distances = calculate_weighted_euclidean_distance_matrix( &pattern_feature_matrix, &feature_matrix, &weightings_vector); + + // Uncomment and implement debug print if needed + // print_debug(&format!("DEBUG [rntimatcher] distances: {:.2}", euclidean_distances)); let mut rnti_and_distance: Vec<(u16, f64)> = cell_traffic.traffic.keys() .cloned() @@ -712,9 +788,9 @@ fn feature_distance_matrices( rnti_and_distance.sort_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap()); - (cell_id, *rnti_and_distance.first().unwrap()) + Ok((cell_id, rnti_and_distance.first().unwrap().0)) }) - .collect() + .collect::>>() } diff --git a/src/logic/traffic_patterns.rs b/src/logic/traffic_patterns.rs index 9f0f347..1260f88 100644 --- a/src/logic/traffic_patterns.rs +++ b/src/logic/traffic_patterns.rs @@ -1,16 +1,17 @@ use std::collections::VecDeque; +use anyhow::Result; use clap::ValueEnum; use serde::{Deserialize, Serialize}; -use crate::util::{calculate_mean_variance, calculate_median, standardize_feature_vec}; +use crate::math_util::{calculate_mean_variance, calculate_median, standardize_feature_vec}; #[derive( Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug, Serialize, Deserialize, Default, )] pub enum RntiMatchingTrafficPatternType { #[default] - A, /* t: 24 sec, 1KB packets, 1ms interval => ? Mbit/s */ + A, /* t: 24 sec, 1KB packets, 1ms interval => ? Mbit/s */ B, /* t: 24 sec, 2KB packets, 5ms interval => ? Mbit/s */ C, /* t: 24 sec, 4KB packets, 5ms interval => ? Mbit/s */ D, /* t: 24 sec, 8KB packets, 5ms interval => ~ 5.8 Mbit/s */ @@ -103,14 +104,14 @@ impl RntiMatchingTrafficPatternType { impl TrafficPatternFeatures { - pub fn from_traffic_pattern(pattern: &TrafficPattern) -> TrafficPatternFeatures { - TrafficPatternFeatures { + pub fn from_traffic_pattern(pattern: &TrafficPattern) -> Result { + Ok(TrafficPatternFeatures { pattern_type: pattern.pattern_type, std_vec: pattern.std_vec.clone(), - std_feature_vec: pattern.generate_standardized_feature_vec(), + std_feature_vec: pattern.generate_standardized_feature_vec()?, total_ul_bytes: pattern.total_ul_bytes(), nof_packets: pattern.nof_packets(), - } + }) } } @@ -142,7 +143,7 @@ impl TrafficPattern { * DCI timestamp delta mean * DCI timestamp delta variance * */ - pub fn generate_standardized_feature_vec(&self) -> Vec { + pub fn generate_standardized_feature_vec(&self) -> Result> { let packet_sizes: Vec = self.messages .iter() .map(|t| t.payload.len() as f64) @@ -152,10 +153,10 @@ impl TrafficPattern { .map(|m| m.time_ms as f64) .collect::>(); - let (ul_mean, ul_variance) = calculate_mean_variance(&packet_sizes); - let ul_median = calculate_median(&packet_sizes); - let (tx_mean, tx_variance) = calculate_mean_variance(&time_deltas); - let tx_median = calculate_median(&time_deltas); + let (ul_mean, ul_variance) = calculate_mean_variance(&packet_sizes)?; + let ul_median = calculate_median(&packet_sizes)?; + let (tx_mean, tx_variance) = calculate_mean_variance(&time_deltas)?; + let tx_median = calculate_median(&time_deltas)?; let non_std_feature_vec: Vec = vec![ packet_sizes.len() as f64, @@ -168,7 +169,7 @@ impl TrafficPattern { tx_variance, ]; - standardize_feature_vec(&non_std_feature_vec, &self.std_vec) + Ok(standardize_feature_vec(&non_std_feature_vec, &self.std_vec)) } } diff --git a/src/main.rs b/src/main.rs index 94bee84..8b1eb8b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,7 @@ mod logic; mod ngscope; mod parse; mod util; +mod math_util; use logic::cell_sink::{deploy_cell_sink, CellSinkArgs}; use logic::cell_source::{deploy_cell_source, CellSourceArgs}; diff --git a/src/math_util.rs b/src/math_util.rs new file mode 100644 index 0000000..02b1c2f --- /dev/null +++ b/src/math_util.rs @@ -0,0 +1,134 @@ + +use anyhow::{anyhow, Result}; +use nalgebra::{DVector, DMatrix}; + + +/* Feature Matching */ + +pub fn calculate_mean_variance(list: &[f64]) -> Result<(f64, f64)> { + let total_packets = list.len() as f64; + + if total_packets <= 0.0 { + return Err(anyhow!("Cannot determine mean/variance of 0 length array")); + } + + let mean = list.iter().sum::() / total_packets; + + let variance = list + .iter() + .map(|&item| { + let diff = item - mean; + diff * diff + }) + .sum::() + / total_packets; + + Ok((mean, variance)) +} + +pub fn calculate_median(list: &[f64]) -> Result { + let len: usize = list.len(); + if len == 0 { + return Err(anyhow!("Cannot determine median of 0 length array")); + } + let mut sorted_list: Vec = list.to_vec(); + sorted_list.sort_by(|a, b| a.partial_cmp(b).unwrap()); // Handle NaN values safely + if len % 2 == 0 { + // If the length is even, return the average of the two middle elements + let mid1 = sorted_list[len / 2 - 1]; + let mid2 = sorted_list[len / 2]; + Ok((mid1 + mid2) * 0.5) + } else { + // If the length is odd, return the middle element + Ok(sorted_list[len / 2]) + } +} + +#[allow(dead_code)] +pub fn calculate_weighted_manhattan_distance( + vec_a: &[f64], + vec_b: &[f64], + weightings: &[f64], +) -> f64 { + assert_eq!( + vec_a.len(), + vec_b.len(), + "Calcuting Euclidean distance: Vectors must have the same length" + ); + assert_eq!( + vec_a.len(), + weightings.len(), + "Calcuting Euclidean distance: Vectors and weightings must have the same length" + ); + + vec_a + .iter() + .zip(vec_b.iter()) + .zip(weightings.iter()) + .fold(0.0, |acc, ((&a, &b), &w)| acc + w * (a - b).abs()) +} + +pub fn calculate_weighted_euclidean_distance( + vec_a: &[f64], + vec_b: &[f64], + weightings: &[f64], +) -> f64 { + assert_eq!( + vec_a.len(), + vec_b.len(), + "Calcuting Euclidean distance: Vectors must have the same length" + ); + assert_eq!( + vec_a.len(), + weightings.len(), + "Calcuting Euclidean distance: Vectors and weightings must have the same length" + ); + + let sum_of_squared_diff: f64 = vec_a + .iter() + .zip(vec_b.iter()) + .zip(weightings.iter()) + .map(|((&a, &b), &w)| w * (a - b).powi(2)) + .sum(); + + sum_of_squared_diff.sqrt() +} + +#[allow(dead_code)] +pub fn calculate_weighted_manhattan_distance_matrix( + matr_a: &DMatrix, + matr_b: &DMatrix, + weightings: &DVector, +) -> DVector { + assert_eq!( + (matr_a.nrows(), matr_a.ncols()), + (matr_b.nrows(), matr_b.ncols()), + "Calcuting Euclidean distance: Matrices must have the same dimensions" + ); + let diff_matrix = (matr_a - matr_b).abs(); + diff_matrix * weightings +} + +pub fn calculate_weighted_euclidean_distance_matrix( + matr_a: &DMatrix, + matr_b: &DMatrix, + weightings: &DVector, +) -> DVector { + assert_eq!( + (matr_a.nrows(), matr_a.ncols()), + (matr_b.nrows(), matr_b.ncols()), + "Calcuting Euclidean distance: Matrices must have the same dimensions" + ); + let diff_matrix = matr_a - matr_b; + let weighted_squared_diff_vector = diff_matrix.component_mul(&diff_matrix) * weightings; + + weighted_squared_diff_vector.map(|x| x.sqrt()) +} + +pub fn standardize_feature_vec(feature_vec: &[f64], std_vec: &[(f64, f64)]) -> Vec { + feature_vec + .iter() + .zip(std_vec.iter()) + .map(|(&feature, &(mean, std_deviation))| (feature - mean) / std_deviation) + .collect() +} diff --git a/src/util.rs b/src/util.rs index 05b28d2..440f13a 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,16 +1,88 @@ #![allow(dead_code)] -use anyhow::{anyhow, Result}; -use casual_logger::{Level, Log}; -use lazy_static::lazy_static; -use nalgebra::{DMatrix, DVector}; +use std::hash::Hash; +use std::collections::{VecDeque, HashMap}; use std::fs::{File, OpenOptions}; use std::io::Write; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; +use anyhow::{anyhow, Result}; +use casual_logger::{Level, Log}; +use lazy_static::lazy_static; use crate::logic::rnti_matcher::TrafficCollection; +#[derive(Clone, Debug, Default)] +pub struct RingBuffer { + buffer: VecDeque, + size: usize, +} + +#[derive(Clone, Debug, Default)] +pub struct CellRntiRingBuffer { + cell_buffers: HashMap>, + size: usize, +} + +impl RingBuffer +where + T: Eq + Hash + Clone, +{ + pub fn new(size: usize) -> Self { + RingBuffer { + buffer: VecDeque::with_capacity(size), + size, + } + } + + pub fn add(&mut self, value: T) { + if self.buffer.len() == self.size { + self.buffer.pop_front(); // Remove the oldest value if the buffer is full + } + self.buffer.push_back(value); // Add the new value + } + + pub fn most_frequent(&self) -> Option { + let mut frequency_map = std::collections::HashMap::new(); + for value in self.buffer.iter() { + *frequency_map.entry(value).or_insert(0) += 1; + } + frequency_map + .into_iter() + .max_by_key(|&(_, count)| count) + .map(|(value, _)| value.clone()) + } +} + +impl CellRntiRingBuffer { + pub fn new(size: usize) -> CellRntiRingBuffer { + CellRntiRingBuffer { + size, + cell_buffers: HashMap::new(), + } + } + + pub fn update(&mut self, cell_rntis: &HashMap) { + for (&cell, &rnti) in cell_rntis.iter() { + let cell_buffer = self.cell_buffers + .entry(cell) + .or_insert_with(|| RingBuffer::new(self.size)); + cell_buffer.add(rnti); + } + } + + pub fn most_frequent(&self) -> HashMap { + let mut cell_rntis: HashMap = HashMap::new(); + for (&cell, cell_buffer) in self.cell_buffers.iter() { + if let Some(most_frequent_rnti) = cell_buffer.most_frequent() { + cell_rntis.insert(cell, most_frequent_rnti); + } + } + cell_rntis + } +} + + pub fn prepare_sigint_notifier() -> Result> { let notifier = Arc::new(AtomicBool::new(false)); let r = notifier.clone(); @@ -122,127 +194,4 @@ pub fn is_debug() -> bool { IS_DEBUG.load(Ordering::SeqCst) } -/* Feature Matching */ - -pub fn calculate_mean_variance(list: &[f64]) -> (f64, f64) { - let total_packets = list.len() as f64; - - if total_packets == 0.0 { - return (0.0, 0.0); - } - let mean = list.iter().sum::() / total_packets; - - let variance = list - .iter() - .map(|&item| { - let diff = item - mean; - diff * diff - }) - .sum::() - / total_packets; - - (mean, variance) -} - -pub fn calculate_median(list: &[f64]) -> f64 { - let mut sorted_list: Vec = list.to_vec(); - sorted_list.sort_by(|a, b| a.partial_cmp(b).unwrap()); // Handle NaN values safely - let len = sorted_list.len(); - if len % 2 == 0 { - // If the length is even, return the average of the two middle elements - let mid1 = sorted_list[len / 2 - 1]; - let mid2 = sorted_list[len / 2]; - (mid1 + mid2) * 0.5 - } else { - // If the length is odd, return the middle element - sorted_list[len / 2] - } -} - -pub fn calculate_weighted_manhattan_distance( - vec_a: &[f64], - vec_b: &[f64], - weightings: &[f64], -) -> f64 { - assert_eq!( - vec_a.len(), - vec_b.len(), - "Calcuting Euclidean distance: Vectors must have the same length" - ); - assert_eq!( - vec_a.len(), - weightings.len(), - "Calcuting Euclidean distance: Vectors and weightings must have the same length" - ); - - vec_a - .iter() - .zip(vec_b.iter()) - .zip(weightings.iter()) - .fold(0.0, |acc, ((&a, &b), &w)| acc + w * (a - b).abs()) -} - -pub fn calculate_weighted_euclidean_distance( - vec_a: &[f64], - vec_b: &[f64], - weightings: &[f64], -) -> f64 { - assert_eq!( - vec_a.len(), - vec_b.len(), - "Calcuting Euclidean distance: Vectors must have the same length" - ); - assert_eq!( - vec_a.len(), - weightings.len(), - "Calcuting Euclidean distance: Vectors and weightings must have the same length" - ); - - let sum_of_squared_diff: f64 = vec_a - .iter() - .zip(vec_b.iter()) - .zip(weightings.iter()) - .map(|((&a, &b), &w)| w * (a - b).powi(2)) - .sum(); - - sum_of_squared_diff.sqrt() -} - -pub fn calculate_weighted_manhattan_distance_matrix( - matr_a: &DMatrix, - matr_b: &DMatrix, - weightings: &DVector, -) -> DVector { - assert_eq!( - (matr_a.nrows(), matr_a.ncols()), - (matr_b.nrows(), matr_b.ncols()), - "Calcuting Euclidean distance: Matrices must have the same dimensions" - ); - let diff_matrix = (matr_a - matr_b).abs(); - diff_matrix * weightings -} - -pub fn calculate_weighted_euclidean_distance_matrix( - matr_a: &DMatrix, - matr_b: &DMatrix, - weightings: &DVector, -) -> DVector { - assert_eq!( - (matr_a.nrows(), matr_a.ncols()), - (matr_b.nrows(), matr_b.ncols()), - "Calcuting Euclidean distance: Matrices must have the same dimensions" - ); - let diff_matrix = matr_a - matr_b; - let weighted_squared_diff_vector = diff_matrix.component_mul(&diff_matrix) * weightings; - - weighted_squared_diff_vector.map(|x| x.sqrt()) -} - -pub fn standardize_feature_vec(feature_vec: &[f64], std_vec: &[(f64, f64)]) -> Vec { - feature_vec - .iter() - .zip(std_vec.iter()) - .map(|(&feature, &(mean, std_deviation))| (feature - mean) / std_deviation) - .collect() -}