Skip to content

Commit

Permalink
Support auto-reload of tls certificate
Browse files Browse the repository at this point in the history
  • Loading branch information
erebe committed Dec 29, 2023
1 parent c9bc107 commit 640102f
Show file tree
Hide file tree
Showing 5 changed files with 316 additions and 18 deletions.
128 changes: 128 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ nix = { version = "0.27.1", features = ["socket", "net", "uio"] }
once_cell = { version = "1.19.0", features = [] }
parking_lot = "0.12.1"
pin-project = "1"
notify = { version = "6.1.1", features = [] }

rustls-native-certs = { version = "0.7.0", features = [] }
rustls-pemfile = { version = "2.0.0", features = [] }
Expand Down
27 changes: 17 additions & 10 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use futures_util::{stream, TryStreamExt};
use hickory_resolver::config::{NameServerConfig, ResolverConfig, ResolverOpts};
use hyper::header::HOST;
use hyper::http::{HeaderName, HeaderValue};
use parking_lot::Mutex;
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::fmt::{Debug, Formatter};
Expand Down Expand Up @@ -217,10 +218,12 @@ struct Server {
restrict_http_upgrade_path_prefix: Option<Vec<String>>,

/// [Optional] Use custom certificate (.crt) instead of the default embedded self signed certificate.
/// The certificate will be automatically reloaded if it changes
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
tls_certificate: Option<PathBuf>,

/// [Optional] Use a custom tls key (.key) that the server will use instead of the default embedded one
/// The private key will be automatically reloaded if it changes
#[arg(long, value_name = "FILE_PATH", verbatim_doc_comment)]
tls_private_key: Option<PathBuf>,
}
Expand Down Expand Up @@ -481,13 +484,14 @@ pub struct TlsClientConfig {
pub tls_verify_certificate: bool,
}

#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct TlsServerConfig {
pub tls_certificate: Vec<Certificate>,
pub tls_key: PrivateKey,
pub tls_certificate: Mutex<Vec<Certificate>>,
pub tls_key: Mutex<PrivateKey>,
pub tls_certificate_path: Option<PathBuf>,
pub tls_key_path: Option<PathBuf>,
}

#[derive(Clone)]
pub struct WsServerConfig {
pub socket_so_mark: Option<u32>,
pub bind: SocketAddr,
Expand Down Expand Up @@ -814,20 +818,23 @@ async fn main() {
}
Commands::Server(args) => {
let tls_config = if args.remote_addr.scheme() == "wss" {
let tls_certificate = if let Some(cert_path) = args.tls_certificate {
tls::load_certificates_from_pem(&cert_path).expect("Cannot load tls certificate")
let tls_certificate = if let Some(cert_path) = &args.tls_certificate {
tls::load_certificates_from_pem(cert_path).expect("Cannot load tls certificate")
} else {
embedded_certificate::TLS_CERTIFICATE.clone()
};

let tls_key = if let Some(key_path) = args.tls_private_key {
tls::load_private_key_from_file(&key_path).expect("Cannot load tls private key")
let tls_key = if let Some(key_path) = &args.tls_private_key {
tls::load_private_key_from_file(key_path).expect("Cannot load tls private key")
} else {
embedded_certificate::TLS_PRIVATE_KEY.clone()
};

Some(TlsServerConfig {
tls_certificate,
tls_key,
tls_certificate: Mutex::new(tls_certificate),
tls_key: Mutex::new(tls_key),
tls_certificate_path: args.tls_certificate,
tls_key_path: args.tls_private_key,
})
} else {
None
Expand Down
2 changes: 1 addition & 1 deletion src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ pub fn tls_acceptor(tls_cfg: &TlsServerConfig, alpn_protocols: Option<Vec<Vec<u8
let mut config = rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(tls_cfg.tls_certificate.clone(), tls_cfg.tls_key.clone())
.with_single_cert(tls_cfg.tls_certificate.lock().clone(), tls_cfg.tls_key.lock().clone())
.with_context(|| "invalid tls certificate or private key")?;

if let Some(alpn_protocols) = alpn_protocols {
Expand Down
Loading

0 comments on commit 640102f

Please sign in to comment.