diff --git a/src/kv/key.rs b/src/kv/key.rs index 94fe8a94..fa19d3f6 100644 --- a/src/kv/key.rs +++ b/src/kv/key.rs @@ -16,6 +16,7 @@ use super::HexRepr; use crate::kv::codec::BytesEncoder; use crate::kv::codec::{self}; use crate::proto::kvrpcpb; +use crate::proto::kvrpcpb::KvPair; const _PROPTEST_KEY_MAX: usize = 1024 * 2; // 2 KB @@ -79,6 +80,20 @@ impl AsRef for kvrpcpb::Mutation { } } +pub struct KvPairTTL(pub KvPair, pub u64); + +impl AsRef for KvPairTTL { + fn as_ref(&self) -> &Key { + self.0.key.as_ref() + } +} + +impl From for (KvPair, u64) { + fn from(value: KvPairTTL) -> Self { + (value.0, value.1) + } +} + impl Key { /// The empty key. pub const EMPTY: Self = Key(Vec::new()); diff --git a/src/kv/kvpair.rs b/src/kv/kvpair.rs index cfc6ee1c..f609f230 100644 --- a/src/kv/kvpair.rs +++ b/src/kv/kvpair.rs @@ -25,7 +25,7 @@ use crate::proto::kvrpcpb; /// /// Many functions which accept a `KvPair` accept an `Into`, which means all of the above /// types (Like a `(Key, Value)`) can be passed directly to those functions. -#[derive(Default, Clone, Eq, PartialEq)] +#[derive(Default, Clone, Eq, PartialEq, Hash)] #[cfg_attr(test, derive(Arbitrary))] pub struct KvPair(pub Key, pub Value); diff --git a/src/kv/mod.rs b/src/kv/mod.rs index d0958ee2..41da842e 100644 --- a/src/kv/mod.rs +++ b/src/kv/mod.rs @@ -10,6 +10,7 @@ mod value; pub use bound_range::BoundRange; pub use bound_range::IntoOwnedRange; pub use key::Key; +pub use key::KvPairTTL; pub use kvpair::KvPair; pub use value::Value; diff --git a/src/raw/client.rs b/src/raw/client.rs index 9a166278..76d40b65 100644 --- a/src/raw/client.rs +++ b/src/raw/client.rs @@ -875,6 +875,36 @@ mod tests { use crate::proto::kvrpcpb; use crate::Result; + #[tokio::test] + async fn test_batch_put_with_ttl() -> Result<()> { + let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( + move |req: &dyn Any| { + if req.downcast_ref::().is_some() { + let resp = kvrpcpb::RawBatchPutResponse { + ..Default::default() + }; + Ok(Box::new(resp) as Box) + } else { + unreachable!() + } + }, + ))); + let client = Client { + rpc: pd_client, + cf: Some(ColumnFamily::Default), + backoff: DEFAULT_REGION_BACKOFF, + atomic: false, + keyspace: Keyspace::Enable { keyspace_id: 0 }, + }; + let pairs = vec![ + KvPair(vec![11].into(), vec![12]), + KvPair(vec![11].into(), vec![12]), + ]; + let ttls = vec![0, 0]; + assert!(client.batch_put_with_ttl(pairs, ttls).await.is_ok()); + Ok(()) + } + #[tokio::test] async fn test_raw_coprocessor() -> Result<()> { let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( diff --git a/src/raw/requests.rs b/src/raw/requests.rs index 201ac657..4422c883 100644 --- a/src/raw/requests.rs +++ b/src/raw/requests.rs @@ -1,16 +1,8 @@ // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. -use std::any::Any; -use std::ops::Range; -use std::sync::Arc; -use std::time::Duration; - -use async_trait::async_trait; -use futures::stream::BoxStream; -use tonic::transport::Channel; - use super::RawRpcRequest; use crate::collect_single; +use crate::kv::KvPairTTL; use crate::pd::PdClient; use crate::proto::kvrpcpb; use crate::proto::metapb; @@ -41,6 +33,13 @@ use crate::Key; use crate::KvPair; use crate::Result; use crate::Value; +use async_trait::async_trait; +use futures::stream::BoxStream; +use std::any::Any; +use std::ops::Range; +use std::sync::Arc; +use std::time::Duration; +use tonic::transport::Channel; pub fn new_raw_get_request(key: Vec, cf: Option) -> kvrpcpb::RawGetRequest { let mut req = kvrpcpb::RawGetRequest::default(); @@ -190,23 +189,28 @@ impl KvRequest for kvrpcpb::RawBatchPutRequest { } impl Shardable for kvrpcpb::RawBatchPutRequest { - type Shard = Vec; + type Shard = Vec<(kvrpcpb::KvPair, u64)>; fn shards( &self, pd_client: &Arc, ) -> BoxStream<'static, Result<(Self::Shard, RegionStore)>> { - let mut pairs = self.pairs.clone(); - pairs.sort_by(|a, b| a.key.cmp(&b.key)); - store_stream_for_keys( - pairs.into_iter().map(Into::::into), - pd_client.clone(), - ) + let kvs = self.pairs.clone(); + let ttls = self.ttls.clone(); + let mut kv_ttl: Vec = kvs + .into_iter() + .zip(ttls) + .map(|(kv, ttl)| KvPairTTL(kv, ttl)) + .collect(); + kv_ttl.sort_by(|a, b| a.0.key.cmp(&b.0.key)); + store_stream_for_keys(kv_ttl.into_iter(), pd_client.clone()) } fn apply_shard(&mut self, shard: Self::Shard, store: &RegionStore) -> Result<()> { + let (pairs, ttls) = shard.into_iter().unzip(); self.set_leader(&store.region_with_leader)?; - self.pairs = shard; + self.pairs = pairs; + self.ttls = ttls; Ok(()) } } @@ -531,21 +535,35 @@ impl_raw_rpc_request!(RawDeleteRangeRequest); impl_raw_rpc_request!(RawCasRequest); impl HasLocks for kvrpcpb::RawGetResponse {} + impl HasLocks for kvrpcpb::RawBatchGetResponse {} + impl HasLocks for kvrpcpb::RawGetKeyTtlResponse {} + impl HasLocks for kvrpcpb::RawPutResponse {} + impl HasLocks for kvrpcpb::RawBatchPutResponse {} + impl HasLocks for kvrpcpb::RawDeleteResponse {} + impl HasLocks for kvrpcpb::RawBatchDeleteResponse {} + impl HasLocks for kvrpcpb::RawScanResponse {} + impl HasLocks for kvrpcpb::RawBatchScanResponse {} + impl HasLocks for kvrpcpb::RawDeleteRangeResponse {} + impl HasLocks for kvrpcpb::RawCasResponse {} + impl HasLocks for kvrpcpb::RawCoprocessorResponse {} #[cfg(test)] mod test { use std::any::Any; + use std::collections::HashMap; + use std::ops::Deref; + use std::sync::Mutex; use super::*; use crate::backoff::DEFAULT_REGION_BACKOFF; @@ -555,7 +573,6 @@ mod test { use crate::proto::kvrpcpb; use crate::request::Keyspace; use crate::request::Plan; - use crate::Key; #[rstest::rstest] #[case(Keyspace::Disable)] @@ -600,4 +617,58 @@ mod test { assert_eq!(scan.len(), 49); // FIXME test the keys returned. } + + #[tokio::test] + async fn test_raw_batch_put() -> Result<()> { + let region1_kvs = vec![KvPair(vec![9].into(), vec![12])]; + let region1_ttls = vec![0]; + let region2_kvs = vec![ + KvPair(vec![11].into(), vec![12]), + KvPair("FFF".to_string().as_bytes().to_vec().into(), vec![12]), + ]; + let region2_ttls = vec![0, 1]; + + let expected_map = HashMap::from([ + (region1_kvs.clone(), region1_ttls.clone()), + (region2_kvs.clone(), region2_ttls.clone()), + ]); + + let pairs: Vec = [region1_kvs, region2_kvs] + .concat() + .into_iter() + .map(|kv| kv.into()) + .collect(); + let ttls = [region1_ttls, region2_ttls].concat(); + let cf = ColumnFamily::Default; + + let actual_map: Arc, Vec>>> = + Arc::new(Mutex::new(HashMap::new())); + let fut_actual_map = actual_map.clone(); + let client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( + move |req: &dyn Any| { + let req: &kvrpcpb::RawBatchPutRequest = req.downcast_ref().unwrap(); + let kv_pair = req + .pairs + .clone() + .into_iter() + .map(|p| p.into()) + .collect::>(); + let ttls = req.ttls.clone(); + fut_actual_map.lock().unwrap().insert(kv_pair, ttls); + let resp = kvrpcpb::RawBatchPutResponse::default(); + Ok(Box::new(resp) as Box) + }, + ))); + + let batch_put_request = + new_raw_batch_put_request(pairs.clone(), ttls.clone(), Some(cf), false); + let keyspace = Keyspace::Enable { keyspace_id: 0 }; + let plan = crate::request::PlanBuilder::new(client, keyspace, batch_put_request) + .resolve_lock(OPTIMISTIC_BACKOFF, keyspace) + .retry_multi_region(DEFAULT_REGION_BACKOFF) + .plan(); + let _ = plan.execute().await; + assert_eq!(actual_map.lock().unwrap().deref(), &expected_map); + Ok(()) + } }