From 767e9304c546c001e3a18f32afafc804d38b088f Mon Sep 17 00:00:00 2001 From: limbooverlambda Date: Mon, 17 Jun 2024 14:08:33 -0700 Subject: [PATCH] fixing the shard issue with batch_put Signed-off-by: limbooverlambda --- src/kv/kvpair.rs | 2 +- src/raw/client.rs | 30 +++++++++++++ src/raw/requests.rs | 100 ++++++++++++++++++++++++++++++++++++++++++-- src/store/mod.rs | 2 +- 4 files changed, 129 insertions(+), 5 deletions(-) diff --git a/src/kv/kvpair.rs b/src/kv/kvpair.rs index cfc6ee1c..f609f230 100644 --- a/src/kv/kvpair.rs +++ b/src/kv/kvpair.rs @@ -25,7 +25,7 @@ use crate::proto::kvrpcpb; /// /// Many functions which accept a `KvPair` accept an `Into`, which means all of the above /// types (Like a `(Key, Value)`) can be passed directly to those functions. -#[derive(Default, Clone, Eq, PartialEq)] +#[derive(Default, Clone, Eq, PartialEq, Hash)] #[cfg_attr(test, derive(Arbitrary))] pub struct KvPair(pub Key, pub Value); diff --git a/src/raw/client.rs b/src/raw/client.rs index 71d40b2a..e885b4ec 100644 --- a/src/raw/client.rs +++ b/src/raw/client.rs @@ -876,6 +876,36 @@ mod tests { use crate::proto::kvrpcpb; use crate::Result; + #[tokio::test] + async fn test_batch_put_with_ttl() -> Result<()> { + let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( + move |req: &dyn Any| { + if let Some(_) = req.downcast_ref::() { + let resp = kvrpcpb::RawBatchPutResponse { + ..Default::default() + }; + Ok(Box::new(resp) as Box) + } else { + unreachable!() + } + }, + ))); + let client = Client { + rpc: pd_client, + cf: Some(ColumnFamily::Default), + backoff: DEFAULT_REGION_BACKOFF, + atomic: false, + keyspace: Keyspace::Enable { keyspace_id: 0 }, + }; + let pairs = vec![ + KvPair(vec![11].into(), vec![12].into()), + KvPair(vec![11].into(), vec![12].into()), + ]; + let ttls = vec![0, 0]; + assert!(client.batch_put_with_ttl(pairs, ttls).await.is_ok()); + Ok(()) + } + #[tokio::test] async fn test_raw_coprocessor() -> Result<()> { let pd_client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( diff --git a/src/raw/requests.rs b/src/raw/requests.rs index 201ac657..f6e876d4 100644 --- a/src/raw/requests.rs +++ b/src/raw/requests.rs @@ -1,12 +1,14 @@ // Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0. use std::any::Any; +use std::collections::HashMap; use std::ops::Range; use std::sync::Arc; use std::time::Duration; use async_trait::async_trait; use futures::stream::BoxStream; +use futures::StreamExt; use tonic::transport::Channel; use super::RawRpcRequest; @@ -190,23 +192,44 @@ impl KvRequest for kvrpcpb::RawBatchPutRequest { } impl Shardable for kvrpcpb::RawBatchPutRequest { - type Shard = Vec; + type Shard = Vec<(kvrpcpb::KvPair, u64)>; fn shards( &self, pd_client: &Arc, ) -> BoxStream<'static, Result<(Self::Shard, RegionStore)>> { + // Maintain a map of the pair and its associated ttl + let kvs = self.pairs.clone(); + let kv_pair = kvs.into_iter().map(Into::::into); + let kv_ttl = kv_pair.zip(self.ttls.clone()).collect::>(); let mut pairs = self.pairs.clone(); pairs.sort_by(|a, b| a.key.cmp(&b.key)); store_stream_for_keys( pairs.into_iter().map(Into::::into), pd_client.clone(), ) + .map(move |r| { + let s = r.map(|(kv, store)| { + let kv_ttls = kv + .into_iter() + .map(|k: KvPair| { + let kv: kvrpcpb::KvPair = k.clone().into(); + let ttl = *kv_ttl.get(&k).unwrap(); + (kv, ttl) + }) + .collect::>(); + (kv_ttls, store) + }); + s + }) + .boxed() } fn apply_shard(&mut self, shard: Self::Shard, store: &RegionStore) -> Result<()> { + let (pairs, ttls) = shard.into_iter().unzip(); self.set_leader(&store.region_with_leader)?; - self.pairs = shard; + self.pairs = pairs; + self.ttls = ttls; Ok(()) } } @@ -531,21 +554,34 @@ impl_raw_rpc_request!(RawDeleteRangeRequest); impl_raw_rpc_request!(RawCasRequest); impl HasLocks for kvrpcpb::RawGetResponse {} + impl HasLocks for kvrpcpb::RawBatchGetResponse {} + impl HasLocks for kvrpcpb::RawGetKeyTtlResponse {} + impl HasLocks for kvrpcpb::RawPutResponse {} + impl HasLocks for kvrpcpb::RawBatchPutResponse {} + impl HasLocks for kvrpcpb::RawDeleteResponse {} + impl HasLocks for kvrpcpb::RawBatchDeleteResponse {} + impl HasLocks for kvrpcpb::RawScanResponse {} + impl HasLocks for kvrpcpb::RawBatchScanResponse {} + impl HasLocks for kvrpcpb::RawDeleteRangeResponse {} + impl HasLocks for kvrpcpb::RawCasResponse {} + impl HasLocks for kvrpcpb::RawCoprocessorResponse {} #[cfg(test)] mod test { use std::any::Any; + use std::ops::Deref; + use std::sync::Mutex; use super::*; use crate::backoff::DEFAULT_REGION_BACKOFF; @@ -555,7 +591,7 @@ mod test { use crate::proto::kvrpcpb; use crate::request::Keyspace; use crate::request::Plan; - use crate::Key; + #[rstest::rstest] #[case(Keyspace::Disable)] @@ -600,4 +636,62 @@ mod test { assert_eq!(scan.len(), 49); // FIXME test the keys returned. } + + #[tokio::test] + async fn test_raw_batch_put() -> Result<()> { + let region1_kvs = vec![KvPair(vec![9].into(), vec![12].into())]; + let region1_ttls = vec![0]; + let region2_kvs = vec![ + KvPair(vec![11].into(), vec![12].into()), + KvPair( + "FFF".to_string().as_bytes().to_vec().into(), + vec![12].into(), + ), + ]; + let region2_ttls = vec![0, 1]; + + let expected_map = HashMap::from([ + (region1_kvs.clone(), region1_ttls.clone()), + (region2_kvs.clone(), region2_ttls.clone()), + ]); + + let pairs: Vec = [region1_kvs, region2_kvs] + .concat() + .into_iter() + .map(|kv| kv.into()) + .collect(); + let ttls = [region1_ttls, region2_ttls].concat(); + let cf = ColumnFamily::Default; + + let actual_map: Arc, Vec>>> = + Arc::new(Mutex::new(HashMap::new())); + let fut_actual_map = actual_map.clone(); + let client = Arc::new(MockPdClient::new(MockKvClient::with_dispatch_hook( + move |req: &dyn Any| { + let req: &kvrpcpb::RawBatchPutRequest = req.downcast_ref().unwrap(); + let kv_pair = req + .pairs + .clone() + .into_iter() + .map(|p| p.into()) + .collect::>(); + let ttls = req.ttls.clone(); + fut_actual_map.lock().unwrap().insert(kv_pair, ttls); + let resp = kvrpcpb::RawBatchPutResponse::default(); + Ok(Box::new(resp) as Box) + }, + ))); + + let batch_put_request = + new_raw_batch_put_request(pairs.clone(), ttls.clone(), Some(cf), false); + let keyspace = Keyspace::Enable { keyspace_id: 0 }; + let plan = crate::request::PlanBuilder::new(client, keyspace, batch_put_request) + .resolve_lock(OPTIMISTIC_BACKOFF, keyspace) + .retry_multi_region(DEFAULT_REGION_BACKOFF) + .plan(); + let _ = plan.execute().await; + assert_eq!(actual_map.lock().unwrap().len(), 2); + assert_eq!(actual_map.lock().unwrap().deref().clone(), expected_map); + Ok(()) + } } diff --git a/src/store/mod.rs b/src/store/mod.rs index f21373b4..b8381de3 100644 --- a/src/store/mod.rs +++ b/src/store/mod.rs @@ -36,7 +36,7 @@ pub struct RegionStore { pub struct Store { pub client: Arc, } - +#[allow(dead_code)] /// Maps keys to a stream of stores. `key_data` must be sorted in increasing order pub fn store_stream_for_keys( key_data: impl Iterator + Send + Sync + 'static,