Skip to content

Commit

Permalink
Reworked message and action types to be more direct.
Browse files Browse the repository at this point in the history
  • Loading branch information
davidv1992 authored and rnijveld committed Aug 8, 2024
1 parent a1ac540 commit c297ae0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 66 deletions.
51 changes: 22 additions & 29 deletions ntp-proto/src/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
use rand::{thread_rng, Rng};
use serde::{Deserialize, Serialize};
use std::{
fmt::Debug,
io::Cursor,
net::{IpAddr, SocketAddr},
time::Duration,
Expand Down Expand Up @@ -327,12 +328,12 @@ impl Default for ProtocolVersion {
}
}

pub struct NtpSourceUpdate<Controller: SourceController> {
pub struct NtpSourceUpdate<SourceMessage> {
pub(crate) snapshot: NtpSourceSnapshot,
pub(crate) message: Option<Controller::SourceMessage>,
pub(crate) message: Option<SourceMessage>,
}

impl<Controller: SourceController> std::fmt::Debug for NtpSourceUpdate<Controller> {
impl<SourceMessage: Debug> std::fmt::Debug for NtpSourceUpdate<SourceMessage> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NtpSourceUpdate")
.field("snapshot", &self.snapshot)
Expand All @@ -341,7 +342,7 @@ impl<Controller: SourceController> std::fmt::Debug for NtpSourceUpdate<Controlle
}
}

impl<Controller: SourceController> Clone for NtpSourceUpdate<Controller> {
impl<SourceMessage: Clone> Clone for NtpSourceUpdate<SourceMessage> {
fn clone(&self) -> Self {
Self {
snapshot: self.snapshot,
Expand All @@ -351,30 +352,22 @@ impl<Controller: SourceController> Clone for NtpSourceUpdate<Controller> {
}

#[cfg(feature = "__internal-test")]
impl<Controller: SourceController> NtpSourceUpdate<Controller> {
impl<SourceMessage> NtpSourceUpdate<SourceMessage> {
pub fn snapshot(snapshot: NtpSourceSnapshot) -> Self {
NtpSourceUpdate {
snapshot,
message: None,
}
}

// TODO: Cleanup
/*pub fn measurement(snapshot: NtpSourceSnapshot, measurement: Measurement) -> Self {
NtpSourceUpdate {
snapshot,
measurement: Some(measurement),
}
}*/
}

#[derive(Debug, Clone)]
#[allow(clippy::large_enum_variant)]
pub enum NtpSourceAction<Controller: SourceController> {
pub enum NtpSourceAction<SourceMessage> {
/// Send a message over the network. When this is issued, the network port maybe changed.
Send(Vec<u8>),
/// Send an update to [`System`](crate::system::System)
UpdateSystem(NtpSourceUpdate<Controller>),
UpdateSystem(NtpSourceUpdate<SourceMessage>),
/// Call [`NtpSource::handle_timer`] after given duration
SetTimer(Duration),
/// A complete reset of the connection is necessary, including a potential new NTSKE client session and/or DNS lookup.
Expand All @@ -384,28 +377,28 @@ pub enum NtpSourceAction<Controller: SourceController> {
}

#[derive(Debug)]
pub struct NtpSourceActionIterator<Controller: SourceController> {
iter: <Vec<NtpSourceAction<Controller>> as IntoIterator>::IntoIter,
pub struct NtpSourceActionIterator<SourceMessage> {
iter: <Vec<NtpSourceAction<SourceMessage>> as IntoIterator>::IntoIter,
}

impl<Controller: SourceController> Default for NtpSourceActionIterator<Controller> {
impl<SourceMessage> Default for NtpSourceActionIterator<SourceMessage> {
fn default() -> Self {
Self {
iter: vec![].into_iter(),
}
}
}

impl<Controller: SourceController> Iterator for NtpSourceActionIterator<Controller> {
type Item = NtpSourceAction<Controller>;
impl<SourceMessage> Iterator for NtpSourceActionIterator<SourceMessage> {
type Item = NtpSourceAction<SourceMessage>;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

impl<Controller: SourceController> NtpSourceActionIterator<Controller> {
fn from(data: Vec<NtpSourceAction<Controller>>) -> Self {
impl<SourceMessage> NtpSourceActionIterator<SourceMessage> {
fn from(data: Vec<NtpSourceAction<SourceMessage>>) -> Self {
Self {
iter: data.into_iter(),
}
Expand Down Expand Up @@ -438,7 +431,7 @@ impl<Controller: SourceController> NtpSource<Controller> {
source_defaults_config: SourceDefaultsConfig,
protocol_version: ProtocolVersion,
controller: Controller,
) -> (Self, NtpSourceActionIterator<Controller>) {
) -> (Self, NtpSourceActionIterator<Controller::SourceMessage>) {
(
Self {
nts: None,
Expand Down Expand Up @@ -477,7 +470,7 @@ impl<Controller: SourceController> NtpSource<Controller> {
protocol_version: ProtocolVersion,
controller: Controller,
nts: Box<SourceNtsData>,
) -> (Self, NtpSourceActionIterator<Controller>) {
) -> (Self, NtpSourceActionIterator<Controller::SourceMessage>) {
let (base, actions) = Self::new(
source_addr,
source_defaults_config,
Expand Down Expand Up @@ -511,7 +504,7 @@ impl<Controller: SourceController> NtpSource<Controller> {
}

#[cfg_attr(not(feature = "ntpv5"), allow(unused_mut))]
pub fn handle_timer(&mut self) -> NtpSourceActionIterator<Controller> {
pub fn handle_timer(&mut self) -> NtpSourceActionIterator<Controller::SourceMessage> {
if !self.reach.is_reachable() && self.tries >= STARTUP_TRIES_THRESHOLD {
return actions!(NtpSourceAction::Reset);
}
Expand Down Expand Up @@ -596,8 +589,8 @@ impl<Controller: SourceController> NtpSource<Controller> {
#[instrument(skip(self, update), fields(source = debug(self.source_id)))]
pub fn handle_system_update(
&mut self,
update: SystemSourceUpdate<Controller>,
) -> NtpSourceActionIterator<Controller> {
update: SystemSourceUpdate<Controller::ControllerMessage>,
) -> NtpSourceActionIterator<Controller::SourceMessage> {
self.controller.handle_message(update.message);
actions!()
}
Expand All @@ -609,7 +602,7 @@ impl<Controller: SourceController> NtpSource<Controller> {
local_clock_time: NtpInstant,
send_time: NtpTimestamp,
recv_time: NtpTimestamp,
) -> NtpSourceActionIterator<Controller> {
) -> NtpSourceActionIterator<Controller::SourceMessage> {
let message =
match NtpPacket::deserialize(message, &self.nts.as_ref().map(|nts| nts.s2c.as_ref())) {
Ok((packet, _)) => packet,
Expand Down Expand Up @@ -711,7 +704,7 @@ impl<Controller: SourceController> NtpSource<Controller> {
local_clock_time: NtpInstant,
send_time: NtpTimestamp,
recv_time: NtpTimestamp,
) -> NtpSourceActionIterator<Controller> {
) -> NtpSourceActionIterator<Controller::SourceMessage> {
trace!("Packet accepted for processing");
// For reachability, mark that we have had a response
self.reach.received_packet();
Expand Down
41 changes: 20 additions & 21 deletions ntp-proto/src/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::sync::Arc;
use std::time::Duration;
use std::{fmt::Debug, hash::Hash};

use crate::algorithm::SourceController;
#[cfg(feature = "ntpv5")]
use crate::packet::v5::server_reference_id::{BloomFilter, ServerId};
use crate::source::NtpSourceUpdate;
Expand Down Expand Up @@ -112,19 +111,19 @@ impl Default for SystemSnapshot {
}
}

pub struct SystemSourceUpdate<Controller: SourceController> {
pub(crate) message: Controller::ControllerMessage,
pub struct SystemSourceUpdate<ControllerMessage> {
pub(crate) message: ControllerMessage,
}

impl<Controller: SourceController> std::fmt::Debug for SystemSourceUpdate<Controller> {
impl<ControllerMessage: Debug> std::fmt::Debug for SystemSourceUpdate<ControllerMessage> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SystemSourceUpdate")
.field("message", &self.message)
.finish()
}
}

impl<Controller: SourceController> Clone for SystemSourceUpdate<Controller> {
impl<ControllerMessage: Clone> Clone for SystemSourceUpdate<ControllerMessage> {
fn clone(&self) -> Self {
Self {
message: self.message.clone(),
Expand All @@ -134,36 +133,36 @@ impl<Controller: SourceController> Clone for SystemSourceUpdate<Controller> {

#[derive(Debug, Clone)]
#[allow(clippy::large_enum_variant)]
pub enum SystemAction<Controller: SourceController> {
UpdateSources(SystemSourceUpdate<Controller>),
pub enum SystemAction<ControllerMessage> {
UpdateSources(SystemSourceUpdate<ControllerMessage>),
SetTimer(Duration),
}

#[derive(Debug)]
pub struct SystemActionIterator<Controller: SourceController> {
iter: <Vec<SystemAction<Controller>> as IntoIterator>::IntoIter,
pub struct SystemActionIterator<ControllerMessage> {
iter: <Vec<SystemAction<ControllerMessage>> as IntoIterator>::IntoIter,
}

impl<Controller: SourceController> Default for SystemActionIterator<Controller> {
impl<ControllerMessage> Default for SystemActionIterator<ControllerMessage> {
fn default() -> Self {
Self {
iter: vec![].into_iter(),
}
}
}

impl<Controller: SourceController> From<Vec<SystemAction<Controller>>>
for SystemActionIterator<Controller>
impl<ControllerMessage> From<Vec<SystemAction<ControllerMessage>>>
for SystemActionIterator<ControllerMessage>
{
fn from(value: Vec<SystemAction<Controller>>) -> Self {
fn from(value: Vec<SystemAction<ControllerMessage>>) -> Self {
Self {
iter: value.into_iter(),
}
}
}

impl<Controller: SourceController> Iterator for SystemActionIterator<Controller> {
type Item = SystemAction<Controller>;
impl<ControllerMessage> Iterator for SystemActionIterator<ControllerMessage> {
type Item = SystemAction<ControllerMessage>;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
Expand Down Expand Up @@ -252,7 +251,7 @@ impl<SourceId: Hash + Eq + Copy + Debug, Controller: TimeSyncController<SourceId
) -> Result<
(
NtpSource<Controller::SourceController>,
NtpSourceActionIterator<Controller::SourceController>,
NtpSourceActionIterator<Controller::SourceMessage>,
),
<Controller::Clock as NtpClock>::Error,
> {
Expand All @@ -277,7 +276,7 @@ impl<SourceId: Hash + Eq + Copy + Debug, Controller: TimeSyncController<SourceId
) -> Result<
(
NtpSource<Controller::SourceController>,
NtpSourceActionIterator<Controller::SourceController>,
NtpSourceActionIterator<Controller::SourceMessage>,
),
<Controller::Clock as NtpClock>::Error,
> {
Expand Down Expand Up @@ -306,9 +305,9 @@ impl<SourceId: Hash + Eq + Copy + Debug, Controller: TimeSyncController<SourceId
pub fn handle_source_update(
&mut self,
id: SourceId,
update: NtpSourceUpdate<Controller::SourceController>,
update: NtpSourceUpdate<Controller::SourceMessage>,
) -> Result<
SystemActionIterator<Controller::SourceController>,
SystemActionIterator<Controller::ControllerMessage>,
<Controller::Clock as NtpClock>::Error,
> {
let usable = update
Expand All @@ -332,7 +331,7 @@ impl<SourceId: Hash + Eq + Copy + Debug, Controller: TimeSyncController<SourceId
fn handle_algorithm_state_update(
&mut self,
update: StateUpdate<SourceId, Controller::ControllerMessage>,
) -> SystemActionIterator<Controller::SourceController> {
) -> SystemActionIterator<Controller::ControllerMessage> {
let mut actions = vec![];
if let Some(ref used_sources) = update.used_sources {
self.system
Expand All @@ -355,7 +354,7 @@ impl<SourceId: Hash + Eq + Copy + Debug, Controller: TimeSyncController<SourceId
actions.into()
}

pub fn handle_timer(&mut self) -> SystemActionIterator<Controller::SourceController> {
pub fn handle_timer(&mut self) -> SystemActionIterator<Controller::ControllerMessage> {
tracing::debug!("Timer expired");
let update = self.controller.time_update();
self.handle_algorithm_state_update(update)
Expand Down
21 changes: 8 additions & 13 deletions ntpd/src/daemon/ntp_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ use std::{
};

use ntp_proto::{
KalmanSourceController, NtpClock, NtpInstant, NtpSource, NtpSourceActionIterator,
NtpSourceUpdate, NtpTimestamp, ObservableSourceState, SystemSourceUpdate,
KalmanControllerMessage, KalmanSourceController, KalmanSourceMessage, NtpClock, NtpInstant,
NtpSource, NtpSourceActionIterator, NtpSourceUpdate, NtpTimestamp, ObservableSourceState,
SystemSourceUpdate,
};
#[cfg(target_os = "linux")]
use timestamped_socket::socket::open_interface_udp;
Expand Down Expand Up @@ -39,14 +40,14 @@ pub enum MsgForSystem {
/// Source is unreachable, and should be restarted with new resolved addr.
Unreachable(SourceId),
/// Update from source
SourceUpdate(SourceId, NtpSourceUpdate<KalmanSourceController<SourceId>>),
SourceUpdate(SourceId, NtpSourceUpdate<KalmanSourceMessage<SourceId>>),
}

#[derive(Debug)]
pub struct SourceChannels {
pub msg_for_system_sender: tokio::sync::mpsc::Sender<MsgForSystem>,
pub system_update_receiver:
tokio::sync::broadcast::Receiver<SystemSourceUpdate<KalmanSourceController<SourceId>>>,
tokio::sync::broadcast::Receiver<SystemSourceUpdate<KalmanControllerMessage>>,
pub source_snapshots:
Arc<std::sync::RwLock<HashMap<SourceId, ObservableSourceState<SourceId>>>>,
}
Expand Down Expand Up @@ -119,7 +120,7 @@ where
Recv(Result<RecvResult<SocketAddr>, std::io::Error>),
SystemUpdate(
Result<
SystemSourceUpdate<KalmanSourceController<SourceId>>,
SystemSourceUpdate<KalmanControllerMessage>,
tokio::sync::broadcast::error::RecvError,
>,
),
Expand Down Expand Up @@ -332,16 +333,10 @@ where
timestamp_mode: TimestampMode,
channels: SourceChannels,
source: NtpSource<KalmanSourceController<SourceId>>,
initial_actions: NtpSourceActionIterator<KalmanSourceController<SourceId>>,
initial_actions: NtpSourceActionIterator<KalmanSourceMessage<SourceId>>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(
(async move {
/*let (source, initial_actions) = if let Some(nts) = nts {
NtpSource::new_nts(source_addr, config_snapshot, protocol_version, nts)
} else {
NtpSource::new(source_addr, config_snapshot, protocol_version)
};*/

let poll_wait = tokio::time::sleep(std::time::Duration::default());
tokio::pin!(poll_wait);

Expand Down Expand Up @@ -582,7 +577,7 @@ mod tests {
SourceTask<TestClock, T>,
Socket<SocketAddr, Open>,
mpsc::Receiver<MsgForSystem>,
broadcast::Sender<SystemSourceUpdate<KalmanSourceController<SourceId>>>,
broadcast::Sender<SystemSourceUpdate<KalmanControllerMessage>>,
) {
// Note: Ports must be unique among tests to deal with parallelism, hence
// port_base
Expand Down
6 changes: 3 additions & 3 deletions ntpd/src/daemon/system.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use std::{
};

use ntp_proto::{
KalmanClockController, KalmanSourceController, KeySet, NtpClock, ObservableSourceState,
KalmanClockController, KalmanControllerMessage, KeySet, NtpClock, ObservableSourceState,
SourceDefaultsConfig, System, SystemActionIterator, SystemSnapshot, SystemSourceUpdate,
};
use timestamped_socket::interface::InterfaceName;
Expand Down Expand Up @@ -170,7 +170,7 @@ struct SystemTask<C: NtpClock, T: Wait> {

system_snapshot_sender: tokio::sync::watch::Sender<SystemSnapshot>,
system_update_sender:
tokio::sync::broadcast::Sender<SystemSourceUpdate<KalmanSourceController<SourceId>>>,
tokio::sync::broadcast::Sender<SystemSourceUpdate<KalmanControllerMessage>>,
source_snapshots: Arc<std::sync::RwLock<HashMap<SourceId, ObservableSourceState<SourceId>>>>,
server_data_sender: tokio::sync::watch::Sender<Vec<ServerData>>,
keyset: tokio::sync::watch::Receiver<Arc<KeySet>>,
Expand Down Expand Up @@ -327,7 +327,7 @@ impl<C: NtpClock + Sync, T: Wait> SystemTask<C, T> {

fn handle_state_update(
&mut self,
actions: SystemActionIterator<KalmanSourceController<SourceId>>,
actions: SystemActionIterator<KalmanControllerMessage>,
wait: &mut Pin<&mut SingleshotSleep<T>>,
) {
// Don't care if there is no receiver.
Expand Down

0 comments on commit c297ae0

Please sign in to comment.