Skip to content

Commit

Permalink
KinesisClient as trait paramater to IteratorConfig (#8)
Browse files Browse the repository at this point in the history
KinesisClient as trait paramater to IteratorConfig
  • Loading branch information
gr211 authored May 9, 2023
1 parent b6b3d6b commit 2554bfa
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 177 deletions.
152 changes: 152 additions & 0 deletions src/aws.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
pub mod client {

use async_trait::async_trait;
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_kinesis::config::Region;
use aws_sdk_kinesis::operation::get_records::GetRecordsOutput;
use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput;
use aws_sdk_kinesis::operation::list_shards::ListShardsOutput;
use aws_sdk_kinesis::primitives::DateTime;
use aws_sdk_kinesis::types::ShardIteratorType;
use aws_sdk_kinesis::{Client, Error};
use chrono::Utc;

#[derive(Clone, Debug)]
pub struct AwsKinesisClient {
client: Client,
}

#[async_trait]
pub trait KinesisClient: Sync + Send + Clone + 'static {
async fn list_shards(&self, stream: &str) -> Result<ListShardsOutput, Error>;

async fn get_records(&self, shard_iterator: &str) -> Result<GetRecordsOutput, Error>;

async fn get_shard_iterator_at_timestamp(
&self,
stream: &str,
shard_id: &str,
timestamp: &chrono::DateTime<Utc>,
) -> Result<GetShardIteratorOutput, Error>;

async fn get_shard_iterator_at_sequence(
&self,
stream: &str,
shard_id: &str,
starting_sequence_number: &str,
) -> Result<GetShardIteratorOutput, Error>;

async fn get_shard_iterator_latest(
&self,
stream: &str,
shard_id: &str,
) -> Result<GetShardIteratorOutput, Error>;

fn get_region(&self) -> Option<&Region>;

fn to_aws_datetime(timestamp: &chrono::DateTime<Utc>) -> DateTime {
DateTime::from_millis(timestamp.timestamp_millis())
}
}

#[async_trait]
impl KinesisClient for AwsKinesisClient {
async fn list_shards(&self, stream: &str) -> Result<ListShardsOutput, Error> {
self.client
.list_shards()
.stream_name(stream)
.send()
.await
.map_err(|e| e.into())
}

async fn get_records(&self, shard_iterator: &str) -> Result<GetRecordsOutput, Error> {
self.client
.get_records()
.shard_iterator(shard_iterator)
.send()
.await
.map_err(|e| e.into())
}

async fn get_shard_iterator_at_timestamp(
&self,
stream: &str,
shard_id: &str,
timestamp: &chrono::DateTime<Utc>,
) -> Result<GetShardIteratorOutput, Error> {
self.client
.get_shard_iterator()
.shard_iterator_type(ShardIteratorType::AtTimestamp)
.timestamp(Self::to_aws_datetime(timestamp))
.stream_name(stream)
.shard_id(shard_id)
.send()
.await
.map_err(|e| e.into())
}

async fn get_shard_iterator_at_sequence(
&self,
stream: &str,
shard_id: &str,
starting_sequence_number: &str,
) -> Result<GetShardIteratorOutput, Error> {
self.client
.get_shard_iterator()
.shard_iterator_type(ShardIteratorType::AtSequenceNumber)
.starting_sequence_number(starting_sequence_number)
.stream_name(stream)
.shard_id(shard_id)
.send()
.await
.map_err(|e| e.into())
}

async fn get_shard_iterator_latest(
&self,
stream: &str,
shard_id: &str,
) -> Result<GetShardIteratorOutput, Error> {
self.client
.get_shard_iterator()
.shard_iterator_type(ShardIteratorType::Latest)
.stream_name(stream)
.shard_id(shard_id)
.send()
.await
.map_err(|e| e.into())
}

fn get_region(&self) -> Option<&Region> {
self.client.conf().region()
}
}

pub async fn create_client(
region: Option<String>,
endpoint_url: Option<String>,
) -> AwsKinesisClient {
let region_provider = RegionProviderChain::first_try(region.map(Region::new))
.or_default_provider()
.or_else(Region::new("us-east-1"));

let shared_config = {
let inner = aws_config::from_env().region(region_provider);

let inner = if endpoint_url.is_some() {
inner.endpoint_url(endpoint_url.unwrap().as_str())
} else {
inner
};

inner
}
.load()
.await;

let client = Client::new(&shared_config);

AwsKinesisClient { client }
}
}
63 changes: 22 additions & 41 deletions src/iterator.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use crate::aws::client::KinesisClient;
use async_trait::async_trait;
use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput;
use aws_sdk_kinesis::primitives::DateTime;
use aws_sdk_kinesis::types::ShardIteratorType;
use aws_sdk_kinesis::{Client, Error};
use aws_sdk_kinesis::Error;
use chrono::Utc;

#[async_trait]
Expand All @@ -14,16 +13,15 @@ pub trait ShardIterator {
) -> Result<GetShardIteratorOutput, Error>;
}

fn to_aws_datetime(timestamp: &chrono::DateTime<Utc>) -> DateTime {
DateTime::from_millis(timestamp.timestamp_millis())
}

pub fn latest<'a>(client: &'a Client) -> Box<dyn ShardIterator + 'a + Send + Sync> {
pub fn latest<'a, K>(client: &'a K) -> Box<dyn ShardIterator + 'a + Send + Sync>
where
K: KinesisClient,
{
Box::new(LatestShardIterator { client })
}

pub fn at_sequence<'a>(
client: &'a Client,
pub fn at_sequence<'a, K: KinesisClient>(
client: &'a K,
starting_sequence_number: &'a str,
) -> Box<dyn ShardIterator + 'a + Send + Sync> {
Box::new(AtSequenceShardIterator {
Expand All @@ -32,79 +30,62 @@ pub fn at_sequence<'a>(
})
}

pub fn at_timestamp<'a>(
client: &'a Client,
pub fn at_timestamp<'a, K: KinesisClient>(
client: &'a K,
timestamp: &'a chrono::DateTime<Utc>,
) -> Box<dyn ShardIterator + 'a + Send + Sync> {
Box::new(AtTimestampShardIterator { client, timestamp })
}

struct LatestShardIterator<'a> {
client: &'a Client,
struct LatestShardIterator<'a, K: KinesisClient> {
client: &'a K,
}

struct AtSequenceShardIterator<'a> {
client: &'a Client,
struct AtSequenceShardIterator<'a, K: KinesisClient> {
client: &'a K,
starting_sequence_number: &'a str,
}

struct AtTimestampShardIterator<'a> {
client: &'a Client,
struct AtTimestampShardIterator<'a, K: KinesisClient> {
client: &'a K,
timestamp: &'a chrono::DateTime<Utc>,
}

#[async_trait]
impl ShardIterator for LatestShardIterator<'_> {
impl<K: KinesisClient> ShardIterator for LatestShardIterator<'_, K> {
async fn iterator<'a>(
&'a self,
stream: &'a str,
shard_id: &'a str,
) -> Result<GetShardIteratorOutput, Error> {
self.client
.get_shard_iterator()
.shard_iterator_type(ShardIteratorType::Latest)
.stream_name(stream)
.shard_id(shard_id)
.send()
.get_shard_iterator_latest(stream, shard_id)
.await
.map_err(|e| e.into())
}
}

#[async_trait]
impl ShardIterator for AtSequenceShardIterator<'_> {
impl<K: KinesisClient> ShardIterator for AtSequenceShardIterator<'_, K> {
async fn iterator<'a>(
&'a self,
stream: &'a str,
shard_id: &'a str,
) -> Result<GetShardIteratorOutput, Error> {
self.client
.get_shard_iterator()
.shard_iterator_type(ShardIteratorType::AtSequenceNumber)
.starting_sequence_number(self.starting_sequence_number)
.stream_name(stream)
.shard_id(shard_id)
.send()
.get_shard_iterator_at_sequence(stream, shard_id, self.starting_sequence_number)
.await
.map_err(|e| e.into())
}
}

#[async_trait]
impl ShardIterator for AtTimestampShardIterator<'_> {
impl<K: KinesisClient> ShardIterator for AtTimestampShardIterator<'_, K> {
async fn iterator<'a>(
&'a self,
stream: &'a str,
shard_id: &'a str,
) -> Result<GetShardIteratorOutput, Error> {
self.client
.get_shard_iterator()
.shard_iterator_type(ShardIteratorType::AtTimestamp)
.timestamp(to_aws_datetime(self.timestamp))
.stream_name(stream)
.shard_id(shard_id)
.send()
.get_shard_iterator_at_timestamp(stream, shard_id, self.timestamp)
.await
.map_err(|e| e.into())
}
}
78 changes: 11 additions & 67 deletions src/kinesis.rs
Original file line number Diff line number Diff line change
@@ -1,87 +1,34 @@
use std::fmt::Debug;

use crate::aws::client::KinesisClient;
use crate::kinesis::models::*;
use async_trait::async_trait;
use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput;
use aws_sdk_kinesis::{Client, Error};
use chrono::Utc;
use aws_sdk_kinesis::Error;
use log::{debug, error};
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio::time::{sleep, Duration};

pub mod helpers;
pub mod models;

pub fn new(
client: Client,
stream: String,
shard_id: String,
from_datetime: Option<chrono::DateTime<Utc>>,
tx_records: Sender<Result<ShardProcessorADT, PanicError>>,
) -> Box<dyn ShardProcessor + Send + Sync> {
match from_datetime {
Some(from_datetime) => Box::new(ShardProcessorAtTimestamp {
config: ShardProcessorConfig {
client,
stream,
shard_id,
tx_records,
},
from_datetime,
}),
None => Box::new(ShardProcessorLatest {
config: ShardProcessorConfig {
client,
stream,
shard_id,
tx_records,
},
}),
}
}

#[async_trait]
pub trait IteratorProvider: Send + Sync + Debug + Clone {
fn get_config(&self) -> ShardProcessorConfig;
pub trait IteratorProvider<K: KinesisClient>: Send + Sync + Clone + 'static {
fn get_config(&self) -> ShardProcessorConfig<K>;

async fn get_iterator(&self) -> Result<GetShardIteratorOutput, Error>;
}

#[async_trait]
impl IteratorProvider for ShardProcessorLatest {
fn get_config(&self) -> ShardProcessorConfig {
self.config.clone()
}

async fn get_iterator(&self) -> Result<GetShardIteratorOutput, Error> {
helpers::get_latest_iterator(self.clone()).await
}
}

#[async_trait]
impl IteratorProvider for ShardProcessorAtTimestamp {
fn get_config(&self) -> ShardProcessorConfig {
self.config.clone()
}

async fn get_iterator(&self) -> Result<GetShardIteratorOutput, Error> {
helpers::get_iterator_at_timestamp(self.clone(), self.from_datetime).await
}
}

#[async_trait]
impl<T> ShardProcessor for T
impl<T, K> ShardProcessor<K> for T
where
T: IteratorProvider + Send + Sync + Debug + 'static,
K: KinesisClient,
T: IteratorProvider<K>,
{
async fn run(&self) -> Result<(), Error> {
let (tx_shard_iterator_progress, mut rx_shard_iterator_progress) =
mpsc::channel::<ShardIteratorProgress>(100);

{
let cloned_self = self.clone();

let tx_shard_iterator_progress = tx_shard_iterator_progress.clone();
tokio::spawn(async move {
#[allow(unused_assignments)]
Expand Down Expand Up @@ -182,13 +129,7 @@ where
shard_iterator: &str,
tx_shard_iterator_progress: Sender<ShardIteratorProgress>,
) -> Result<(), Error> {
let resp = self
.get_config()
.client
.get_records()
.shard_iterator(shard_iterator)
.send()
.await?;
let resp = self.get_config().client.get_records(shard_iterator).await?;

let next_shard_iterator = resp.next_shard_iterator();

Expand Down Expand Up @@ -233,3 +174,6 @@ where
Ok(())
}
}

#[cfg(test)]
mod tests;
Loading

0 comments on commit 2554bfa

Please sign in to comment.