Skip to content

Commit

Permalink
Require the address to be set with http or https directly
Browse files Browse the repository at this point in the history
Closes #228
  • Loading branch information
jellonek committed Oct 29, 2024
1 parent 497f015 commit e00cdf0
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ serde = { version = "1.0.214", features = ["derive"] }
range-set = "0.0.11"
serde_with = { version = "3.11.0", features = ["hex"] }
serde_json = "1.0.132"
http = "1.1.0"

[build-dependencies]
tonic-build = "0.12.3"
Expand Down
19 changes: 18 additions & 1 deletion service/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//! It connects to the node and registers itself as a Post Service.
//! It then waits for requests from the node and forwards them to the Post Service.
use http::uri::{Scheme, Uri};
use std::time::Duration;

use post::metadata::PostMetadata;
Expand Down Expand Up @@ -68,9 +69,17 @@ impl<S: PostService> ServiceClient<S> {
tls: Option<(Option<String>, Certificate, Identity)>,
service: S,
) -> eyre::Result<Self> {
let listen_address = address.parse::<Uri>()?;
let parts = listen_address.into_parts();
let scheme = parts.scheme.unwrap_or(Scheme::HTTP);
if !["http", "https"].contains(&scheme.as_str()) {
return Err(eyre::eyre!("unknown client protocol"));
};

let endpoint = Channel::builder(address.parse()?)
.keep_alive_timeout(Duration::from_secs(20))
.http2_keep_alive_interval(Duration::from_secs(10 * 60));

let endpoint = match tls {
Some((domain, cert, identity)) => {
let domain = match domain {
Expand All @@ -90,7 +99,15 @@ impl<S: PostService> ServiceClient<S> {
.identity(identity),
)?
}
None => endpoint,
None => {
if scheme == Scheme::HTTPS {
return Err(eyre::eyre!(
"client protocol set to https but tls configuration not provided"
));
}

endpoint
}
};

Ok(Self { endpoint, service })
Expand Down
47 changes: 46 additions & 1 deletion service/tests/test_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use post_service::{
self, service_response, GenProofResponse, GenProofStatus, Metadata, MetadataResponse,
NodeRequest,
},
MockPostService,
MockPostService, ServiceClient,
},
service::ProofGenState,
};
Expand Down Expand Up @@ -64,6 +64,51 @@ async fn test_registers_tls() {
let _ = client_handle.await;
}

#[test]
fn test_client_creation_error_handling() {
let ca = rcgen::generate_simple_self_signed(vec![]).unwrap();
let client = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap();
let tls = Some((
Some("localhost".to_string()),
Certificate::from_pem(ca.serialize_pem().unwrap()).clone(),
Identity::from_pem(
client.serialize_pem_with_signer(&ca).unwrap(),
client.serialize_private_key_pem(),
),
));
let service = Arc::new(MockPostService::new());

// backward compatibility - default to http if no scheme provided.
// should work both with or without tls configuration
let result = ServiceClient::new("localhost:1234".to_string(), tls.clone(), service.clone());
assert!(result.is_ok());
let result = ServiceClient::new("localhost:1234".to_string(), None, service.clone());
assert!(result.is_ok());

let result = ServiceClient::new(
"http://localhost:1234".to_string(),
tls.clone(),
service.clone(),
);
assert!(result.is_ok());
let result = ServiceClient::new("http://localhost:1234".to_string(), None, service.clone());
assert!(result.is_ok());

// should fail only without tls configuration
let result = ServiceClient::new(
"https://localhost:1234".to_string(),
tls.clone(),
service.clone(),
);
assert!(result.is_ok());
let result = ServiceClient::new("https://localhost:1234".to_string(), None, service.clone());
assert!(result.is_err());

// should fail on unrecognized scheme
let result = ServiceClient::new("yolo://localhost:1234".to_string(), None, service.clone());
assert!(result.is_err());
}

#[tokio::test]
async fn test_gen_proof_in_progress() {
let mut test_server = TestServer::new(None).await;
Expand Down

0 comments on commit e00cdf0

Please sign in to comment.