Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async support #23

Merged
merged 15 commits into from
Jan 4, 2024
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ jobs:
run: cargo clippy --no-default-features --features="std log utils"
- name: Run clippy with no_std
run: cargo clippy --no-default-features --features="log"
- name: Run clippy for tokio feature
run: cargo clippy --features="tokio"
- name: Run clippy for async feature
run: cargo clippy --no-default-feature --features="async"
vpetrigo marked this conversation as resolved.
Show resolved Hide resolved
check_format:
runs-on: ubuntu-latest
steps:
Expand All @@ -34,4 +38,4 @@ jobs:
- name: Run tests with std
run: cargo test
- name: Run tests with no_std
run: cargo test --no-default-features
run: cargo test --no-default-features
11 changes: 10 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,18 @@ exclude = [
default = ["std"]
std = []
utils = ["std", "chrono/clock"]
async = []
async_tokio = ["std", "async", "tokio", "async-trait"]

[dependencies]
log = { version = "~0.4", optional = true }
chrono = { version = "~0.4", default-features = false, optional = true }
# requred till this https://github.com/rust-lang/rfcs/pull/2832 is not addressed
no-std-net = "~0.6"

async-trait = { version = "0.1", optional = true }
tokio = { version = "1", features = ["full"], optional = true }

[dev-dependencies]
simple_logger = { version = "~1.13" }
smoltcp = { version = "~0.9", default-features = false, features = ["phy-tuntap_interface", "socket-udp", "proto-ipv4"] }
Expand All @@ -48,4 +53,8 @@ required-features = ["utils"]

[[example]]
name = "smoltcp_request"
required-features = ["std"]
required-features = ["std"]

[[example]]
name = "tokio"
required-features = ["async_tokio"]
46 changes: 46 additions & 0 deletions examples/tokio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use sntpc::{
r#async::{get_time, NtpUdpSocket},
Error, NtpContext, Result, StdTimestampGen,
};
use std::net::SocketAddr;
use tokio::net::{ToSocketAddrs, UdpSocket};

const POOL_NTP_ADDR: &str = "pool.ntp.org:123";

#[derive(Debug)]
struct Socket {
sock: UdpSocket,
}

#[async_trait::async_trait]
impl NtpUdpSocket for Socket {
async fn send_to<T: ToSocketAddrs + Send>(
&self,
buf: &[u8],
addr: T,
) -> Result<usize> {
self.sock
.send_to(buf, addr)
.await
.map_err(|_| Error::Network)
}

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
self.sock.recv_from(buf).await.map_err(|_| Error::Network)
}
}

#[tokio::main]
async fn main() {
let sock = UdpSocket::bind("0.0.0.0:0".parse::<SocketAddr>().unwrap())
.await
.expect("Socket creation");
let socket = Socket { sock: sock };
let ntp_context = NtpContext::new(StdTimestampGen::default());

let res = get_time(POOL_NTP_ADDR, socket, ntp_context)
.await
.expect("get_time error");

println!("RESULT: {:?}", res);
}
149 changes: 149 additions & 0 deletions src/async.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use crate::types::{
Error, NtpContext, NtpPacket, NtpResult, NtpTimestampGenerator,
RawNtpPacket, Result, SendRequestResult,
};
use crate::{get_ntp_timestamp, process_response};
use core::fmt::Debug;
#[cfg(feature = "log")]
use log::debug;

#[cfg(feature = "std")]
use std::net::SocketAddr;
#[cfg(feature = "tokio")]
use tokio::net::{lookup_host, ToSocketAddrs};

#[cfg(not(feature = "std"))]
pub use no_std_net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs};

#[cfg(not(feature = "std"))]
async fn lookup_host<T>(host: T) -> Result<impl Iterator<Item = SocketAddr>>
where
T: ToSocketAddrs,
{
#[allow(unused_variables)]
host.to_socket_addrs().map_err(|e| {
#[cfg(feature = "log")]
debug!("ToScoketAddrs: {}", e);
Error::AddressResolve
})
}

#[cfg(feature = "tokio")]
#[async_trait::async_trait]
pub trait NtpUdpSocket {
async fn send_to<T: ToSocketAddrs + Send>(
&self,
buf: &[u8],
addr: T,
) -> Result<usize>;

async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)>;
}

#[cfg(not(feature = "std"))]
pub trait NtpUdpSocket {
fn send_to<T: ToSocketAddrs + Send>(
&self,
buf: &[u8],
addr: T,
) -> impl core::future::Future<Output = Result<usize>>;

fn recv_from(
&self,
buf: &mut [u8],
) -> impl core::future::Future<Output = Result<(usize, SocketAddr)>>;
}

pub async fn sntp_send_request<A, U, T>(
dest: A,
socket: &U,
context: NtpContext<T>,
) -> Result<SendRequestResult>
where
A: ToSocketAddrs + Debug + Send,
U: NtpUdpSocket + Debug,
T: NtpTimestampGenerator + Copy,
{
#[cfg(feature = "log")]
debug!("Address: {:?}, Socket: {:?}", dest, *socket);
let request = NtpPacket::new(context.timestamp_gen);

send_request(dest, &request, socket).await?;
Ok(SendRequestResult::from(request))
}

async fn send_request<A: ToSocketAddrs + Send, U: NtpUdpSocket>(
dest: A,
req: &NtpPacket,
socket: &U,
) -> core::result::Result<(), Error> {
let buf = RawNtpPacket::from(req);

match socket.send_to(&buf.0, dest).await {
Ok(size) => {
if size == buf.0.len() {
Ok(())
} else {
Err(Error::Network)
}
}
Err(_) => Err(Error::Network),
}
}

pub async fn sntp_process_response<A, U, T>(
dest: A,
socket: &U,
mut context: NtpContext<T>,
send_req_result: SendRequestResult,
) -> Result<NtpResult>
where
A: ToSocketAddrs + Debug,
U: NtpUdpSocket + Debug,
T: NtpTimestampGenerator + Copy,
{
let mut response_buf = RawNtpPacket::default();
let (response, src) = socket.recv_from(response_buf.0.as_mut()).await?;
context.timestamp_gen.init();
let recv_timestamp = get_ntp_timestamp(context.timestamp_gen);
#[cfg(feature = "log")]
debug!("Response: {}", response);

match lookup_host(dest).await {
Err(_) => return Err(Error::AddressResolve),
Ok(mut it) => {
if !it.any(|addr| addr == src) {
return Err(Error::ResponseAddressMismatch);
}
}
}

if response != core::mem::size_of::<NtpPacket>() {
return Err(Error::IncorrectPayload);
}

let result =
process_response(send_req_result, response_buf, recv_timestamp);

if let Ok(_r) = &result {
#[cfg(feature = "log")]
debug!("{:?}", _r);
}

result
}

pub async fn get_time<A, U, T>(
pool_addrs: A,
socket: U,
context: NtpContext<T>,
) -> Result<NtpResult>
where
A: ToSocketAddrs + Copy + Debug + Send,
U: NtpUdpSocket + Debug,
T: NtpTimestampGenerator + Copy,
{
let result = sntp_send_request(pool_addrs, &socket, context).await?;

sntp_process_response(pool_addrs, &socket, context, result).await
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ pub mod utils;
mod types;
pub use crate::types::*;

#[cfg(feature = "async")]
pub mod r#async;
vpetrigo marked this conversation as resolved.
Show resolved Hide resolved

use core::fmt::Debug;
use core::iter::Iterator;
use core::marker::Copy;
Expand Down
Loading