Skip to content

Commit

Permalink
Pass static SK by ref because of UTs & use OnceLock for them
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanbgd committed Jul 18, 2024
1 parent 1062c55 commit 3d7097c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::errors::HandshakeError;
/// https://github.com/ethereum/devp2p/blob/master/rlpx.md
#[instrument(level = "trace", skip_all)]
pub async fn initiate_handshake(
static_secret_key: SecretKey,
static_secret_key: &SecretKey,
stream: &mut TcpStream,
username: String,
hostname: String,
Expand Down Expand Up @@ -71,7 +71,7 @@ pub async fn respond_to_handshake() {
/// Step 1: initiator connects to recipient and sends its `auth` message
#[instrument(level = "trace", skip_all)]
async fn step_1(
static_secret_key: SecretKey,
static_secret_key: &SecretKey,
stream: &mut TcpStream,
username: &String,
hostname: &String,
Expand Down
14 changes: 9 additions & 5 deletions src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::handshake::initiate_handshake;
/// # Errors
/// - [`CliError::InvalidRecipientHostName`]
/// - [`CliError::ConnectionError`] wrapping [`ConnError::TcpStreamError`]
pub async fn dial(static_secret_key: SecretKey, parsed_args: ParsedArgs) -> Result<(), CliError> {
pub async fn dial(static_secret_key: &SecretKey, parsed_args: ParsedArgs) -> Result<(), CliError> {
let timeout = parsed_args.timeout;
let username = parsed_args.username;
let hostname = parsed_args.hostname;
Expand Down Expand Up @@ -83,6 +83,8 @@ pub async fn answer(_timeout: u64) -> Result<(), CliError> {

#[cfg(test)]
mod tests {
use std::sync::OnceLock;

use k256::SecretKey;
use rand_core::OsRng;

Expand All @@ -91,22 +93,24 @@ mod tests {

use super::*;

static STATIC_SK: OnceLock<SecretKey> = OnceLock::new();

#[tokio::test]
async fn test_dial_pass() {
let static_secret_key: SecretKey = SecretKey::random(&mut OsRng);
STATIC_SK.get_or_init(|| SecretKey::random(&mut OsRng));

let parsed_args = ParsedArgs {
timeout: 1000,
username: TEST_USERNAME.to_string(),
hostname: TEST_HOSTNAME.to_string(),
};

assert!(dial(static_secret_key, parsed_args).await.is_ok());
assert!(dial(&STATIC_SK.get().unwrap(), parsed_args).await.is_ok());
}

#[tokio::test]
async fn test_dial_fail_missing_colon() {
let static_secret_key: SecretKey = SecretKey::random(&mut OsRng);
STATIC_SK.get_or_init(|| SecretKey::random(&mut OsRng));

let bad_hostname = TEST_HOSTNAME.replace(':', "");

Expand All @@ -116,7 +120,7 @@ mod tests {
hostname: bad_hostname.clone(),
};

let result = dial(static_secret_key, parsed_args).await;
let result = dial(&STATIC_SK.get().unwrap(), parsed_args).await;

assert!(result.is_err());
assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async fn main() -> eyre::Result<()> {
let parsed_args = parse_cli_args()?;

if !parsed_args.hostname.is_empty() {
dial(static_secret_key, parsed_args).await?;
dial(&static_secret_key, parsed_args).await?;
} else {
answer(parsed_args.timeout).await?;
}
Expand Down

0 comments on commit 3d7097c

Please sign in to comment.