diff --git a/Cargo.lock b/Cargo.lock index 4c58b0d6..7c14cbc3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3867,6 +3867,7 @@ dependencies = [ "env_logger", "eyre", "hex", + "http 1.1.0", "httpmock", "log", "mockall", diff --git a/service/Cargo.toml b/service/Cargo.toml index b0f88715..4de35a8a 100644 --- a/service/Cargo.toml +++ b/service/Cargo.toml @@ -34,6 +34,7 @@ serde = { version = "1.0.214", features = ["derive"] } range-set = "0.0.11" serde_with = { version = "3.11.0", features = ["hex"] } serde_json = "1.0.132" +http = "1.1.0" [build-dependencies] tonic-build = "0.12.3" diff --git a/service/src/client.rs b/service/src/client.rs index d1c85318..4ba45a17 100644 --- a/service/src/client.rs +++ b/service/src/client.rs @@ -4,6 +4,7 @@ //! It connects to the node and registers itself as a Post Service. //! It then waits for requests from the node and forwards them to the Post Service. +use http::uri::{Scheme, Uri}; use std::time::Duration; use post::metadata::PostMetadata; @@ -68,9 +69,17 @@ impl ServiceClient { tls: Option<(Option, Certificate, Identity)>, service: S, ) -> eyre::Result { + let listen_address = address.parse::()?; + let parts = listen_address.into_parts(); + let scheme = parts.scheme.unwrap_or(Scheme::HTTP); + if !["http", "https"].contains(&scheme.as_str()) { + return Err(eyre::eyre!("unknown client protocol")); + }; + let endpoint = Channel::builder(address.parse()?) .keep_alive_timeout(Duration::from_secs(20)) .http2_keep_alive_interval(Duration::from_secs(10 * 60)); + let endpoint = match tls { Some((domain, cert, identity)) => { let domain = match domain { @@ -90,7 +99,15 @@ impl ServiceClient { .identity(identity), )? } - None => endpoint, + None => { + if scheme == Scheme::HTTPS { + return Err(eyre::eyre!( + "client protocol set to https but tls configuration not provided" + )); + } + + endpoint + } }; Ok(Self { endpoint, service }) diff --git a/service/tests/test_client.rs b/service/tests/test_client.rs index f5a64585..d8fd6587 100644 --- a/service/tests/test_client.rs +++ b/service/tests/test_client.rs @@ -17,7 +17,7 @@ use post_service::{ self, service_response, GenProofResponse, GenProofStatus, Metadata, MetadataResponse, NodeRequest, }, - MockPostService, + MockPostService, ServiceClient, }, service::ProofGenState, }; @@ -64,6 +64,51 @@ async fn test_registers_tls() { let _ = client_handle.await; } +#[test] +fn test_client_creation_error_handling() { + let ca = rcgen::generate_simple_self_signed(vec![]).unwrap(); + let client = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); + let tls = Some(( + Some("localhost".to_string()), + Certificate::from_pem(ca.serialize_pem().unwrap()).clone(), + Identity::from_pem( + client.serialize_pem_with_signer(&ca).unwrap(), + client.serialize_private_key_pem(), + ), + )); + let service = Arc::new(MockPostService::new()); + + // backward compatibility - default to http if no scheme provided. + // should work both with or without tls configuration + let result = ServiceClient::new("localhost:1234".to_string(), tls.clone(), service.clone()); + assert!(result.is_ok()); + let result = ServiceClient::new("localhost:1234".to_string(), None, service.clone()); + assert!(result.is_ok()); + + let result = ServiceClient::new( + "http://localhost:1234".to_string(), + tls.clone(), + service.clone(), + ); + assert!(result.is_ok()); + let result = ServiceClient::new("http://localhost:1234".to_string(), None, service.clone()); + assert!(result.is_ok()); + + // should fail only without tls configuration + let result = ServiceClient::new( + "https://localhost:1234".to_string(), + tls.clone(), + service.clone(), + ); + assert!(result.is_ok()); + let result = ServiceClient::new("https://localhost:1234".to_string(), None, service.clone()); + assert!(result.is_err()); + + // should fail on unrecognized scheme + let result = ServiceClient::new("yolo://localhost:1234".to_string(), None, service.clone()); + assert!(result.is_err()); +} + #[tokio::test] async fn test_gen_proof_in_progress() { let mut test_server = TestServer::new(None).await;