From b343f9df25a1db6e243f2bdad098254f68704b8a Mon Sep 17 00:00:00 2001 From: Smityz Date: Wed, 9 Aug 2023 18:21:34 +0800 Subject: [PATCH] add config to solve by custom dns server Signed-off-by: Smityz --- Cargo.toml | 2 ++ src/common/security.rs | 19 ++++++++++-- src/config.rs | 10 +++++++ src/mock.rs | 4 +-- src/pd/client.rs | 10 +++---- src/pd/cluster.rs | 33 ++++++++++----------- src/pd/retry.rs | 27 +++++++++-------- src/raw/client.rs | 2 +- src/store/client.rs | 7 +++-- src/transaction/client.rs | 2 +- src/util/dns.rs | 62 +++++++++++++++++++++++++++++++++++++++ src/util/mod.rs | 1 + 12 files changed, 133 insertions(+), 46 deletions(-) create mode 100644 src/util/dns.rs diff --git a/Cargo.toml b/Cargo.toml index 9535a6e4..f13681aa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,8 @@ serde_derive = "1.0" thiserror = "1" tokio = { version = "1", features = ["sync", "rt-multi-thread", "macros"] } tonic = { version = "0.9", features = ["tls"] } +trust-dns-resolver = "0.19.4" +url = "2.4" [dev-dependencies] clap = "2" diff --git a/src/common/security.rs b/src/common/security.rs index 483759cf..869529d8 100644 --- a/src/common/security.rs +++ b/src/common/security.rs @@ -4,7 +4,6 @@ use std::fs::File; use std::io::Read; use std::path::Path; use std::path::PathBuf; -use std::time::Duration; use log::info; use regex::Regex; @@ -15,6 +14,7 @@ use tonic::transport::Identity; use crate::internal_err; use crate::Result; +use crate::{util, Config}; lazy_static::lazy_static! { static ref SCHEME_REG: Regex = Regex::new(r"^\s*(https?://)").unwrap(); @@ -73,17 +73,30 @@ impl SecurityManager { // env: Arc, addr: &str, factory: Factory, + config: &Config, ) -> Result where Factory: FnOnce(Channel) -> Client, { let addr = "http://".to_string() + &SCHEME_REG.replace(addr, ""); + let addr = match config.dns_server_addr { + Some(ref dns_server_addr) => { + util::dns::custom_dns( + addr, + dns_server_addr.clone(), + config.dns_search_domain.clone(), + ) + .await? + } + None => addr, + }; + info!("connect to rpc server at endpoint: {:?}", addr); let mut builder = Channel::from_shared(addr)? - .tcp_keepalive(Some(Duration::from_secs(10))) - .keep_alive_timeout(Duration::from_secs(3)); + .tcp_keepalive(config.tcp_keepalive) + .keep_alive_timeout(config.keep_alive_timeout); if !self.ca.is_empty() { let tls = ClientTlsConfig::new() diff --git a/src/config.rs b/src/config.rs index 1be273cc..f90c7b49 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,9 +19,15 @@ pub struct Config { pub cert_path: Option, pub key_path: Option, pub timeout: Duration, + pub tcp_keepalive: Option, + pub keep_alive_timeout: Duration, + pub dns_server_addr: Option, + pub dns_search_domain: Vec, } const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(2); +const DEFAULT_TCP_KEEPALIVE: Duration = Duration::from_secs(10); +const DEFAULT_KEEP_ALIVE_TIMEOUT: Duration = Duration::from_secs(3); impl Default for Config { fn default() -> Self { @@ -30,6 +36,10 @@ impl Default for Config { cert_path: None, key_path: None, timeout: DEFAULT_REQUEST_TIMEOUT, + tcp_keepalive: Some(DEFAULT_TCP_KEEPALIVE), + keep_alive_timeout: DEFAULT_KEEP_ALIVE_TIMEOUT, + dns_server_addr: None, + dns_search_domain: vec![], } } } diff --git a/src/mock.rs b/src/mock.rs index eada6a8e..823ae99f 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -33,12 +33,12 @@ use crate::Timestamp; pub async fn pd_rpc_client() -> PdRpcClient { let config = Config::default(); PdRpcClient::new( - config.clone(), + &config, |_| MockKvConnect, |sm| { futures::future::ok(RetryClient::new_with_cluster( sm, - config.timeout, + config.clone(), MockCluster, )) }, diff --git a/src/pd/client.rs b/src/pd/client.rs index 31f88968..5736393c 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -260,13 +260,13 @@ impl PdClient for PdRpcClient { impl PdRpcClient { pub async fn connect( pd_endpoints: &[String], - config: Config, + config: &Config, enable_codec: bool, ) -> Result { PdRpcClient::new( - config.clone(), - |security_mgr| TikvConnect::new(security_mgr, config.timeout), - |security_mgr| RetryClient::connect(pd_endpoints, security_mgr, config.timeout), + config, + |security_mgr| TikvConnect::new(security_mgr, config.clone()), + |security_mgr| RetryClient::connect(pd_endpoints, security_mgr, config), enable_codec, ) .await @@ -275,7 +275,7 @@ impl PdRpcClient { impl PdRpcClient { pub async fn new( - config: Config, + config: &Config, kv_connect: MakeKvC, pd: MakePd, enable_codec: bool, diff --git a/src/pd/cluster.rs b/src/pd/cluster.rs index 3df4d255..32b11d8c 100644 --- a/src/pd/cluster.rs +++ b/src/pd/cluster.rs @@ -16,6 +16,7 @@ use tonic::Request; use super::timestamp::TimestampOracle; use crate::internal_err; use crate::proto::pdpb; +use crate::Config; use crate::Result; use crate::SecurityManager; use crate::Timestamp; @@ -103,13 +104,9 @@ impl Connection { Connection { security_mgr } } - pub async fn connect_cluster( - &self, - endpoints: &[String], - timeout: Duration, - ) -> Result { - let members = self.validate_endpoints(endpoints, timeout).await?; - let (client, members) = self.try_connect_leader(&members, timeout).await?; + pub async fn connect_cluster(&self, endpoints: &[String], config: &Config) -> Result { + let members = self.validate_endpoints(endpoints, config).await?; + let (client, members) = self.try_connect_leader(&members, config).await?; let id = members.header.as_ref().unwrap().cluster_id; let tso = TimestampOracle::new(id, &client)?; let cluster = Cluster { @@ -122,10 +119,10 @@ impl Connection { } // Re-establish connection with PD leader in asynchronous fashion. - pub async fn reconnect(&self, cluster: &mut Cluster, timeout: Duration) -> Result<()> { + pub async fn reconnect(&self, cluster: &mut Cluster, config: &Config) -> Result<()> { warn!("updating pd client"); let start = Instant::now(); - let (client, members) = self.try_connect_leader(&cluster.members, timeout).await?; + let (client, members) = self.try_connect_leader(&cluster.members, config).await?; let tso = TimestampOracle::new(cluster.id, &client)?; *cluster = Cluster { id: cluster.id, @@ -141,7 +138,7 @@ impl Connection { async fn validate_endpoints( &self, endpoints: &[String], - timeout: Duration, + config: &Config, ) -> Result { let mut endpoints_set = HashSet::with_capacity(endpoints.len()); @@ -152,7 +149,7 @@ impl Connection { return Err(internal_err!("duplicated PD endpoint {}", ep)); } - let (_, resp) = match self.connect(ep, timeout).await { + let (_, resp) = match self.connect(ep, config).await { Ok(resp) => resp, // Ignore failed PD node. Err(e) => { @@ -193,11 +190,11 @@ impl Connection { async fn connect( &self, addr: &str, - _timeout: Duration, + config: &Config, ) -> Result<(pdpb::pd_client::PdClient, pdpb::GetMembersResponse)> { let mut client = self .security_mgr - .connect(addr, pdpb::pd_client::PdClient::::new) + .connect(addr, pdpb::pd_client::PdClient::::new, config) .await?; let resp: pdpb::GetMembersResponse = client .get_members(pdpb::GetMembersRequest::default()) @@ -210,9 +207,9 @@ impl Connection { &self, addr: &str, cluster_id: u64, - timeout: Duration, + config: &Config, ) -> Result<(pdpb::pd_client::PdClient, pdpb::GetMembersResponse)> { - let (client, r) = self.connect(addr, timeout).await?; + let (client, r) = self.connect(addr, config).await?; Connection::validate_cluster_id(addr, &r, cluster_id)?; Ok((client, r)) } @@ -238,7 +235,7 @@ impl Connection { async fn try_connect_leader( &self, previous: &pdpb::GetMembersResponse, - timeout: Duration, + config: &Config, ) -> Result<(pdpb::pd_client::PdClient, pdpb::GetMembersResponse)> { let previous_leader = previous.leader.as_ref().unwrap(); let members = &previous.members; @@ -252,7 +249,7 @@ impl Connection { .chain(Some(previous_leader)) { for ep in &m.client_urls { - match self.try_connect(ep.as_str(), cluster_id, timeout).await { + match self.try_connect(ep.as_str(), cluster_id, config).await { Ok((_, r)) => { resp = Some(r); break 'outer; @@ -269,7 +266,7 @@ impl Connection { if let Some(resp) = resp { let leader = resp.leader.as_ref().unwrap(); for ep in &leader.client_urls { - let r = self.try_connect(ep.as_str(), cluster_id, timeout).await; + let r = self.try_connect(ep.as_str(), cluster_id, config).await; if r.is_ok() { return r; } diff --git a/src/pd/retry.rs b/src/pd/retry.rs index 548fc269..b72dae95 100644 --- a/src/pd/retry.rs +++ b/src/pd/retry.rs @@ -20,6 +20,7 @@ use crate::region::RegionId; use crate::region::RegionWithLeader; use crate::region::StoreId; use crate::stats::pd_stats; +use crate::Config; use crate::Error; use crate::Result; use crate::SecurityManager; @@ -51,21 +52,21 @@ pub struct RetryClient { // Tuple is the cluster and the time of the cluster's last reconnect. cluster: RwLock<(Cl, Instant)>, connection: Connection, - timeout: Duration, + config: Config, } #[cfg(test)] impl RetryClient { pub fn new_with_cluster( security_mgr: Arc, - timeout: Duration, + config: Config, cluster: Cl, ) -> RetryClient { let connection = Connection::new(security_mgr); RetryClient { cluster: RwLock::new((cluster, Instant::now())), connection, - timeout, + config, } } } @@ -107,17 +108,17 @@ impl RetryClient { pub async fn connect( endpoints: &[String], security_mgr: Arc, - timeout: Duration, + config: &Config, ) -> Result { let connection = Connection::new(security_mgr); let cluster = RwLock::new(( - connection.connect_cluster(endpoints, timeout).await?, + connection.connect_cluster(endpoints, config).await?, Instant::now(), )); Ok(RetryClient { cluster, connection, - timeout, + config: config.clone(), }) } } @@ -131,7 +132,7 @@ impl RetryClientTrait for RetryClient { let key = key.clone(); async { cluster - .get_region(key.clone(), self.timeout) + .get_region(key.clone(), self.config.timeout) .await .and_then(|resp| { region_from_response(resp, || Error::RegionForKeyNotFound { key }) @@ -143,7 +144,7 @@ impl RetryClientTrait for RetryClient { async fn get_region_by_id(self: Arc, region_id: RegionId) -> Result { retry!(self, "get_region_by_id", |cluster| async { cluster - .get_region_by_id(region_id, self.timeout) + .get_region_by_id(region_id, self.config.timeout) .await .and_then(|resp| { region_from_response(resp, || Error::RegionNotFoundInResponse { region_id }) @@ -154,7 +155,7 @@ impl RetryClientTrait for RetryClient { async fn get_store(self: Arc, id: StoreId) -> Result { retry!(self, "get_store", |cluster| async { cluster - .get_store(id, self.timeout) + .get_store(id, self.config.timeout) .await .map(|resp| resp.store.unwrap()) }) @@ -164,7 +165,7 @@ impl RetryClientTrait for RetryClient { async fn get_all_stores(self: Arc) -> Result> { retry!(self, "get_all_stores", |cluster| async { cluster - .get_all_stores(self.timeout) + .get_all_stores(self.config.timeout) .await .map(|resp| resp.stores.into_iter().map(Into::into).collect()) }) @@ -177,7 +178,7 @@ impl RetryClientTrait for RetryClient { async fn update_safepoint(self: Arc, safepoint: u64) -> Result { retry!(self, "update_gc_safepoint", |cluster| async { cluster - .update_safepoint(safepoint, self.timeout) + .update_safepoint(safepoint, self.config.timeout) .await .map(|resp| resp.new_safe_point == safepoint) }) @@ -187,7 +188,7 @@ impl RetryClientTrait for RetryClient { impl fmt::Debug for RetryClient { fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { fmt.debug_struct("pd::RetryClient") - .field("timeout", &self.timeout) + .field("timeout", &self.config.timeout) .finish() } } @@ -219,7 +220,7 @@ impl Reconnect for RetryClient { // a concurrent reconnect is just succeed when this thread trying to get write lock let should_connect = reconnect_begin > *last_connected + Duration::from_secs(interval_sec); if should_connect { - self.connection.reconnect(cluster, self.timeout).await?; + self.connection.reconnect(cluster, &self.config).await?; *last_connected = Instant::now(); } Ok(()) diff --git a/src/raw/client.rs b/src/raw/client.rs index 0bdc2f8b..08e058fc 100644 --- a/src/raw/client.rs +++ b/src/raw/client.rs @@ -100,7 +100,7 @@ impl Client { config: Config, ) -> Result { let pd_endpoints: Vec = pd_endpoints.into_iter().map(Into::into).collect(); - let rpc = Arc::new(PdRpcClient::connect(&pd_endpoints, config, false).await?); + let rpc = Arc::new(PdRpcClient::connect(&pd_endpoints, &config, false).await?); Ok(Client { rpc, cf: None, diff --git a/src/store/client.rs b/src/store/client.rs index 363d4137..7b2a5d8d 100644 --- a/src/store/client.rs +++ b/src/store/client.rs @@ -10,6 +10,7 @@ use tonic::transport::Channel; use super::Request; use crate::proto::tikvpb::tikv_client::TikvClient; +use crate::Config; use crate::Result; use crate::SecurityManager; @@ -24,7 +25,7 @@ pub trait KvConnect: Sized + Send + Sync + 'static { #[derive(new, Clone)] pub struct TikvConnect { security_mgr: Arc, - timeout: Duration, + config: Config, } #[async_trait] @@ -33,9 +34,9 @@ impl KvConnect for TikvConnect { async fn connect(&self, address: &str) -> Result { self.security_mgr - .connect(address, TikvClient::new) + .connect(address, TikvClient::new, &self.config) .await - .map(|c| KvRpcClient::new(c, self.timeout)) + .map(|c| KvRpcClient::new(c, self.config.timeout)) } } diff --git a/src/transaction/client.rs b/src/transaction/client.rs index 64d32451..ee061ecb 100644 --- a/src/transaction/client.rs +++ b/src/transaction/client.rs @@ -102,7 +102,7 @@ impl Client { ) -> Result { debug!("creating new transactional client"); let pd_endpoints: Vec = pd_endpoints.into_iter().map(Into::into).collect(); - let pd = Arc::new(PdRpcClient::connect(&pd_endpoints, config, true).await?); + let pd = Arc::new(PdRpcClient::connect(&pd_endpoints, &config, true).await?); Ok(Client { pd }) } diff --git a/src/util/dns.rs b/src/util/dns.rs new file mode 100644 index 00000000..48bff061 --- /dev/null +++ b/src/util/dns.rs @@ -0,0 +1,62 @@ +use std::{net::SocketAddr, str::FromStr}; + +use crate::{Error, Result}; +use url::Url; + +use trust_dns_resolver::{ + config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts}, + Name, Resolver, +}; + +pub async fn custom_dns( + target: String, + dns_addr: String, + search_domain: Vec, +) -> Result { + let server: SocketAddr = dns_addr.parse().map_err(|e| Error::InternalError { + message: format!("dns server error: {}", e), + })?; + let mut search_names: Vec = Vec::new(); + for d in search_domain { + let n = Name::from_str(d.as_str()).map_err(|e| Error::InternalError { + message: format!("dns search domain error: {}", e), + })?; + search_names.push(n); + } + let resolver_config = ResolverConfig::from_parts( + None, + search_names, + vec![NameServerConfig { + socket_addr: server, + protocol: Protocol::Udp, + tls_dns_name: None, + }], + ); + let resolver = Resolver::new(resolver_config, ResolverOpts::default()).map_err(|e| { + Error::InternalError { + message: format!("dns resolver error: {}", e), + } + })?; + let mut url = Url::parse(&target).map_err(|e| Error::InternalError { + message: format!("url parse error: {}", e), + })?; + let hostname = url.host_str().ok_or(Error::InternalError { + message: format!("url parse error: {}", url), + })?; + let ip = resolver + .lookup_ip(hostname) + .map_err(|e| Error::InternalError { + message: format!("dns resolve error: {}", e), + })? + .iter() + .next() + .ok_or(Error::InternalError { + message: format!("can't resolve hostname {}", hostname), + })? + .to_string(); + url.set_host(Some(ip.as_str())) + .map_err(|e| Error::InternalError { + message: format!("url parse error: {}", e), + })?; + Ok(String::from(url)) +} diff --git a/src/util/mod.rs b/src/util/mod.rs index 6d3bf763..59ea28d6 100644 --- a/src/util/mod.rs +++ b/src/util/mod.rs @@ -1,3 +1,4 @@ // Copyright 2021 TiKV Project Authors. Licensed under Apache-2.0. +pub mod dns; pub mod iter;