diff --git a/src/server/mod.rs b/src/server/mod.rs index 367f3eb..6010a6a 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -46,7 +46,7 @@ pub struct ServerHandle { impl Server { pub fn new(port: u16, cert: ParsedPkcs12_2, data: AppState) -> Self { - let handle = Arc::new(ServerHandle::new(create_ssl_acceptor(&cert))); + let handle = Arc::new(ServerHandle::new(create_ssl_acceptor(cert))); let mut listener = bind(SocketAddr::from(([0, 0, 0, 0], port))); let mut http = http1::Builder::new(); @@ -148,7 +148,7 @@ impl ServerHandle { } pub fn update_cert(&self, cert: ParsedPkcs12_2) { - let acceptor = create_ssl_acceptor(&cert); + let acceptor = create_ssl_acceptor(cert); self.ssl_acceptor.store(Arc::new(acceptor)); } diff --git a/src/server/ssl.rs b/src/server/ssl.rs index ccceb72..9c771f5 100644 --- a/src/server/ssl.rs +++ b/src/server/ssl.rs @@ -1,9 +1,20 @@ +use std::{sync::Arc, time::Duration}; + +use arc_swap::{access::Access, ArcSwapOption}; +use chrono::{DateTime, NaiveDateTime, Utc}; +use log::debug; use openssl::{ + base64, + hash::MessageDigest, + ocsp::{OcspCertId, OcspCertStatus, OcspRequest, OcspResponse}, pkcs12::ParsedPkcs12_2, ssl::{SslAcceptor, SslMethod, SslMode, SslOptions, SslSessionCacheMode}, + x509::X509VerifyResult, }; +use reqwest::Url; +use tokio::time::{sleep_until, Instant}; -pub fn create_ssl_acceptor(cert: &ParsedPkcs12_2) -> SslAcceptor { +pub fn create_ssl_acceptor(cert: ParsedPkcs12_2) -> SslAcceptor { // TODO error handle let mut builder = SslAcceptor::mozilla_intermediate(SslMethod::tls_server()).unwrap(); builder.clear_options(SslOptions::NO_TLSV1_3); @@ -61,6 +72,37 @@ pub fn create_ssl_acceptor(cert: &ParsedPkcs12_2) -> SslAcceptor { if let Some(i) = &cert.ca { i.iter().for_each(|j| builder.add_extra_chain_cert(j.to_owned()).unwrap()); } + + // OCSP stapling + let ocsp_store = Arc::new(ArcSwapOption::empty()); + let weak_ocsp_store = Arc::downgrade(&ocsp_store); + tokio::spawn(async move { + debug!("Start new OCSP stapling worker"); + loop { + let mut next_run = Instant::now() + Duration::from_secs(3600); + if let Some(ocsp_store) = weak_ocsp_store.upgrade() { + ocsp_store.store(None); + if let Some((res, date)) = fetch_ocsp(&cert).await { + debug!("Fetch OCSP response success"); + ocsp_store.store(Some(Arc::new(res))); + next_run = Instant::now() + date.signed_duration_since(Utc::now()).to_std().unwrap_or_default(); + } + } else { + debug!("Cert dropped, stop old OCSP stapling worker"); + break; + } + + sleep_until(next_run).await; + } + }); + let _ = builder.set_status_callback(move |ssl| { + if let Some(res) = ocsp_store.load().as_ref() { + Ok(ssl.set_ocsp_status(res).is_ok()) + } else { + Ok(false) + } + }); + builder.build() } @@ -74,3 +116,32 @@ fn aes_support() -> bool { fn aes_support() -> bool { false // Unable to check AES acceleration support, assumed negative. } + +async fn fetch_ocsp(full_chain: &ParsedPkcs12_2) -> Option<(Vec, DateTime)> { + let cert = full_chain.cert.as_ref()?; + let chain = full_chain.ca.as_ref()?; + let issuer = chain.iter().find(|ca| ca.issued(cert) == X509VerifyResult::OK)?; + let mut url = cert.ocsp_responders().ok()?.get(0).and_then(|url| Url::parse(url).ok())?; + let mut request = OcspRequest::new().unwrap(); + // Let's Encrypt follow rfc5019 2.1.1: Clients MUST use SHA1 as the hashing algorithm for the CertID.issuerNameHash and the CertID.issuerKeyHash values. + // ref: https://github.com/letsencrypt/boulder/issues/5523#issuecomment-877301162 + request + .add_id(OcspCertId::from_cert(MessageDigest::sha1(), cert, issuer).ok()?) + .ok()?; + let ocsp_encoded = request.to_der().map(|data| base64::encode_block(&data)).ok()?; + url.path_segments_mut().unwrap().push(&ocsp_encoded); + + // Check response + let response = reqwest::get(url).await.ok()?.bytes().await.ok()?; + let ocsp = OcspResponse::from_der(&response).ok()?.basic().ok()?; + let oid = OcspCertId::from_cert(MessageDigest::sha1(), cert, issuer).unwrap(); + let ocsp_status = ocsp.find_status(&oid)?; + if ocsp_status.status == OcspCertStatus::GOOD { + let update = NaiveDateTime::parse_from_str(&ocsp_status.next_update.to_string(), "%b %d %H:%M:%S %Y %Z") + .ok()? + .and_utc(); + return Some((response.to_vec(), update)); + } + + None +}