From 349548015cfe00402b432e786ea136fcb71dbd5b Mon Sep 17 00:00:00 2001
From: Nikhil Sinha <131262146+nikhilsinhaparseable@users.noreply.github.com>
Date: Mon, 16 Sep 2024 15:46:56 +0530
Subject: [PATCH] feat: accept other certificates (#889)

env var P_TRUSTED_CA_CERTS_DIR accepts a directory path
where user can keep all the certificates intended to be accepted by the server
---
 server/src/cli.rs                             | 13 ++++++++++
 .../src/handlers/http/modal/ingest_server.rs  |  1 +
 .../src/handlers/http/modal/query_server.rs   |  1 +
 server/src/handlers/http/modal/server.rs      |  1 +
 .../src/handlers/http/modal/ssl_acceptor.rs   | 25 +++++++++++++++++--
 5 files changed, 39 insertions(+), 2 deletions(-)

diff --git a/server/src/cli.rs b/server/src/cli.rs
index d479ab9b6..d5efaac30 100644
--- a/server/src/cli.rs
+++ b/server/src/cli.rs
@@ -34,6 +34,9 @@ pub struct Cli {
     /// The location of TLS Private Key file
     pub tls_key_path: Option<PathBuf>,
 
+    /// The location of other certificates to accept
+    pub trusted_ca_certs_path: Option<PathBuf>,
+
     /// The address on which the http server will listen.
     pub address: String,
 
@@ -122,6 +125,7 @@ impl Cli {
     // identifiers for arguments
     pub const TLS_CERT: &'static str = "tls-cert-path";
     pub const TLS_KEY: &'static str = "tls-key-path";
+    pub const TRUSTED_CA_CERTS_PATH: &'static str = "trusted-ca-certs-path";
     pub const ADDRESS: &'static str = "address";
     pub const DOMAIN_URI: &'static str = "origin";
     pub const STAGING: &'static str = "local-staging-path";
@@ -224,6 +228,14 @@ impl Cli {
                      .value_parser(validation::file_path)
                      .help("Local path on this device where private key file is located. Required to enable TLS"),
              )
+            .arg(
+                Arg::new(Self::TRUSTED_CA_CERTS_PATH)
+                    .long(Self::TRUSTED_CA_CERTS_PATH)
+                    .env("P_TRUSTED_CA_CERTS_DIR")
+                    .value_name("DIR")
+                    .value_parser(validation::canonicalize_path)
+                    .help("Local path on this device where all trusted certificates are located.")
+            )
              .arg(
                  Arg::new(Self::ADDRESS)
                      .long(Self::ADDRESS)
@@ -509,6 +521,7 @@ impl FromArgMatches for Cli {
         self.query_cache_path = m.get_one::<PathBuf>(Self::QUERY_CACHE).cloned();
         self.tls_cert_path = m.get_one::<PathBuf>(Self::TLS_CERT).cloned();
         self.tls_key_path = m.get_one::<PathBuf>(Self::TLS_KEY).cloned();
+        self.trusted_ca_certs_path = m.get_one::<PathBuf>(Self::TRUSTED_CA_CERTS_PATH).cloned();
         self.domain_address = m.get_one::<Url>(Self::DOMAIN_URI).cloned();
 
         self.address = m
diff --git a/server/src/handlers/http/modal/ingest_server.rs b/server/src/handlers/http/modal/ingest_server.rs
index c19517899..6789819fc 100644
--- a/server/src/handlers/http/modal/ingest_server.rs
+++ b/server/src/handlers/http/modal/ingest_server.rs
@@ -83,6 +83,7 @@ impl ParseableServer for IngestServer {
         let ssl = get_ssl_acceptor(
             &CONFIG.parseable.tls_cert_path,
             &CONFIG.parseable.tls_key_path,
+            &CONFIG.parseable.trusted_ca_certs_path,
         )?;
 
         // fn that creates the app
diff --git a/server/src/handlers/http/modal/query_server.rs b/server/src/handlers/http/modal/query_server.rs
index 9861990de..28c39a63e 100644
--- a/server/src/handlers/http/modal/query_server.rs
+++ b/server/src/handlers/http/modal/query_server.rs
@@ -65,6 +65,7 @@ impl ParseableServer for QueryServer {
         let ssl = get_ssl_acceptor(
             &CONFIG.parseable.tls_cert_path,
             &CONFIG.parseable.tls_key_path,
+            &CONFIG.parseable.trusted_ca_certs_path,
         )?;
 
         let create_app_fn = move || {
diff --git a/server/src/handlers/http/modal/server.rs b/server/src/handlers/http/modal/server.rs
index cd469acee..d3d56eb90 100644
--- a/server/src/handlers/http/modal/server.rs
+++ b/server/src/handlers/http/modal/server.rs
@@ -96,6 +96,7 @@ impl ParseableServer for Server {
         let ssl = get_ssl_acceptor(
             &CONFIG.parseable.tls_cert_path,
             &CONFIG.parseable.tls_key_path,
+            &CONFIG.parseable.trusted_ca_certs_path,
         )?;
 
         // Create a channel to trigger server shutdown
diff --git a/server/src/handlers/http/modal/ssl_acceptor.rs b/server/src/handlers/http/modal/ssl_acceptor.rs
index 84b27ebf8..850b4868b 100644
--- a/server/src/handlers/http/modal/ssl_acceptor.rs
+++ b/server/src/handlers/http/modal/ssl_acceptor.rs
@@ -16,13 +16,18 @@
  *
  */
 
-use std::{fs::File, io::BufReader, path::PathBuf};
+use std::{
+    fs::{self, File},
+    io::BufReader,
+    path::PathBuf,
+};
 
 use rustls::ServerConfig;
 
 pub fn get_ssl_acceptor(
     tls_cert: &Option<PathBuf>,
     tls_key: &Option<PathBuf>,
+    other_certs: &Option<PathBuf>,
 ) -> anyhow::Result<Option<ServerConfig>> {
     match (tls_cert, tls_key) {
         (Some(cert), Some(key)) => {
@@ -30,7 +35,23 @@ pub fn get_ssl_acceptor(
 
             let cert_file = &mut BufReader::new(File::open(cert)?);
             let key_file = &mut BufReader::new(File::open(key)?);
-            let certs = rustls_pemfile::certs(cert_file).collect::<Result<Vec<_>, _>>()?;
+
+            let mut certs = rustls_pemfile::certs(cert_file).collect::<Result<Vec<_>, _>>()?;
+            // Load CA certificates from the directory
+            if let Some(other_cert_dir) = other_certs {
+                if other_cert_dir.is_dir() {
+                    for entry in fs::read_dir(other_cert_dir)? {
+                        let path = entry.unwrap().path();
+
+                        if path.is_file() {
+                            let other_cert_file = &mut BufReader::new(File::open(&path)?);
+                            let mut other_certs = rustls_pemfile::certs(other_cert_file)
+                                .collect::<Result<Vec<_>, _>>()?;
+                            certs.append(&mut other_certs);
+                        }
+                    }
+                }
+            }
             let private_key = rustls_pemfile::private_key(key_file)?
                 .ok_or(anyhow::anyhow!("Could not parse private key."))?;