diff --git a/src/config.rs b/src/config.rs index a0ee3abb..7f58a305 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,7 +2,9 @@ use serde_derive::{Deserialize, Serialize}; use std::{path::PathBuf, time::Duration}; +use tikv_client_store::KvClientConfig; +const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(2); /// The configuration for either a [`RawClient`](crate::RawClient) or a /// [`TransactionClient`](crate::TransactionClient). /// @@ -16,10 +18,9 @@ pub struct Config { pub cert_path: Option, pub key_path: Option, pub timeout: Duration, + pub kv_config: KvClientConfig, } -const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(2); - impl Default for Config { fn default() -> Self { Config { @@ -27,6 +28,7 @@ impl Default for Config { cert_path: None, key_path: None, timeout: DEFAULT_REQUEST_TIMEOUT, + kv_config: KvClientConfig::default(), } } } @@ -80,4 +82,50 @@ impl Config { self.timeout = timeout; self } + + // TODO: add more config options for tivk client config + pub fn with_kv_timeout(mut self, timeout: u64) -> Self { + self.kv_config.request_timeout = timeout; + self + } + + pub fn with_kv_completion_queue_size(mut self, size: usize) -> Self { + self.kv_config.completion_queue_size = size; + self + } + + pub fn with_kv_grpc_keepalive_time(mut self, time: u64) -> Self { + self.kv_config.grpc_keepalive_time = time; + self + } + + pub fn with_kv_grpc_keepalive_timeout(mut self, timeout: u64) -> Self { + self.kv_config.grpc_keepalive_timeout = timeout; + self + } + + pub fn with_kv_allow_batch(mut self, allow_batch: bool) -> Self { + self.kv_config.allow_batch = allow_batch; + self + } + + pub fn with_kv_overload_threshold(mut self, threshold: u64) -> Self { + self.kv_config.overload_threshold = threshold; + self + } + + pub fn with_kv_max_batch_wait_time(mut self, wait: u64) -> Self { + self.kv_config.max_batch_wait_time = wait; + self + } + + pub fn with_kv_max_batch_size(mut self, size: usize) -> Self { + self.kv_config.max_batch_size = size; + self + } + + pub fn with_kv_max_inflight_requests(mut self, requests: usize) -> Self { + self.kv_config.max_inflight_requests = requests; + self + } } diff --git a/src/mock.rs b/src/mock.rs index 02a0bef9..107a31b7 100644 --- a/src/mock.rs +++ b/src/mock.rs @@ -16,7 +16,7 @@ use derive_new::new; use slog::{Drain, Logger}; use std::{any::Any, sync::Arc}; use tikv_client_proto::metapb; -use tikv_client_store::{KvClient, KvConnect, Request}; +use tikv_client_store::{KvClient, KvClientConfig, KvConnect, Request}; /// Create a `PdRpcClient` with it's internals replaced with mocks so that the /// client can be tested without doing any RPC calls. @@ -52,13 +52,15 @@ pub async fn pd_rpc_client() -> PdRpcClient { #[derive(new, Default, Clone)] pub struct MockKvClient { pub addr: String, - dispatch: Option Result> + Send + Sync + 'static>>, + dispatch: Option< + Arc Result> + Send + Sync + 'static>, + >, } impl MockKvClient { pub fn with_dispatch_hook(dispatch: F) -> MockKvClient where - F: Fn(&dyn Any) -> Result> + Send + Sync + 'static, + F: Fn(&(dyn Any + Send)) -> Result> + Send + Sync + 'static, { MockKvClient { addr: String::new(), @@ -78,7 +80,7 @@ pub struct MockPdClient { #[async_trait] impl KvClient for MockKvClient { - async fn dispatch(&self, req: &dyn Request) -> Result> { + async fn dispatch(&self, req: Box) -> Result> { match &self.dispatch { Some(f) => f(req.as_any()), None => panic!("no dispatch hook set"), @@ -89,7 +91,7 @@ impl KvClient for MockKvClient { impl KvConnect for MockKvConnect { type KvClient = MockKvClient; - fn connect(&self, address: &str) -> Result { + fn connect(&self, address: &str, _: KvClientConfig) -> Result { Ok(MockKvClient { addr: address.to_owned(), dispatch: None, diff --git a/src/pd/client.rs b/src/pd/client.rs index 75610544..2e4db808 100644 --- a/src/pd/client.rs +++ b/src/pd/client.rs @@ -16,10 +16,9 @@ use slog::Logger; use std::{collections::HashMap, sync::Arc, thread}; use tikv_client_pd::Cluster; use tikv_client_proto::{kvrpcpb, metapb}; -use tikv_client_store::{KvClient, KvConnect, TikvConnect}; +use tikv_client_store::{KvClient, KvClientConfig, KvConnect, TikvConnect}; use tokio::sync::RwLock; -const CQ_COUNT: usize = 1; const CLIENT_PREFIX: &str = "tikv-client"; /// The PdClient handles all the encoding stuff. @@ -210,6 +209,7 @@ pub struct PdRpcClient>, kv_connect: KvC, kv_client_cache: Arc>>, + kv_config: KvClientConfig, enable_codec: bool, region_cache: RegionCache>, logger: Logger, @@ -304,7 +304,7 @@ impl PdRpcClient { { let env = Arc::new( EnvBuilder::new() - .cq_count(CQ_COUNT) + .cq_count(config.kv_config.completion_queue_size) .name_prefix(thread_name(CLIENT_PREFIX)) .build(), ); @@ -324,6 +324,7 @@ impl PdRpcClient { pd: pd.clone(), kv_client_cache, kv_connect: kv_connect(env, security_mgr), + kv_config: config.kv_config, enable_codec, region_cache: RegionCache::new(pd), logger, @@ -335,7 +336,7 @@ impl PdRpcClient { return Ok(client.clone()); }; info!(self.logger, "connect to tikv endpoint: {:?}", address); - match self.kv_connect.connect(address) { + match self.kv_connect.connect(address, self.kv_config.clone()) { Ok(client) => { self.kv_client_cache .write() diff --git a/src/raw/client.rs b/src/raw/client.rs index b9ac740d..dcbca4d8 100644 --- a/src/raw/client.rs +++ b/src/raw/client.rs @@ -6,6 +6,7 @@ use std::{str::FromStr, sync::Arc, u32}; use slog::{Drain, Logger}; use tikv_client_common::Error; use tikv_client_proto::metapb; +use tikv_client_store::KvClientConfig; use crate::{ backoff::DEFAULT_REGION_BACKOFF, @@ -31,6 +32,7 @@ pub struct Client { /// Whether to use the [`atomic mode`](Client::with_atomic_for_cas). atomic: bool, logger: Logger, + kv_config: KvClientConfig, } impl Clone for Client { @@ -40,6 +42,7 @@ impl Clone for Client { cf: self.cf.clone(), atomic: self.atomic, logger: self.logger.clone(), + kv_config: self.kv_config.clone(), } } } @@ -106,13 +109,15 @@ impl Client { }); debug!(logger, "creating new raw client"); let pd_endpoints: Vec = pd_endpoints.into_iter().map(Into::into).collect(); - let rpc = - Arc::new(PdRpcClient::connect(&pd_endpoints, config, false, logger.clone()).await?); + let rpc = Arc::new( + PdRpcClient::connect(&pd_endpoints, config.clone(), false, logger.clone()).await?, + ); Ok(Client { rpc, cf: None, atomic: false, logger, + kv_config: config.kv_config, }) } @@ -147,6 +152,7 @@ impl Client { cf: Some(cf), atomic: self.atomic, logger: self.logger.clone(), + kv_config: self.kv_config.clone(), } } @@ -164,6 +170,7 @@ impl Client { cf: self.cf.clone(), atomic: true, logger: self.logger.clone(), + kv_config: self.kv_config.clone(), } } } @@ -773,7 +780,7 @@ mod tests { o!(), ); let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( - move |req: &dyn Any| { + move |req: &(dyn Any + Send)| { if let Some(req) = req.downcast_ref::() { assert_eq!(req.copr_name, "example"); assert_eq!(req.copr_version_req, "0.1.0"); @@ -781,7 +788,7 @@ mod tests { data: req.data.clone(), ..Default::default() }; - Ok(Box::new(resp) as Box) + Ok(Box::new(resp) as Box) } else { unreachable!() } @@ -792,6 +799,7 @@ mod tests { cf: Some(ColumnFamily::Default), atomic: false, logger, + kv_config: KvClientConfig::default(), }; let resps = client .coprocessor( diff --git a/src/raw/requests.rs b/src/raw/requests.rs index bd678aaf..b8f0c4a8 100644 --- a/src/raw/requests.rs +++ b/src/raw/requests.rs @@ -361,7 +361,11 @@ pub struct RawCoprocessorRequest { #[async_trait] impl Request for RawCoprocessorRequest { - async fn dispatch(&self, client: &TikvClient, options: CallOption) -> Result> { + async fn dispatch( + &self, + client: &TikvClient, + options: CallOption, + ) -> Result> { self.inner.dispatch(client, options).await } @@ -369,13 +373,17 @@ impl Request for RawCoprocessorRequest { self.inner.label() } - fn as_any(&self) -> &dyn Any { + fn as_any(&self) -> &(dyn Any + Send) { self.inner.as_any() } fn set_context(&mut self, context: kvrpcpb::Context) { self.inner.set_context(context); } + + fn to_batch_request(&self) -> tikv_client_proto::tikvpb::batch_commands_request::Request { + todo!() + } } impl KvRequest for RawCoprocessorRequest { @@ -483,7 +491,7 @@ mod test { #[ignore] fn test_raw_scan() { let client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( - |req: &dyn Any| { + |req: &(dyn Any + Send)| { let req: &kvrpcpb::RawScanRequest = req.downcast_ref().unwrap(); assert!(req.key_only); assert_eq!(req.limit, 10); @@ -497,7 +505,7 @@ mod test { resp.kvs.push(kv); } - Ok(Box::new(resp) as Box) + Ok(Box::new(resp) as Box) }, ))); diff --git a/src/request/mod.rs b/src/request/mod.rs index 959ff5e7..e429f907 100644 --- a/src/request/mod.rs +++ b/src/request/mod.rs @@ -105,7 +105,7 @@ mod test { #[async_trait] impl Request for MockKvRequest { - async fn dispatch(&self, _: &TikvClient, _: CallOption) -> Result> { + async fn dispatch(&self, _: &TikvClient, _: CallOption) -> Result> { Ok(Box::new(MockRpcResponse {})) } @@ -113,13 +113,19 @@ mod test { "mock" } - fn as_any(&self) -> &dyn Any { + fn as_any(&self) -> &(dyn Any + Send) { self } fn set_context(&mut self, _: kvrpcpb::Context) { unreachable!(); } + + fn to_batch_request( + &self, + ) -> tikv_client_proto::tikvpb::batch_commands_request::Request { + todo!() + } } #[async_trait] @@ -162,7 +168,7 @@ mod test { }; let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( - |_: &dyn Any| Ok(Box::new(MockRpcResponse) as Box), + |_: &(dyn Any + Send)| Ok(Box::new(MockRpcResponse) as Box), ))); let plan = crate::request::PlanBuilder::new(pd_client.clone(), request) @@ -179,12 +185,12 @@ mod test { #[tokio::test] async fn test_extract_error() { let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( - |_: &dyn Any| { + |_: &(dyn Any + Send)| { Ok(Box::new(kvrpcpb::CommitResponse { region_error: None, error: Some(kvrpcpb::KeyError::default()), commit_version: 0, - }) as Box) + }) as Box) }, ))); diff --git a/src/request/plan.rs b/src/request/plan.rs index ce785262..31e97cc8 100644 --- a/src/request/plan.rs +++ b/src/request/plan.rs @@ -48,7 +48,7 @@ impl Plan for Dispatch { .kv_client .as_ref() .expect("Unreachable: kv_client has not been initialised in Dispatch") - .dispatch(&self.request) + .dispatch(Box::new(self.request.clone())) .await; let result = stats.done(result); result.map(|r| { @@ -85,7 +85,7 @@ where preserve_region_results: bool, ) -> Result<::Result> { let shards = current_plan.shards(&pd_client).collect::>().await; - let mut handles = Vec::new(); + let mut handles = Vec::with_capacity(shards.len()); for shard in shards { let (shard, region_store) = shard?; let mut clone = current_plan.clone(); diff --git a/src/store.rs b/src/store.rs index 44c56a04..11696582 100644 --- a/src/store.rs +++ b/src/store.rs @@ -8,7 +8,7 @@ use std::{ sync::Arc, }; use tikv_client_proto::kvrpcpb; -use tikv_client_store::{KvClient, KvConnect, TikvConnect}; +use tikv_client_store::KvClient; #[derive(new, Clone)] pub struct RegionStore { @@ -16,15 +16,15 @@ pub struct RegionStore { pub client: Arc, } -pub trait KvConnectStore: KvConnect { - fn connect_to_store(&self, region: RegionWithLeader, address: String) -> Result { - log::info!("connect to tikv endpoint: {:?}", &address); - let client = self.connect(address.as_str())?; - Ok(RegionStore::new(region, Arc::new(client))) - } -} +// pub trait KvConnectStore: KvConnect { +// fn connect_to_store(&self, region: RegionWithLeader, address: String) -> Result { +// log::info!("connect to tikv endpoint: {:?}", &address); +// let client = self.connect(address.as_str())?; +// Ok(RegionStore::new(region, Arc::new(client))) +// } +// } -impl KvConnectStore for TikvConnect {} +// impl KvConnectStore for TikvConnect {} /// Maps keys to a stream of stores. `key_data` must be sorted in increasing order pub fn store_stream_for_keys( diff --git a/src/transaction/client.rs b/src/transaction/client.rs index 69f11bce..eac4eb45 100644 --- a/src/transaction/client.rs +++ b/src/transaction/client.rs @@ -13,6 +13,7 @@ use crate::{ use slog::{Drain, Logger}; use std::{mem, sync::Arc}; use tikv_client_proto::{kvrpcpb, pdpb::Timestamp}; +use tikv_client_store::KvClientConfig; // FIXME: cargo-culted value const SCAN_LOCK_BATCH_SIZE: u32 = 1024; @@ -36,6 +37,7 @@ const SCAN_LOCK_BATCH_SIZE: u32 = 1024; pub struct Client { pd: Arc, logger: Logger, + kv_config: KvClientConfig, } impl Clone for Client { @@ -43,6 +45,7 @@ impl Clone for Client { Self { pd: self.pd.clone(), logger: self.logger.clone(), + kv_config: self.kv_config.clone(), } } } @@ -112,8 +115,14 @@ impl Client { }); debug!(logger, "creating new transactional client"); let pd_endpoints: Vec = pd_endpoints.into_iter().map(Into::into).collect(); - let pd = Arc::new(PdRpcClient::connect(&pd_endpoints, config, true, logger.clone()).await?); - Ok(Client { pd, logger }) + let pd = Arc::new( + PdRpcClient::connect(&pd_endpoints, config.clone(), true, logger.clone()).await?, + ); + Ok(Client { + pd, + logger, + kv_config: config.kv_config, + }) } /// Creates a new optimistic [`Transaction`]. diff --git a/src/transaction/lock.rs b/src/transaction/lock.rs index 871effc6..05900284 100644 --- a/src/transaction/lock.rs +++ b/src/transaction/lock.rs @@ -151,15 +151,15 @@ mod tests { fail::cfg("region-error", "9*return").unwrap(); let client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( - |_: &dyn Any| { + |_: &(dyn Any + Send)| { fail::fail_point!("region-error", |_| { let resp = kvrpcpb::ResolveLockResponse { region_error: Some(errorpb::Error::default()), ..Default::default() }; - Ok(Box::new(resp) as Box) + Ok(Box::new(resp) as Box) }); - Ok(Box::new(kvrpcpb::ResolveLockResponse::default()) as Box) + Ok(Box::new(kvrpcpb::ResolveLockResponse::default()) as Box) }, ))); diff --git a/src/transaction/transaction.rs b/src/transaction/transaction.rs index 30f9602a..7ee5049f 100644 --- a/src/transaction/transaction.rs +++ b/src/transaction/transaction.rs @@ -1385,14 +1385,14 @@ mod tests { let heartbeats = Arc::new(AtomicUsize::new(0)); let heartbeats_cloned = heartbeats.clone(); let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( - move |req: &dyn Any| { + move |req: &(dyn Any + Send)| { if req.downcast_ref::().is_some() { heartbeats_cloned.fetch_add(1, Ordering::SeqCst); - Ok(Box::new(kvrpcpb::TxnHeartBeatResponse::default()) as Box) + Ok(Box::new(kvrpcpb::TxnHeartBeatResponse::default()) as Box) } else if req.downcast_ref::().is_some() { - Ok(Box::new(kvrpcpb::PrewriteResponse::default()) as Box) + Ok(Box::new(kvrpcpb::PrewriteResponse::default()) as Box) } else { - Ok(Box::new(kvrpcpb::CommitResponse::default()) as Box) + Ok(Box::new(kvrpcpb::CommitResponse::default()) as Box) } }, ))); @@ -1429,19 +1429,20 @@ mod tests { let heartbeats = Arc::new(AtomicUsize::new(0)); let heartbeats_cloned = heartbeats.clone(); let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( - move |req: &dyn Any| { + move |req: &(dyn Any + Send)| { if req.downcast_ref::().is_some() { heartbeats_cloned.fetch_add(1, Ordering::SeqCst); - Ok(Box::new(kvrpcpb::TxnHeartBeatResponse::default()) as Box) + Ok(Box::new(kvrpcpb::TxnHeartBeatResponse::default()) as Box) } else if req.downcast_ref::().is_some() { - Ok(Box::new(kvrpcpb::PrewriteResponse::default()) as Box) + Ok(Box::new(kvrpcpb::PrewriteResponse::default()) as Box) } else if req .downcast_ref::() .is_some() { - Ok(Box::new(kvrpcpb::PessimisticLockResponse::default()) as Box) + Ok(Box::new(kvrpcpb::PessimisticLockResponse::default()) + as Box) } else { - Ok(Box::new(kvrpcpb::CommitResponse::default()) as Box) + Ok(Box::new(kvrpcpb::CommitResponse::default()) as Box) } }, ))); diff --git a/tikv-client-common/src/security.rs b/tikv-client-common/src/security.rs index a89b2fca..37aa480e 100644 --- a/tikv-client-common/src/security.rs +++ b/tikv-client-common/src/security.rs @@ -67,6 +67,8 @@ impl SecurityManager { &self, env: Arc, addr: &str, + keepalive: u64, + keepalive_timeout: u64, factory: Factory, ) -> Result where @@ -77,8 +79,8 @@ impl SecurityManager { let addr = SCHEME_REG.replace(addr, ""); let cb = ChannelBuilder::new(env) - .keepalive_time(Duration::from_secs(10)) - .keepalive_timeout(Duration::from_secs(3)) + .keepalive_time(Duration::from_millis(keepalive)) + .keepalive_timeout(Duration::from_millis(keepalive_timeout)) .use_local_subchannel_pool(true); let channel = if self.ca.is_empty() { diff --git a/tikv-client-pd/src/cluster.rs b/tikv-client-pd/src/cluster.rs index 063532e9..a4e49eb2 100644 --- a/tikv-client-pd/src/cluster.rs +++ b/tikv-client-pd/src/cluster.rs @@ -180,9 +180,9 @@ impl Connection { addr: &str, timeout: Duration, ) -> Result<(pdpb::PdClient, pdpb::GetMembersResponse)> { - let client = self - .security_mgr - .connect(self.env.clone(), addr, pdpb::PdClient::new)?; + let client = + self.security_mgr + .connect(self.env.clone(), addr, 10000, 2000, pdpb::PdClient::new)?; let option = CallOption::default().timeout(timeout); let resp = client .get_members_async_opt(&pdpb::GetMembersRequest::default(), option) diff --git a/tikv-client-store/Cargo.toml b/tikv-client-store/Cargo.toml index efdc18c5..08891bc1 100644 --- a/tikv-client-store/Cargo.toml +++ b/tikv-client-store/Cargo.toml @@ -13,5 +13,8 @@ derive-new = "0.5" futures = { version = "0.3", features = ["compat", "async-await", "thread-pool"] } grpcio = { version = "0.10", features = [ "prost-codec" ], default-features = false } log = "0.4" +serde = "1.0" +serde_derive = "1.0" +tokio = { version = "1", features = [ "sync", "rt-multi-thread", "macros" ] } tikv-client-common = { version = "0.1.0", path = "../tikv-client-common" } tikv-client-proto = { version = "0.1.0", path = "../tikv-client-proto" } diff --git a/tikv-client-store/src/batch.rs b/tikv-client-store/src/batch.rs new file mode 100644 index 00000000..81977fd0 --- /dev/null +++ b/tikv-client-store/src/batch.rs @@ -0,0 +1,288 @@ +use crate::{request::from_batch_commands_resp, Error, Request, Result}; +use core::any::Any; +use futures::{ + channel::{mpsc, oneshot}, + executor::block_on, + join, pin_mut, + prelude::*, + task::{AtomicWaker, Context, Poll}, +}; +use grpcio::{CallOption, WriteFlags}; +use log::debug; +use std::{ + cell::RefCell, + collections::HashMap, + pin::Pin, + rc::Rc, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, + }, + thread, + time::{Duration, Instant}, +}; +use tikv_client_common::internal_err; +use tikv_client_proto::tikvpb::{BatchCommandsRequest, BatchCommandsResponse, TikvClient}; + +static ID_ALLOC: AtomicU64 = AtomicU64::new(0); +const BATCH_WORKER_NAME: &str = "batch-worker"; + +type Response = oneshot::Sender>; +pub struct RequestEntry { + cmd: Box, + tx: Response, + transport_layer_load: u64, + id: u64, +} + +impl RequestEntry { + pub fn new(cmd: Box, tx: Response, transport_layer_load: u64) -> Self { + Self { + cmd, + tx, + transport_layer_load, + id: ID_ALLOC.fetch_add(1, Ordering::Relaxed), + } + } +} + +/// BatchWorker provides request in batch and return the result in batch. +#[derive(Clone)] +pub struct BatchWorker { + request_tx: mpsc::Sender, + last_transport_layer_load_report: Arc, + is_running: Arc, + max_batch_size: usize, + max_inflight_requests: usize, + max_delay_duration: u64, + overload_threshold: u64, + options: CallOption, +} + +impl BatchWorker { + pub fn new( + kv_client: Arc, + max_batch_size: usize, + max_inflight_requests: usize, + max_delay_duration: u64, + overload_threshold: u64, + options: CallOption, + ) -> Result { + let (request_tx, request_rx) = mpsc::channel(max_inflight_requests); + + // Create rpc sender and receiver + let (rpc_sender, rpc_receiver) = kv_client.batch_commands_opt(options.clone())?; + + let last_transport_layer_load_report = Arc::new(AtomicU64::new(0)); + let is_running_status = Arc::new(AtomicBool::new(true)); + let is_running_status_cloned = is_running_status.clone(); + + // Start a background thread to handle batch requests and responses + let last_transport_layer_load_report_clone = last_transport_layer_load_report.clone(); + thread::Builder::new() + .name(BATCH_WORKER_NAME.to_owned()) + .spawn(move || { + block_on(run_batch_worker( + rpc_sender.sink_err_into(), + rpc_receiver.err_into(), + is_running_status_cloned, + request_rx, + max_batch_size, + max_inflight_requests, + max_delay_duration, + overload_threshold, + last_transport_layer_load_report_clone, + )) + }) + .unwrap(); + + Ok(BatchWorker { + request_tx, + last_transport_layer_load_report, + is_running: is_running_status, + max_batch_size, + max_inflight_requests, + max_delay_duration, + overload_threshold, + options, + }) + } + + pub fn is_running(&self) -> bool { + self.is_running.load(Ordering::Relaxed) + } + + pub fn max_batch_size(&self) -> usize { + self.max_batch_size + } + + pub fn max_inflight_requests(&self) -> usize { + self.max_inflight_requests + } + + pub fn max_delay_duration(&self) -> u64 { + self.max_delay_duration + } + + pub fn overload_threshold(&self) -> u64 { + self.overload_threshold + } + + pub fn options(&self) -> CallOption { + self.options.clone() + } + + pub async fn dispatch(mut self, request: Box) -> Result> { + let (tx, rx) = oneshot::channel(); + // Generate BatchCommandRequestEntry + let last_transport_layer_load = self + .last_transport_layer_load_report + .load(Ordering::Relaxed); + + // Save the load of transport layer in RequestEntry + let entry = RequestEntry::new(request, tx, last_transport_layer_load); + // Send request entry to the background thread to handle the request, response will be + // received in rx channel. + self.request_tx + .send(entry) + .await + .map_err(|_| internal_err!("Failed to send request to batch worker".to_owned()))?; + rx.await + .map_err(|_| internal_err!("Failed to receive response from batch worker".to_owned())) + } +} + +#[allow(clippy::too_many_arguments)] +async fn run_batch_worker( + mut tx: impl Sink<(BatchCommandsRequest, WriteFlags), Error = Error> + Unpin, + mut rx: impl Stream> + Unpin, + is_running: Arc, + request_rx: mpsc::Receiver, + max_batch_size: usize, + max_inflight_requests: usize, + max_delay_duration: u64, + overload_threshold: u64, + last_transport_layer_load_report: Arc, +) { + // Inflight requests which are waiting for the response from rpc server + let inflight_requests = Rc::new(RefCell::new(HashMap::new())); + + let waker = Rc::new(AtomicWaker::new()); + + pin_mut!(request_rx); + let mut request_stream = BatchCommandsRequestStream { + request_rx, + inflight_requests: inflight_requests.clone(), + self_waker: waker.clone(), + max_batch_size, + max_inflight_requests, + max_delay_duration, + overload_threshold, + } + .map(Ok); + + let send_requests = tx.send_all(&mut request_stream); + + let recv_handle_response = async move { + while let Some(Ok(mut batch_resp)) = rx.next().await { + let mut inflight_requests = inflight_requests.borrow_mut(); + + if inflight_requests.len() == max_inflight_requests { + waker.wake(); + } + + let trasport_layer_load = batch_resp.get_transport_layer_load(); + // Store the load of transport layer + last_transport_layer_load_report.store(trasport_layer_load, Ordering::Relaxed); + + for (id, resp) in batch_resp + .take_request_ids() + .into_iter() + .zip(batch_resp.take_responses()) + { + if let Some(tx) = inflight_requests.remove(&id) { + let inner_resp = from_batch_commands_resp(resp); + debug!("Received response for request_id {}", id); + tx.send(inner_resp.unwrap()).unwrap(); + } + } + } + }; + + let (tx_res, rx_res) = join!(send_requests, recv_handle_response); + + is_running.store(false, Ordering::Relaxed); + + debug!("Batch sender finished: {:?}", tx_res); + debug!("Batch receiver finished: {:?}", rx_res); +} + +struct BatchCommandsRequestStream<'a> { + request_rx: Pin<&'a mut mpsc::Receiver>, + inflight_requests: Rc>>, + self_waker: Rc, + max_batch_size: usize, + max_inflight_requests: usize, + max_delay_duration: u64, + overload_threshold: u64, +} + +impl Stream for BatchCommandsRequestStream<'_> { + type Item = (BatchCommandsRequest, WriteFlags); + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let inflight_requests = self.inflight_requests.clone(); + let mut inflight_requests = inflight_requests.borrow_mut(); + if inflight_requests.len() == 0 { + self.self_waker.register(cx.waker()); + } + + // Collect user requests + let mut requests = vec![]; + let mut request_ids = vec![]; + let latency_timer = Instant::now(); + while requests.len() < self.max_batch_size + && inflight_requests.len() < self.max_inflight_requests + { + // We can not wait longger than max_deplay_duration + if latency_timer.elapsed() > Duration::from_millis(self.max_delay_duration) { + break; + } + + match self.request_rx.as_mut().poll_next(cx) { + Poll::Ready(Some(entry)) => { + inflight_requests.insert(entry.id, entry.tx); + requests.push(entry.cmd.to_batch_request()); + request_ids.push(entry.id); + + // Check the transport layer load received in RequestEntry + let load_reported = entry.transport_layer_load; + if load_reported > 0 + && self.overload_threshold > 0 + && load_reported > self.overload_threshold + { + break; + } + } + Poll::Ready(None) => { + return Poll::Ready(None); + } + Poll::Pending => { + break; + } + } + } + + // The requests is the commands will be convert to a batch request + if !requests.is_empty() { + let mut batch_request = BatchCommandsRequest::new_(); + batch_request.set_requests(requests); + batch_request.set_request_ids(request_ids); + let write_flags = WriteFlags::default().buffer_hint(false); + Poll::Ready(Some((batch_request, write_flags))) + } else { + self.self_waker.register(cx.waker()); + Poll::Pending + } + } +} diff --git a/tikv-client-store/src/client.rs b/tikv-client-store/src/client.rs index 5e3534c5..cdf0573c 100644 --- a/tikv-client-store/src/client.rs +++ b/tikv-client-store/src/client.rs @@ -1,17 +1,58 @@ // Copyright 2020 TiKV Project Authors. Licensed under Apache-2.0. -use crate::{request::Request, Result, SecurityManager}; +use crate::{batch::BatchWorker, request::Request, Result, SecurityManager}; use async_trait::async_trait; use derive_new::new; use grpcio::{CallOption, Environment}; +use serde_derive::{Deserialize, Serialize}; use std::{any::Any, sync::Arc, time::Duration}; use tikv_client_proto::tikvpb::TikvClient; +use tokio::sync::RwLock; +const DEFAULT_REQUEST_TIMEOUT: u64 = 2000; +const DEFAULT_GRPC_KEEPALIVE_TIME: u64 = 10000; +const DEFAULT_GRPC_KEEPALIVE_TIMEOUT: u64 = 3000; +const DEFAULT_GRPC_COMPLETION_QUEUE_SIZE: usize = 1; +const DEFAULT_MAX_BATCH_WAIT_TIME: u64 = 10; +const DEFAULT_MAX_BATCH_SIZE: usize = 10; +const DEFAULT_MAX_INFLIGHT_REQUESTS: usize = 10000; +const DEFAULT_OVERLOAD_THRESHOLD: u64 = 1000; /// A trait for connecting to TiKV stores. pub trait KvConnect: Sized + Send + Sync + 'static { type KvClient: KvClient + Clone + Send + Sync + 'static; - fn connect(&self, address: &str) -> Result; + fn connect(&self, address: &str, kv_config: KvClientConfig) -> Result; +} + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(default)] +#[serde(rename_all = "kebab-case")] +pub struct KvClientConfig { + pub request_timeout: u64, + pub completion_queue_size: usize, + pub grpc_keepalive_time: u64, + pub grpc_keepalive_timeout: u64, + pub allow_batch: bool, + pub overload_threshold: u64, + pub max_batch_wait_time: u64, + pub max_batch_size: usize, + pub max_inflight_requests: usize, +} + +impl Default for KvClientConfig { + fn default() -> Self { + Self { + request_timeout: DEFAULT_REQUEST_TIMEOUT, + completion_queue_size: DEFAULT_GRPC_COMPLETION_QUEUE_SIZE, + grpc_keepalive_time: DEFAULT_GRPC_KEEPALIVE_TIME, + grpc_keepalive_timeout: DEFAULT_GRPC_KEEPALIVE_TIMEOUT, + allow_batch: false, + overload_threshold: DEFAULT_OVERLOAD_THRESHOLD, + max_batch_wait_time: DEFAULT_MAX_BATCH_WAIT_TIME, + max_batch_size: DEFAULT_MAX_BATCH_SIZE, + max_inflight_requests: DEFAULT_MAX_INFLIGHT_REQUESTS, + } + } } #[derive(new, Clone)] @@ -24,16 +65,41 @@ pub struct TikvConnect { impl KvConnect for TikvConnect { type KvClient = KvRpcClient; - fn connect(&self, address: &str) -> Result { + fn connect(&self, address: &str, kv_config: KvClientConfig) -> Result { self.security_mgr - .connect(self.env.clone(), address, TikvClient::new) - .map(|c| KvRpcClient::new(Arc::new(c), self.timeout)) + .connect( + self.env.clone(), + address, + kv_config.grpc_keepalive_time, + kv_config.grpc_keepalive_timeout, + TikvClient::new, + ) + .map(|c| { + // Create batch worker if needed + let c = Arc::new(c); + let batch_worker = if kv_config.allow_batch { + Some(Arc::new(RwLock::new( + BatchWorker::new( + c.clone(), + kv_config.max_batch_size, + kv_config.max_inflight_requests, + kv_config.max_batch_wait_time, + kv_config.overload_threshold, + CallOption::default(), + ) + .unwrap(), + ))) + } else { + None + }; + KvRpcClient::new(c, self.timeout, batch_worker) + }) } } #[async_trait] pub trait KvClient { - async fn dispatch(&self, req: &dyn Request) -> Result>; + async fn dispatch(&self, req: Box) -> Result>; } /// This client handles requests for a single TiKV node. It converts the data @@ -42,16 +108,39 @@ pub trait KvClient { pub struct KvRpcClient { rpc_client: Arc, timeout: Duration, + batch_worker: Option>>, } #[async_trait] impl KvClient for KvRpcClient { - async fn dispatch(&self, request: &dyn Request) -> Result> { - request - .dispatch( - &self.rpc_client, - CallOption::default().timeout(self.timeout), + async fn dispatch(&self, request: Box) -> Result> { + if let Some(batch_worker_arc) = self.batch_worker.clone() && request.support_batch(){ + let batch_worker = batch_worker_arc.read().await; + if batch_worker.is_running() { + return batch_worker.clone().dispatch(request).await; + } + drop(batch_worker); + + let mut batch_worker = batch_worker_arc.write().await; + // batch worker is not running, because of gRPC channel is broken, create a new one + *batch_worker = BatchWorker::new( + self.rpc_client.clone(), + batch_worker.max_batch_size(), + batch_worker.max_inflight_requests(), + batch_worker.max_delay_duration(), + batch_worker.overload_threshold(), + batch_worker.options(), ) - .await + .unwrap(); + batch_worker.clone().dispatch(request).await + } else { + // Batch no needed if not batch enabled + request + .dispatch( + &self.rpc_client, + CallOption::default().timeout(self.timeout), + ) + .await + } } } diff --git a/tikv-client-store/src/lib.rs b/tikv-client-store/src/lib.rs index 5df938ff..994818db 100644 --- a/tikv-client-store/src/lib.rs +++ b/tikv-client-store/src/lib.rs @@ -1,12 +1,14 @@ // Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0. +mod batch; mod client; mod errors; mod request; #[doc(inline)] pub use crate::{ - client::{KvClient, KvConnect, TikvConnect}, + batch::{BatchWorker, RequestEntry}, + client::{KvClient, KvClientConfig, KvConnect, TikvConnect}, errors::{HasKeyErrors, HasRegionError, HasRegionErrors}, request::Request, }; diff --git a/tikv-client-store/src/request.rs b/tikv-client-store/src/request.rs index 290f142a..e5531d90 100644 --- a/tikv-client-store/src/request.rs +++ b/tikv-client-store/src/request.rs @@ -4,29 +4,46 @@ use crate::{Error, Result}; use async_trait::async_trait; use grpcio::CallOption; use std::any::Any; -use tikv_client_proto::{kvrpcpb, tikvpb::TikvClient}; +use tikv_client_common::internal_err; +use tikv_client_proto::{ + kvrpcpb, + tikvpb::{ + batch_commands_request::{self, request::Cmd::*}, + batch_commands_response, TikvClient, + }, +}; #[async_trait] pub trait Request: Any + Sync + Send + 'static { - async fn dispatch(&self, client: &TikvClient, options: CallOption) -> Result>; + async fn dispatch( + &self, + client: &TikvClient, + options: CallOption, + ) -> Result>; fn label(&self) -> &'static str; - fn as_any(&self) -> &dyn Any; + fn as_any(&self) -> &(dyn Any + Send); fn set_context(&mut self, context: kvrpcpb::Context); + fn to_batch_request(&self) -> batch_commands_request::Request { + batch_commands_request::Request { cmd: None } + } + fn support_batch(&self) -> bool { + false + } } macro_rules! impl_request { - ($name: ident, $fun: ident, $label: literal) => { + ($name: ident, $fun: ident, $label: literal, $cmd: ident) => { #[async_trait] impl Request for kvrpcpb::$name { async fn dispatch( &self, client: &TikvClient, options: CallOption, - ) -> Result> { + ) -> Result> { client .$fun(self, options)? .await - .map(|r| Box::new(r) as Box) + .map(|r| Box::new(r) as Box) .map_err(Error::Grpc) } @@ -34,94 +51,262 @@ macro_rules! impl_request { $label } - fn as_any(&self) -> &dyn Any { + fn as_any(&self) -> &(dyn Any + Send) { self } fn set_context(&mut self, context: kvrpcpb::Context) { kvrpcpb::$name::set_context(self, context) } + + fn to_batch_request(&self) -> batch_commands_request::Request { + let req = batch_commands_request::Request { + cmd: Some($cmd(self.clone())), + }; + req + } + + fn support_batch(&self) -> bool { + true + } } }; } -impl_request!(RawGetRequest, raw_get_async_opt, "raw_get"); -impl_request!(RawBatchGetRequest, raw_batch_get_async_opt, "raw_batch_get"); -impl_request!(RawPutRequest, raw_put_async_opt, "raw_put"); -impl_request!(RawBatchPutRequest, raw_batch_put_async_opt, "raw_batch_put"); -impl_request!(RawDeleteRequest, raw_delete_async_opt, "raw_delete"); +impl_request!(RawGetRequest, raw_get_async_opt, "raw_get", RawGet); +impl_request!( + RawBatchGetRequest, + raw_batch_get_async_opt, + "raw_batch_get", + RawBatchGet +); +impl_request!(RawPutRequest, raw_put_async_opt, "raw_put", RawPut); +impl_request!( + RawBatchPutRequest, + raw_batch_put_async_opt, + "raw_batch_put", + RawBatchPut +); +impl_request!( + RawDeleteRequest, + raw_delete_async_opt, + "raw_delete", + RawDelete +); impl_request!( RawBatchDeleteRequest, raw_batch_delete_async_opt, - "raw_batch_delete" + "raw_batch_delete", + RawBatchDelete ); -impl_request!(RawScanRequest, raw_scan_async_opt, "raw_scan"); +impl_request!(RawScanRequest, raw_scan_async_opt, "raw_scan", RawScan); impl_request!( RawBatchScanRequest, raw_batch_scan_async_opt, - "raw_batch_scan" + "raw_batch_scan", + RawBatchScan ); impl_request!( RawDeleteRangeRequest, raw_delete_range_async_opt, - "raw_delete_range" -); -impl_request!( - RawCasRequest, - raw_compare_and_swap_async_opt, - "raw_compare_and_swap" + "raw_delete_range", + RawDeleteRange ); + impl_request!( RawCoprocessorRequest, raw_coprocessor_async_opt, - "raw_coprocessor" + "raw_coprocessor", + RawCoprocessor ); -impl_request!(GetRequest, kv_get_async_opt, "kv_get"); -impl_request!(ScanRequest, kv_scan_async_opt, "kv_scan"); -impl_request!(PrewriteRequest, kv_prewrite_async_opt, "kv_prewrite"); -impl_request!(CommitRequest, kv_commit_async_opt, "kv_commit"); -impl_request!(CleanupRequest, kv_cleanup_async_opt, "kv_cleanup"); -impl_request!(BatchGetRequest, kv_batch_get_async_opt, "kv_batch_get"); +impl_request!(GetRequest, kv_get_async_opt, "kv_get", Get); +impl_request!(ScanRequest, kv_scan_async_opt, "kv_scan", Scan); +impl_request!( + PrewriteRequest, + kv_prewrite_async_opt, + "kv_prewrite", + Prewrite +); +impl_request!(CommitRequest, kv_commit_async_opt, "kv_commit", Commit); +impl_request!(CleanupRequest, kv_cleanup_async_opt, "kv_cleanup", Cleanup); +impl_request!( + BatchGetRequest, + kv_batch_get_async_opt, + "kv_batch_get", + BatchGet +); impl_request!( BatchRollbackRequest, kv_batch_rollback_async_opt, - "kv_batch_rollback" + "kv_batch_rollback", + BatchRollback ); impl_request!( PessimisticRollbackRequest, kv_pessimistic_rollback_async_opt, - "kv_pessimistic_rollback" + "kv_pessimistic_rollback", + PessimisticRollback ); impl_request!( ResolveLockRequest, kv_resolve_lock_async_opt, - "kv_resolve_lock" + "kv_resolve_lock", + ResolveLock +); +impl_request!( + ScanLockRequest, + kv_scan_lock_async_opt, + "kv_scan_lock", + ScanLock ); -impl_request!(ScanLockRequest, kv_scan_lock_async_opt, "kv_scan_lock"); impl_request!( PessimisticLockRequest, kv_pessimistic_lock_async_opt, - "kv_pessimistic_lock" + "kv_pessimistic_lock", + PessimisticLock ); impl_request!( TxnHeartBeatRequest, kv_txn_heart_beat_async_opt, - "kv_txn_heart_beat" + "kv_txn_heart_beat", + TxnHeartBeat ); impl_request!( CheckTxnStatusRequest, kv_check_txn_status_async_opt, - "kv_check_txn_status" + "kv_check_txn_status", + CheckTxnStatus ); impl_request!( CheckSecondaryLocksRequest, kv_check_secondary_locks_async_opt, - "kv_check_secondary_locks_request" + "kv_check_secondary_locks_request", + CheckSecondaryLocks ); -impl_request!(GcRequest, kv_gc_async_opt, "kv_gc"); +impl_request!(GcRequest, kv_gc_async_opt, "kv_gc", Gc); impl_request!( DeleteRangeRequest, kv_delete_range_async_opt, - "kv_delete_range" + "kv_delete_range", + DeleteRange ); + +#[async_trait] +impl Request for kvrpcpb::RawCasRequest { + async fn dispatch( + &self, + client: &TikvClient, + options: CallOption, + ) -> Result> { + client + .raw_compare_and_swap_async_opt(self, options)? + .await + .map(|r| Box::new(r) as Box) + .map_err(Error::Grpc) + } + fn label(&self) -> &'static str { + "raw_compare_and_swap" + } + fn as_any(&self) -> &(dyn Any + Send) { + self + } + fn set_context(&mut self, _: tikv_client_proto::kvrpcpb::Context) { + todo!() + } + fn to_batch_request(&self) -> batch_commands_request::Request { + batch_commands_request::Request { cmd: None } + } +} + +pub fn from_batch_commands_resp( + resp: batch_commands_response::Response, +) -> Result> { + match resp.cmd { + Some(batch_commands_response::response::Cmd::Get(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::Scan(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::Prewrite(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::Commit(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::Import(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::Cleanup(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::BatchGet(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::BatchRollback(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::ScanLock(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::ResolveLock(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::Gc(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::DeleteRange(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawGet(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawBatchGet(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawPut(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawBatchPut(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawDelete(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawBatchDelete(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawScan(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawDeleteRange(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawBatchScan(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::Coprocessor(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::PessimisticLock(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::PessimisticRollback(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::CheckTxnStatus(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::TxnHeartBeat(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::CheckSecondaryLocks(cmd)) => { + Ok(Box::new(cmd) as Box) + } + Some(batch_commands_response::response::Cmd::RawCoprocessor(cmd)) => { + Ok(Box::new(cmd) as Box) + } + _ => Err(internal_err!("batch_commands_resp.cmd is None".to_owned())), + } +}