From 767b68d8ae292e02431a737c6d69a84e8f119d57 Mon Sep 17 00:00:00 2001 From: Xiangjin Date: Mon, 1 Apr 2024 13:05:52 +0800 Subject: [PATCH] fix API --- Cargo.lock | 2 +- src/connector/Cargo.toml | 2 +- src/connector/src/connector_common/common.rs | 10 ++++------ src/connector/src/connector_common/mqtt_common.rs | 15 ++++++--------- src/connector/src/error.rs | 2 +- 5 files changed, 13 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2aa3d7a9f262..afe3079f7601 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9311,6 +9311,7 @@ dependencies = [ "rust_decimal", "rustls-native-certs 0.7.0", "rustls-pemfile 2.1.1", + "rustls-pki-types", "rw_futures_util", "serde", "serde_derive", @@ -9327,7 +9328,6 @@ dependencies = [ "time", "tokio-postgres", "tokio-retry", - "tokio-rustls 0.24.1", "tokio-stream", "tokio-util", "tracing", diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index 548d77ac0805..648061333b5e 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -120,6 +120,7 @@ rumqttc = { version = "0.24.0", features = ["url"] } rust_decimal = "1" rustls-native-certs = "0.7" rustls-pemfile = "2" +rustls-pki-types = "1" rw_futures_util = { workspace = true } serde = { version = "1", features = ["derive", "rc"] } serde_derive = "1" @@ -143,7 +144,6 @@ tokio = { version = "0.2", package = "madsim-tokio", features = [ ] } tokio-postgres = { version = "0.7", features = ["with-uuid-1"] } tokio-retry = "0.3" -tokio-rustls = "0.24" tokio-stream = "0.1" tokio-util = { version = "0.7", features = ["codec", "io"] } tonic = { workspace = true } diff --git a/src/connector/src/connector_common/common.rs b/src/connector/src/connector_common/common.rs index 122383400f5c..302b68dd664a 100644 --- a/src/connector/src/connector_common/common.rs +++ b/src/connector/src/connector_common/common.rs @@ -687,7 +687,7 @@ impl NatsCommon { pub(crate) fn load_certs( certificates: &str, -) -> ConnectorResult> { +) -> ConnectorResult>> { let cert_bytes = if let Some(path) = certificates.strip_prefix("fs://") { std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())? } else { @@ -695,13 +695,13 @@ pub(crate) fn load_certs( }; rustls_pemfile::certs(&mut cert_bytes.as_slice()) - .map(|cert| Ok(tokio_rustls::rustls::Certificate(cert?.to_vec()))) + .map(|cert| Ok(cert?)) .collect() } pub(crate) fn load_private_key( certificate: &str, -) -> ConnectorResult { +) -> ConnectorResult> { let cert_bytes = if let Some(path) = certificate.strip_prefix("fs://") { std::fs::read_to_string(path).map(|cert| cert.as_bytes().to_owned())? } else { @@ -711,7 +711,5 @@ pub(crate) fn load_private_key( let cert = rustls_pemfile::pkcs8_private_keys(&mut cert_bytes.as_slice()) .next() .ok_or_else(|| anyhow!("No private key found"))?; - Ok(tokio_rustls::rustls::PrivateKey( - cert?.secret_pkcs8_der().to_vec(), - )) + Ok(cert?.into()) } diff --git a/src/connector/src/connector_common/mqtt_common.rs b/src/connector/src/connector_common/mqtt_common.rs index b771fd34143d..e607decff58a 100644 --- a/src/connector/src/connector_common/mqtt_common.rs +++ b/src/connector/src/connector_common/mqtt_common.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use rumqttc::tokio_rustls::rustls; use rumqttc::v5::mqttbytes::QoS; use rumqttc::v5::{AsyncClient, EventLoop, MqttOptions}; use serde_derive::Deserialize; @@ -141,26 +142,22 @@ impl MqttCommon { .unwrap_or(QoS::AtMostOnce) } - fn get_tls_config(&self) -> ConnectorResult { - let mut root_cert_store = tokio_rustls::rustls::RootCertStore::empty(); + fn get_tls_config(&self) -> ConnectorResult { + let mut root_cert_store = rustls::RootCertStore::empty(); if let Some(ca) = &self.ca { let certificates = load_certs(ca)?; for cert in certificates { - root_cert_store.add(&cert).unwrap(); + root_cert_store.add(cert).unwrap(); } } else { for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") { - root_cert_store - .add(&tokio_rustls::rustls::Certificate(cert.to_vec())) - .unwrap(); + root_cert_store.add(cert).unwrap(); } } - let builder = tokio_rustls::rustls::ClientConfig::builder() - .with_safe_defaults() - .with_root_certificates(root_cert_store); + let builder = rustls::ClientConfig::builder().with_root_certificates(root_cert_store); let tls_config = if let (Some(client_cert), Some(client_key)) = (self.client_cert.as_ref(), self.client_key.as_ref()) diff --git a/src/connector/src/error.rs b/src/connector/src/error.rs index 17885931f58e..925f90c20c96 100644 --- a/src/connector/src/error.rs +++ b/src/connector/src/error.rs @@ -58,7 +58,7 @@ def_anyhow_newtype! { redis::RedisError => "Redis error", arrow_schema::ArrowError => "Arrow error", google_cloud_pubsub::client::google_cloud_auth::error::Error => "Google Cloud error", - tokio_rustls::rustls::Error => "TLS error", + rumqttc::tokio_rustls::rustls::Error => "TLS error", rumqttc::v5::ClientError => "MQTT error", rumqttc::v5::OptionError => "MQTT error",