diff --git a/pdc-update/src/main.rs b/pdc-update/src/main.rs index 642a17c..4e205b4 100644 --- a/pdc-update/src/main.rs +++ b/pdc-update/src/main.rs @@ -1,12 +1,12 @@ use std::{collections::HashMap, time::Duration}; use clap::Parser; -use cost::CostModel; mod cost; mod metrics; mod updater; +use cost::CostModel; use metrics::MetricsScraper; use updater::{WorkerUpdate, WorkerUpdater}; @@ -47,6 +47,72 @@ struct Cli { granularity: f64, } +/// Wait for termination signal (either SIGINT or SIGTERM) +#[cfg(unix)] +async fn wait_terminate() { + use futures::{stream::FuturesUnordered, StreamExt}; + use tokio::signal::unix::{signal, SignalKind}; + let mut signals = Vec::new(); + + // Register signal handlers + for sig in [SignalKind::terminate(), SignalKind::interrupt()] { + match signal(sig) { + Ok(sig) => signals.push(sig), + Err(err) => log::error!("Could not register signal handler: {err}"), + } + } + + // Wait for the first signal to trigger + let mut signals = signals + .iter_mut() + .map(|sig| sig.recv()) + .collect::>(); + + loop { + match signals.next().await { + // One of the signal triggered -> stop waiting + Some(Some(())) => break, + // One of the signal handler has been stopped -> continue waiting for the others + Some(None) => (), + // No more signal handlers are available, so wait indefinitely + None => futures::future::pending::<()>().await, + } + } +} + +#[cfg(windows)] +macro_rules! win_signal { + ($($sig:ident),*$(,)?) => { + $( + let $sig = async { + match tokio::signal::windows::$sig() { + Ok(mut $sig) => { + if $sig.recv().await.is_some() { + return; + } + } + Err(err) => log::error!( + "Could not register signal handler for {}: {err}", + stringify!($sig), + ), + } + futures::future::pending::<()>().await; + }; + )* + tokio::select! { + $( + _ = $sig => {} + )* + } + } +} + +/// Wait for termination signal (either SIGINT or SIGTERM) +#[cfg(windows)] +async fn wait_terminate() { + win_signal!(ctrl_c, ctrl_close, ctrl_logoff, ctrl_shutdown); +} + #[tokio::main] async fn main() -> Result<(), Box> { env_logger::init(); @@ -66,8 +132,14 @@ async fn main() -> Result<(), Box> { let mut interval = tokio::time::interval(Duration::from_secs_f64(cli.period)); interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + let mut wait_terminate = std::pin::pin!(wait_terminate()); + loop { - interval.tick().await; + // Wait for next tick or termination + tokio::select! { + _ = interval.tick() => {} + _ = &mut wait_terminate => break, + } // Scrap metrics let metrics = metrics_scraper.scrap_metrics().await?; @@ -112,4 +184,6 @@ async fn main() -> Result<(), Box> { _ => log::info!("{n} pods have been updated"), } } + + Ok(()) }