From 3d7097c3f315f54ea603aad82ecfe2fd4e793397 Mon Sep 17 00:00:00 2001 From: Ivan Lazarevic <24551171+ivanbgd@users.noreply.github.com> Date: Fri, 19 Jul 2024 01:27:50 +0200 Subject: [PATCH] Pass static SK by ref because of UTs & use OnceLock for them --- src/handshake.rs | 4 ++-- src/interface.rs | 14 +++++++++----- src/main.rs | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/handshake.rs b/src/handshake.rs index 2f93b83..68be88e 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -27,7 +27,7 @@ use crate::errors::HandshakeError; /// https://github.com/ethereum/devp2p/blob/master/rlpx.md #[instrument(level = "trace", skip_all)] pub async fn initiate_handshake( - static_secret_key: SecretKey, + static_secret_key: &SecretKey, stream: &mut TcpStream, username: String, hostname: String, @@ -71,7 +71,7 @@ pub async fn respond_to_handshake() { /// Step 1: initiator connects to recipient and sends its `auth` message #[instrument(level = "trace", skip_all)] async fn step_1( - static_secret_key: SecretKey, + static_secret_key: &SecretKey, stream: &mut TcpStream, username: &String, hostname: &String, diff --git a/src/interface.rs b/src/interface.rs index fb4fcf9..6b32241 100644 --- a/src/interface.rs +++ b/src/interface.rs @@ -31,7 +31,7 @@ use crate::handshake::initiate_handshake; /// # Errors /// - [`CliError::InvalidRecipientHostName`] /// - [`CliError::ConnectionError`] wrapping [`ConnError::TcpStreamError`] -pub async fn dial(static_secret_key: SecretKey, parsed_args: ParsedArgs) -> Result<(), CliError> { +pub async fn dial(static_secret_key: &SecretKey, parsed_args: ParsedArgs) -> Result<(), CliError> { let timeout = parsed_args.timeout; let username = parsed_args.username; let hostname = parsed_args.hostname; @@ -83,6 +83,8 @@ pub async fn answer(_timeout: u64) -> Result<(), CliError> { #[cfg(test)] mod tests { + use std::sync::OnceLock; + use k256::SecretKey; use rand_core::OsRng; @@ -91,9 +93,11 @@ mod tests { use super::*; + static STATIC_SK: OnceLock = OnceLock::new(); + #[tokio::test] async fn test_dial_pass() { - let static_secret_key: SecretKey = SecretKey::random(&mut OsRng); + STATIC_SK.get_or_init(|| SecretKey::random(&mut OsRng)); let parsed_args = ParsedArgs { timeout: 1000, @@ -101,12 +105,12 @@ mod tests { hostname: TEST_HOSTNAME.to_string(), }; - assert!(dial(static_secret_key, parsed_args).await.is_ok()); + assert!(dial(&STATIC_SK.get().unwrap(), parsed_args).await.is_ok()); } #[tokio::test] async fn test_dial_fail_missing_colon() { - let static_secret_key: SecretKey = SecretKey::random(&mut OsRng); + STATIC_SK.get_or_init(|| SecretKey::random(&mut OsRng)); let bad_hostname = TEST_HOSTNAME.replace(':', ""); @@ -116,7 +120,7 @@ mod tests { hostname: bad_hostname.clone(), }; - let result = dial(static_secret_key, parsed_args).await; + let result = dial(&STATIC_SK.get().unwrap(), parsed_args).await; assert!(result.is_err()); assert_eq!( diff --git a/src/main.rs b/src/main.rs index cea812a..2ee2b6b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -38,7 +38,7 @@ async fn main() -> eyre::Result<()> { let parsed_args = parse_cli_args()?; if !parsed_args.hostname.is_empty() { - dial(static_secret_key, parsed_args).await?; + dial(&static_secret_key, parsed_args).await?; } else { answer(parsed_args.timeout).await?; }