diff --git a/src/aws.rs b/src/aws.rs new file mode 100644 index 0000000..ec4d2fe --- /dev/null +++ b/src/aws.rs @@ -0,0 +1,152 @@ +pub mod client { + + use async_trait::async_trait; + use aws_config::meta::region::RegionProviderChain; + use aws_sdk_kinesis::config::Region; + use aws_sdk_kinesis::operation::get_records::GetRecordsOutput; + use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput; + use aws_sdk_kinesis::operation::list_shards::ListShardsOutput; + use aws_sdk_kinesis::primitives::DateTime; + use aws_sdk_kinesis::types::ShardIteratorType; + use aws_sdk_kinesis::{Client, Error}; + use chrono::Utc; + + #[derive(Clone, Debug)] + pub struct AwsKinesisClient { + client: Client, + } + + #[async_trait] + pub trait KinesisClient: Sync + Send + Clone + 'static { + async fn list_shards(&self, stream: &str) -> Result; + + async fn get_records(&self, shard_iterator: &str) -> Result; + + async fn get_shard_iterator_at_timestamp( + &self, + stream: &str, + shard_id: &str, + timestamp: &chrono::DateTime, + ) -> Result; + + async fn get_shard_iterator_at_sequence( + &self, + stream: &str, + shard_id: &str, + starting_sequence_number: &str, + ) -> Result; + + async fn get_shard_iterator_latest( + &self, + stream: &str, + shard_id: &str, + ) -> Result; + + fn get_region(&self) -> Option<&Region>; + + fn to_aws_datetime(timestamp: &chrono::DateTime) -> DateTime { + DateTime::from_millis(timestamp.timestamp_millis()) + } + } + + #[async_trait] + impl KinesisClient for AwsKinesisClient { + async fn list_shards(&self, stream: &str) -> Result { + self.client + .list_shards() + .stream_name(stream) + .send() + .await + .map_err(|e| e.into()) + } + + async fn get_records(&self, shard_iterator: &str) -> Result { + self.client + .get_records() + .shard_iterator(shard_iterator) + .send() + .await + .map_err(|e| e.into()) + } + + async fn get_shard_iterator_at_timestamp( + &self, + stream: &str, + shard_id: &str, + timestamp: &chrono::DateTime, + ) -> Result { + self.client + .get_shard_iterator() + .shard_iterator_type(ShardIteratorType::AtTimestamp) + .timestamp(Self::to_aws_datetime(timestamp)) + .stream_name(stream) + .shard_id(shard_id) + .send() + .await + .map_err(|e| e.into()) + } + + async fn get_shard_iterator_at_sequence( + &self, + stream: &str, + shard_id: &str, + starting_sequence_number: &str, + ) -> Result { + self.client + .get_shard_iterator() + .shard_iterator_type(ShardIteratorType::AtSequenceNumber) + .starting_sequence_number(starting_sequence_number) + .stream_name(stream) + .shard_id(shard_id) + .send() + .await + .map_err(|e| e.into()) + } + + async fn get_shard_iterator_latest( + &self, + stream: &str, + shard_id: &str, + ) -> Result { + self.client + .get_shard_iterator() + .shard_iterator_type(ShardIteratorType::Latest) + .stream_name(stream) + .shard_id(shard_id) + .send() + .await + .map_err(|e| e.into()) + } + + fn get_region(&self) -> Option<&Region> { + self.client.conf().region() + } + } + + pub async fn create_client( + region: Option, + endpoint_url: Option, + ) -> AwsKinesisClient { + let region_provider = RegionProviderChain::first_try(region.map(Region::new)) + .or_default_provider() + .or_else(Region::new("us-east-1")); + + let shared_config = { + let inner = aws_config::from_env().region(region_provider); + + let inner = if endpoint_url.is_some() { + inner.endpoint_url(endpoint_url.unwrap().as_str()) + } else { + inner + }; + + inner + } + .load() + .await; + + let client = Client::new(&shared_config); + + AwsKinesisClient { client } + } +} diff --git a/src/iterator.rs b/src/iterator.rs index d1b8ed4..246d391 100644 --- a/src/iterator.rs +++ b/src/iterator.rs @@ -1,8 +1,7 @@ +use crate::aws::client::KinesisClient; use async_trait::async_trait; use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput; -use aws_sdk_kinesis::primitives::DateTime; -use aws_sdk_kinesis::types::ShardIteratorType; -use aws_sdk_kinesis::{Client, Error}; +use aws_sdk_kinesis::Error; use chrono::Utc; #[async_trait] @@ -14,16 +13,15 @@ pub trait ShardIterator { ) -> Result; } -fn to_aws_datetime(timestamp: &chrono::DateTime) -> DateTime { - DateTime::from_millis(timestamp.timestamp_millis()) -} - -pub fn latest<'a>(client: &'a Client) -> Box { +pub fn latest<'a, K>(client: &'a K) -> Box +where + K: KinesisClient, +{ Box::new(LatestShardIterator { client }) } -pub fn at_sequence<'a>( - client: &'a Client, +pub fn at_sequence<'a, K: KinesisClient>( + client: &'a K, starting_sequence_number: &'a str, ) -> Box { Box::new(AtSequenceShardIterator { @@ -32,79 +30,62 @@ pub fn at_sequence<'a>( }) } -pub fn at_timestamp<'a>( - client: &'a Client, +pub fn at_timestamp<'a, K: KinesisClient>( + client: &'a K, timestamp: &'a chrono::DateTime, ) -> Box { Box::new(AtTimestampShardIterator { client, timestamp }) } -struct LatestShardIterator<'a> { - client: &'a Client, +struct LatestShardIterator<'a, K: KinesisClient> { + client: &'a K, } -struct AtSequenceShardIterator<'a> { - client: &'a Client, +struct AtSequenceShardIterator<'a, K: KinesisClient> { + client: &'a K, starting_sequence_number: &'a str, } -struct AtTimestampShardIterator<'a> { - client: &'a Client, +struct AtTimestampShardIterator<'a, K: KinesisClient> { + client: &'a K, timestamp: &'a chrono::DateTime, } #[async_trait] -impl ShardIterator for LatestShardIterator<'_> { +impl ShardIterator for LatestShardIterator<'_, K> { async fn iterator<'a>( &'a self, stream: &'a str, shard_id: &'a str, ) -> Result { self.client - .get_shard_iterator() - .shard_iterator_type(ShardIteratorType::Latest) - .stream_name(stream) - .shard_id(shard_id) - .send() + .get_shard_iterator_latest(stream, shard_id) .await - .map_err(|e| e.into()) } } #[async_trait] -impl ShardIterator for AtSequenceShardIterator<'_> { +impl ShardIterator for AtSequenceShardIterator<'_, K> { async fn iterator<'a>( &'a self, stream: &'a str, shard_id: &'a str, ) -> Result { self.client - .get_shard_iterator() - .shard_iterator_type(ShardIteratorType::AtSequenceNumber) - .starting_sequence_number(self.starting_sequence_number) - .stream_name(stream) - .shard_id(shard_id) - .send() + .get_shard_iterator_at_sequence(stream, shard_id, self.starting_sequence_number) .await - .map_err(|e| e.into()) } } #[async_trait] -impl ShardIterator for AtTimestampShardIterator<'_> { +impl ShardIterator for AtTimestampShardIterator<'_, K> { async fn iterator<'a>( &'a self, stream: &'a str, shard_id: &'a str, ) -> Result { self.client - .get_shard_iterator() - .shard_iterator_type(ShardIteratorType::AtTimestamp) - .timestamp(to_aws_datetime(self.timestamp)) - .stream_name(stream) - .shard_id(shard_id) - .send() + .get_shard_iterator_at_timestamp(stream, shard_id, self.timestamp) .await - .map_err(|e| e.into()) } } diff --git a/src/kinesis.rs b/src/kinesis.rs index c434103..3062e35 100644 --- a/src/kinesis.rs +++ b/src/kinesis.rs @@ -1,79 +1,27 @@ -use std::fmt::Debug; - +use crate::aws::client::KinesisClient; use crate::kinesis::models::*; use async_trait::async_trait; use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput; -use aws_sdk_kinesis::{Client, Error}; -use chrono::Utc; +use aws_sdk_kinesis::Error; use log::{debug, error}; use tokio::sync::mpsc; use tokio::sync::mpsc::Sender; use tokio::time::{sleep, Duration}; - pub mod helpers; pub mod models; -pub fn new( - client: Client, - stream: String, - shard_id: String, - from_datetime: Option>, - tx_records: Sender>, -) -> Box { - match from_datetime { - Some(from_datetime) => Box::new(ShardProcessorAtTimestamp { - config: ShardProcessorConfig { - client, - stream, - shard_id, - tx_records, - }, - from_datetime, - }), - None => Box::new(ShardProcessorLatest { - config: ShardProcessorConfig { - client, - stream, - shard_id, - tx_records, - }, - }), - } -} - #[async_trait] -pub trait IteratorProvider: Send + Sync + Debug + Clone { - fn get_config(&self) -> ShardProcessorConfig; +pub trait IteratorProvider: Send + Sync + Clone + 'static { + fn get_config(&self) -> ShardProcessorConfig; async fn get_iterator(&self) -> Result; } #[async_trait] -impl IteratorProvider for ShardProcessorLatest { - fn get_config(&self) -> ShardProcessorConfig { - self.config.clone() - } - - async fn get_iterator(&self) -> Result { - helpers::get_latest_iterator(self.clone()).await - } -} - -#[async_trait] -impl IteratorProvider for ShardProcessorAtTimestamp { - fn get_config(&self) -> ShardProcessorConfig { - self.config.clone() - } - - async fn get_iterator(&self) -> Result { - helpers::get_iterator_at_timestamp(self.clone(), self.from_datetime).await - } -} - -#[async_trait] -impl ShardProcessor for T +impl ShardProcessor for T where - T: IteratorProvider + Send + Sync + Debug + 'static, + K: KinesisClient, + T: IteratorProvider, { async fn run(&self) -> Result<(), Error> { let (tx_shard_iterator_progress, mut rx_shard_iterator_progress) = @@ -81,7 +29,6 @@ where { let cloned_self = self.clone(); - let tx_shard_iterator_progress = tx_shard_iterator_progress.clone(); tokio::spawn(async move { #[allow(unused_assignments)] @@ -182,13 +129,7 @@ where shard_iterator: &str, tx_shard_iterator_progress: Sender, ) -> Result<(), Error> { - let resp = self - .get_config() - .client - .get_records() - .shard_iterator(shard_iterator) - .send() - .await?; + let resp = self.get_config().client.get_records(shard_iterator).await?; let next_shard_iterator = resp.next_shard_iterator(); @@ -233,3 +174,6 @@ where Ok(()) } } + +#[cfg(test)] +mod tests; diff --git a/src/kinesis/helpers.rs b/src/kinesis/helpers.rs index 84e8964..7d0044f 100644 --- a/src/kinesis/helpers.rs +++ b/src/kinesis/helpers.rs @@ -1,15 +1,51 @@ +use crate::aws::client::{AwsKinesisClient, KinesisClient}; use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput; -use aws_sdk_kinesis::{Client, Error}; +use aws_sdk_kinesis::Error; use chrono::Utc; +use log::debug; use tokio::sync::mpsc::Sender; +use crate::iterator::at_sequence; use crate::iterator::latest; -use crate::iterator::{at_sequence, at_timestamp}; +use crate::kinesis::models::{ + PanicError, ShardProcessor, ShardProcessorADT, ShardProcessorAtTimestamp, ShardProcessorConfig, + ShardProcessorLatest, +}; use crate::kinesis::{IteratorProvider, ShardIteratorProgress}; -pub async fn get_latest_iterator(iterator_provider: T) -> Result +pub fn new( + client: AwsKinesisClient, + stream: String, + shard_id: String, + from_datetime: Option>, + tx_records: Sender>, +) -> Box + Send + Sync> { + match from_datetime { + Some(from_datetime) => Box::new(ShardProcessorAtTimestamp { + config: ShardProcessorConfig { + client, + stream, + shard_id, + tx_records, + }, + from_datetime, + }), + None => Box::new(ShardProcessorLatest { + config: ShardProcessorConfig { + client, + stream, + shard_id, + tx_records, + }, + }), + } +} + +pub async fn get_latest_iterator( + iterator_provider: T, +) -> Result where - T: IteratorProvider, + T: IteratorProvider, { latest(&iterator_provider.get_config().client) .iterator( @@ -19,12 +55,12 @@ where .await } -pub async fn get_iterator_since( +pub async fn get_iterator_since( iterator_provider: T, starting_sequence_number: &str, ) -> Result where - T: IteratorProvider, + T: IteratorProvider, { at_sequence( &iterator_provider.get_config().client, @@ -37,43 +73,30 @@ where .await } -pub async fn get_iterator_at_timestamp( - iterator_provider: T, - timestamp: chrono::DateTime, -) -> Result -where - T: IteratorProvider, -{ - at_timestamp(&iterator_provider.get_config().client, ×tamp) - .iterator( - &iterator_provider.get_config().stream, - &iterator_provider.get_config().shard_id, - ) - .await -} - -pub async fn handle_iterator_refresh( +pub async fn handle_iterator_refresh( shard_iterator_progress: ShardIteratorProgress, - reader: T, + iterator_provider: T, tx_shard_iterator_progress: Sender, ) where - T: IteratorProvider, + T: IteratorProvider, { let (sequence_id, iterator) = match shard_iterator_progress.last_sequence_id { Some(last_sequence_id) => { - let resp = get_iterator_since(reader, &last_sequence_id).await.unwrap(); + let resp = get_iterator_since(iterator_provider, &last_sequence_id) + .await + .unwrap(); ( Some(last_sequence_id), resp.shard_iterator().map(|v| v.into()), ) } None => { - let resp = get_latest_iterator(reader).await.unwrap(); + let resp = get_latest_iterator(iterator_provider).await.unwrap(); (None, resp.shard_iterator().map(|v| v.into())) } }; - println!( + debug!( "Refreshing with next_shard_iterator: {:?} / last_sequence_id {:?}", iterator, sequence_id ); @@ -87,8 +110,8 @@ pub async fn handle_iterator_refresh( .unwrap(); } -pub async fn get_shards(client: &Client, stream: &str) -> Result, Error> { - let resp = client.list_shards().stream_name(stream).send().await?; +pub async fn get_shards(client: &AwsKinesisClient, stream: &str) -> Result, Error> { + let resp = client.list_shards(stream).await?; Ok(resp .shards() diff --git a/src/kinesis/models.rs b/src/kinesis/models.rs index b6ef374..0fc1db1 100644 --- a/src/kinesis/models.rs +++ b/src/kinesis/models.rs @@ -1,6 +1,11 @@ +use crate::aws::client::KinesisClient; +use crate::iterator::at_timestamp; +use crate::kinesis::helpers::get_latest_iterator; +use crate::kinesis::IteratorProvider; use async_trait::async_trait; +use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput; use aws_sdk_kinesis::primitives::DateTime; -use aws_sdk_kinesis::{Client, Error}; +use aws_sdk_kinesis::Error; use chrono::Utc; use std::fmt::Debug; use tokio::sync::mpsc::Sender; @@ -29,27 +34,51 @@ pub struct RecordResult { pub data: Vec, } -#[derive(Debug, Clone)] -pub struct ShardProcessorConfig { - pub client: Client, +#[derive(Clone)] +pub struct ShardProcessorConfig { + pub client: K, pub stream: String, pub shard_id: String, pub tx_records: Sender>, } -#[derive(Debug, Clone)] -pub struct ShardProcessorLatest { - pub config: ShardProcessorConfig, +#[derive(Clone)] +pub struct ShardProcessorLatest { + pub config: ShardProcessorConfig, } -#[derive(Debug, Clone)] -pub struct ShardProcessorAtTimestamp { - pub config: ShardProcessorConfig, +#[derive(Clone)] +pub struct ShardProcessorAtTimestamp { + pub config: ShardProcessorConfig, pub from_datetime: chrono::DateTime, } #[async_trait] -pub trait ShardProcessor: Send + Sync + Debug { +impl IteratorProvider for ShardProcessorLatest { + fn get_config(&self) -> ShardProcessorConfig { + self.config.clone() + } + + async fn get_iterator(&self) -> Result { + get_latest_iterator(self.clone()).await + } +} + +#[async_trait] +impl IteratorProvider for ShardProcessorAtTimestamp { + fn get_config(&self) -> ShardProcessorConfig { + self.config.clone() + } + + async fn get_iterator(&self) -> Result { + at_timestamp(&self.config.client, &self.from_datetime) + .iterator(&self.config.stream, &self.config.shard_id) + .await + } +} + +#[async_trait] +pub trait ShardProcessor: Send + Sync { async fn run(&self) -> Result<(), Error>; async fn publish_records_shard( diff --git a/src/kinesis/tests.rs b/src/kinesis/tests.rs new file mode 100644 index 0000000..b14458c --- /dev/null +++ b/src/kinesis/tests.rs @@ -0,0 +1,129 @@ +use crate::aws::client::KinesisClient; +use crate::kinesis::models::{ + PanicError, ShardProcessor, ShardProcessorADT, ShardProcessorConfig, ShardProcessorLatest, +}; +use async_trait::async_trait; +use aws_sdk_kinesis::config::Region; +use aws_sdk_kinesis::operation::get_records::GetRecordsOutput; +use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput; +use aws_sdk_kinesis::operation::list_shards::ListShardsOutput; +use aws_sdk_kinesis::primitives::{Blob, DateTime}; +use aws_sdk_kinesis::types::{Record, Shard}; +use aws_sdk_kinesis::Error; +use chrono::Utc; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio::time::sleep; + +#[tokio::test] +async fn produced_record_is_processed() { + let (tx_records, mut rx_records) = mpsc::channel::>(10); + + let client = TestKinesisClient { + region: Some(Region::new("us-east-1")), + }; + + let processor = ShardProcessorLatest { + config: ShardProcessorConfig { + client, + stream: "test".to_string(), + shard_id: "shardId-000000000000".to_string(), + tx_records, + }, + }; + + // start producer + tokio::spawn(async move { processor.run().await }); + + let mut done_processing = false; + let mut closed_resources = false; + let mut count = 0; + + while let Some(res) = rx_records.recv().await { + if !done_processing { + match res { + Ok(adt) => match adt { + ShardProcessorADT::Progress(res) => { + count += res.len(); + } + _ => {} + }, + Err(_) => {} + } + + done_processing = true; + } else { + if !closed_resources { + sleep(Duration::from_millis(100)).await; + rx_records.close(); + } + closed_resources = true; + } + } + + assert_eq!(count, 1) +} + +#[derive(Clone, Debug)] +pub struct TestKinesisClient { + region: Option, +} + +#[async_trait] +impl KinesisClient for TestKinesisClient { + async fn list_shards(&self, _stream: &str) -> Result { + Ok(ListShardsOutput::builder() + .shards(Shard::builder().shard_id("000001").build()) + .build()) + } + + async fn get_records(&self, _shard_iterator: &str) -> Result { + let dt = DateTime::from_secs(5000); + let record = Record::builder() + .approximate_arrival_timestamp(dt) + .sequence_number("1") + .data(Blob::new("data")) + .build(); + + Ok(GetRecordsOutput::builder() + .records(record) + .next_shard_iterator("shard_iterator2".to_string()) + .build()) + } + + async fn get_shard_iterator_at_timestamp( + &self, + _stream: &str, + _shard_id: &str, + _timestamp: &chrono::DateTime, + ) -> Result { + Ok(GetShardIteratorOutput::builder() + .shard_iterator("shard_iterator".to_string()) + .build()) + } + + async fn get_shard_iterator_at_sequence( + &self, + _stream: &str, + _shard_id: &str, + _starting_sequence_number: &str, + ) -> Result { + Ok(GetShardIteratorOutput::builder() + .shard_iterator("shard_iterator".to_string()) + .build()) + } + + async fn get_shard_iterator_latest( + &self, + _stream: &str, + _shard_id: &str, + ) -> Result { + Ok(GetShardIteratorOutput::builder() + .shard_iterator("shard_iterator".to_string()) + .build()) + } + + fn get_region(&self) -> Option<&Region> { + self.region.as_ref() + } +} diff --git a/src/main.rs b/src/main.rs index b6f92de..06286ea 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,13 @@ #![allow(clippy::result_large_err)] -use aws_config::meta::region::RegionProviderChain; -use aws_sdk_kinesis::{config::Region, meta::PKG_VERSION, Client}; +use aws_sdk_kinesis::{config::Region, meta::PKG_VERSION}; use clap::Parser; use log::info; use std::io; use tokio::sync::mpsc; +use crate::aws::client::*; + use crate::cli_helpers::parse_date; use crate::sink::console::ConsoleSink; use crate::sink::Sink; @@ -17,6 +18,8 @@ mod iterator; mod kinesis; mod sink; +mod aws; + #[derive(Debug, Parser)] struct Opt { /// AWS Region @@ -80,19 +83,16 @@ async fn main() -> Result<(), io::Error> { endpoint_url, } = Opt::parse(); - env_logger::init(); - - let region_provider = RegionProviderChain::first_try(region.map(Region::new)) - .or_default_provider() - .or_else(Region::new("us-east-1")); + env_logger::init_from_env(env_logger::Env::default().default_filter_or("info")); let from_datetime = parse_date(from.as_deref()); + let client = aws::client::create_client(region, endpoint_url).await; if verbose { info!("Kinesis client version: {}", PKG_VERSION); info!( "Region: {}", - region_provider.region().await.unwrap().as_ref() + client.get_region().unwrap_or(&Region::new("us-east-1")) ); info!("Stream name: {}", &stream_name); from_datetime.iter().for_each(|f| { @@ -100,22 +100,6 @@ async fn main() -> Result<(), io::Error> { }); } - let shared_config = { - let inner = aws_config::from_env().region(region_provider); - - let inner = if endpoint_url.is_some() { - inner.endpoint_url(endpoint_url.unwrap().as_str()) - } else { - inner - }; - - inner - } - .load() - .await; - - let client = Client::new(&shared_config); - let (tx_records, rx_records) = mpsc::channel::>(500); let shards = get_shards(&client, &stream_name) @@ -148,7 +132,7 @@ async fn main() -> Result<(), io::Error> { } for shard_id in &selected_shards { - let shard_processor = kinesis::new( + let shard_processor = kinesis::helpers::new( client.clone(), stream_name.clone(), shard_id.clone(), diff --git a/src/sink.rs b/src/sink.rs index 535ddd9..8140373 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -1,8 +1,9 @@ -use async_trait::async_trait; -use chrono::TimeZone; use std::io; use std::io::{BufWriter, Error, Write}; use std::sync::Arc; + +use async_trait::async_trait; +use chrono::TimeZone; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::Mutex; @@ -123,8 +124,8 @@ where data.iter().for_each(|data| { writeln!(handle, "{}", data).unwrap(); + self.delimiter(handle).unwrap(); }); - self.delimiter(handle)? } } None => { @@ -133,8 +134,8 @@ where *lock += data.len() as u32; data.iter().for_each(|data| { writeln!(handle, "{}", data).unwrap(); + self.delimiter(handle).unwrap() }); - self.delimiter(handle)? } } }