Skip to content

Commit

Permalink
Shard pagination (#44)
Browse files Browse the repository at this point in the history
Shard pagination (#44)
  • Loading branch information
gr211 authored Aug 7, 2023
1 parent e8acc02 commit 5b209cc
Show file tree
Hide file tree
Showing 11 changed files with 432 additions and 548 deletions.
633 changes: 221 additions & 412 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 4 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ edition = "2021"
anyhow = "1.0"
async-trait = "0.1"
aws-config = { version = "0.55" }
aws-sdk-kinesis = { version = "0.28" }
chrono = "0.4"
aws-sdk-kinesis = { version = "0.28" }
chrono = { version = "0.4", features = ["clock", "std"] }
clap = { version = "4.3", features = ["derive"] }
colored = "2.0"
config = "0.13"
ctrlc-async = "3.2"
env_logger = "0.10"
humantime = "2.1"
Expand All @@ -26,7 +25,7 @@ log4rs = "1.2"
nix = "0.26"
rand = "0.8"
thiserror = "1.0"
tokio = { version = "1.28", features = ["rt-multi-thread", "macros"] }
tokio = { version = "1.29", features = ["rt-multi-thread", "macros"] }

[features]
default = ["clap/cargo", "clap/derive", "config/json"]
default = ["clap/cargo", "clap/derive"]
24 changes: 16 additions & 8 deletions src/aws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ pub mod client {

#[async_trait]
pub trait KinesisClient: Sync + Send + Clone {
async fn list_shards(&self, stream: &str) -> Result<ListShardsOutput>;
async fn list_shards(
&self,
stream: &str,
next_token: Option<&str>,
) -> Result<ListShardsOutput>;

async fn get_records(&self, shard_iterator: &str) -> Result<GetRecordsOutput>;

Expand Down Expand Up @@ -52,13 +56,17 @@ pub mod client {

#[async_trait]
impl KinesisClient for AwsKinesisClient {
async fn list_shards(&self, stream: &str) -> Result<ListShardsOutput> {
self.client
.list_shards()
.stream_name(stream)
.send()
.await
.map_err(|e| e.into())
async fn list_shards(
&self,
stream: &str,
next_token: Option<&str>,
) -> Result<ListShardsOutput> {
let builder = match next_token {
Some(token) => self.client.list_shards().next_token(token),
None => self.client.list_shards().stream_name(stream),
};

builder.send().await.map_err(|e| e.into())
}

async fn get_records(&self, shard_iterator: &str) -> Result<GetRecordsOutput> {
Expand Down
5 changes: 3 additions & 2 deletions src/cli_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use chrono::{DateTime, Utc};
use clap::Parser;
use log::info;

pub const SEMAPHORE_DEFAULT_SIZE: usize = 50;

#[derive(Debug, Parser)]
#[command(
version = "{#RELEASE_VERSION} - Grum Ltd\nReport bugs to https://github.com/grumlimited/kinesis-tailr/issues"
Expand Down Expand Up @@ -67,8 +69,7 @@ pub struct Opt {

/// Concurrent number of shards to tail
#[structopt(short, long)]
#[clap(default_value_t = 10)]
pub concurrent: usize,
pub concurrent: Option<usize>,

/// Display additional information
#[structopt(short, long)]
Expand Down
78 changes: 46 additions & 32 deletions src/kinesis.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
use crate::aws::client::KinesisClient;
use crate::kinesis::helpers::wait_secs;
use crate::kinesis::models::*;
use crate::kinesis::ticker::TickerUpdate;
use anyhow::Result;
use async_trait::async_trait;
use aws_sdk_kinesis::operation::get_records::GetRecordsError;
use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput;
use chrono::prelude::*;
use chrono::{DateTime, Utc};
use log::{debug, info};
use log::debug;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio::time::{sleep, Duration};
use GetRecordsError::{ExpiredIteratorException, ProvisionedThroughputExceededException};

use crate::aws::client::KinesisClient;
use crate::kinesis::helpers::wait_secs;
use crate::kinesis::models::*;
use crate::kinesis::ticker::{ShardCountUpdate, TickerMessage};

pub mod helpers;
pub mod models;
pub mod ticker;
Expand Down Expand Up @@ -47,7 +48,6 @@ where
let result = self
.publish_records_shard(
&shard_iterator,
res.shard_id.clone(),
self.get_config().tx_ticker_updates.clone(),
tx_shard_iterator_progress.clone(),
)
Expand All @@ -56,7 +56,11 @@ where
if let Err(e) = result {
match e.downcast_ref::<GetRecordsError>() {
Some(ExpiredIteratorException(inner)) => {
debug!("ExpiredIteratorException: {}", inner);
debug!(
"ExpiredIteratorException [{}]: {}",
self.get_config().shard_id,
inner
);
helpers::handle_iterator_refresh(
res_clone.clone(),
self.clone(),
Expand Down Expand Up @@ -92,18 +96,31 @@ where
}
None => {
self.get_config()
.tx_records
.send(Err(ProcessError::PanicError(
"ShardIterator is None".to_string(),
)))
.tx_ticker_updates
.send(TickerMessage::RemoveShard(res.shard_id.clone()))
.await
.expect("");
.expect("Could not send RemoveShard to tx_ticker_updates");
rx_shard_iterator_progress.close();
}
};

drop(permit);
}

debug!("ShardProcessor {} finished", self.get_config().shard_id);

self.get_config()
.tx_ticker_updates
.send(TickerMessage::RemoveShard(
self.get_config().shard_id.clone(),
))
.await?;

self.get_config()
.tx_records
.send(Ok(ShardProcessorADT::BeyondToTimestamp))
.await?;

Ok(())
}

Expand All @@ -115,8 +132,6 @@ where

debug!("Seeding shard {}", self.get_config().shard_id);

let tx_shard_iterator_progress = tx_shard_iterator_progress.clone();

match self.get_iterator().await {
Ok(resp) => {
let shard_iterator: Option<String> = resp.shard_iterator().map(|s| s.into());
Expand All @@ -142,16 +157,12 @@ where
}

/**
* Publish records from a shard iterator.
* Because shards are multiplexed per ShardProcessor, we need to keep
* track of the shard_id for each shard_iterator.
* Publish records from a shard iterator.
*/
async fn publish_records_shard(
&self,
shard_iterator: &str,
shard_id: String,
tx_ticker_updates: Sender<TickerUpdate>,
tx_ticker_updates: Sender<TickerMessage>,
tx_shard_iterator_progress: Sender<ShardIteratorProgress>,
) -> Result<()> {
let resp = self.get_config().client.get_records(shard_iterator).await?;
Expand All @@ -167,7 +178,7 @@ where
let datetime = *record.approximate_arrival_timestamp().unwrap();

RecordResult {
shard_id: shard_id.clone(),
shard_id: self.get_config().shard_id,
sequence_id: record.sequence_number().unwrap().into(),
partition_key: record.partition_key().unwrap_or("none").into(),
datetime,
Expand All @@ -182,10 +193,10 @@ where

if let Some(millis_behind) = resp.millis_behind_latest() {
tx_ticker_updates
.send(TickerUpdate {
shard_id: shard_id.clone(),
.send(TickerMessage::CountUpdate(ShardCountUpdate {
shard_id: self.get_config().shard_id.clone(),
millis_behind,
})
}))
.await
.expect("Could not send TickerUpdate to tx_ticker_updates");
}
Expand Down Expand Up @@ -213,7 +224,7 @@ where
.map(|s| s.into());

let shard_iterator_progress = ShardIteratorProgress {
shard_id: shard_id.clone(),
shard_id: self.get_config().shard_id,
last_sequence_id,
next_shard_iterator: next_shard_iterator.map(|s| s.into()),
};
Expand All @@ -223,21 +234,24 @@ where
.await
.unwrap();
} else {
info!(
debug!(
"{} records in batch for shard-id {} and {} records before {}",
nb_records,
shard_id,
self.get_config().shard_id,
nb_records_before_end_ts,
self.get_config()
.to_datetime
.map(|ts| ts.to_rfc3339())
.unwrap_or("[No end timestamp]".to_string())
);
self.get_config()
.tx_records
.send(Ok(ShardProcessorADT::BeyondToTimestamp))
.await
.expect("Could not send BeyondToTimestamp to tx_records");

tx_shard_iterator_progress
.send(ShardIteratorProgress {
shard_id: self.get_config().shard_id.clone(),
last_sequence_id: None,
next_shard_iterator: None,
})
.await?;
}

Ok(())
Expand Down
58 changes: 39 additions & 19 deletions src/kinesis/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use aws_sdk_kinesis::error::SdkError::ServiceError;
use aws_sdk_kinesis::operation::get_shard_iterator::{
GetShardIteratorError, GetShardIteratorOutput,
};
use aws_sdk_kinesis::operation::list_shards::ListShardsError;
use aws_sdk_kinesis::operation::list_shards::{ListShardsError, ListShardsOutput};
use chrono::Utc;
use log::debug;
use log::{debug, info};
use tokio::sync::mpsc::Sender;
use tokio::sync::Semaphore;
use tokio::time::sleep;
Expand All @@ -23,7 +23,7 @@ use crate::kinesis::models::{
ProcessError, ShardProcessor, ShardProcessorADT, ShardProcessorAtTimestamp,
ShardProcessorConfig, ShardProcessorLatest,
};
use crate::kinesis::ticker::TickerUpdate;
use crate::kinesis::ticker::TickerMessage;
use crate::kinesis::{IteratorProvider, ShardIteratorProgress};

#[allow(clippy::too_many_arguments)]
Expand All @@ -35,7 +35,7 @@ pub fn new(
to_datetime: Option<chrono::DateTime<Utc>>,
semaphore: Arc<Semaphore>,
tx_records: Sender<Result<ShardProcessorADT, ProcessError>>,
tx_ticker_updates: Sender<TickerUpdate>,
tx_ticker_updates: Sender<TickerMessage>,
) -> Box<dyn ShardProcessor<AwsKinesisClient> + Send + Sync> {
debug!("Creating ShardProcessor with shard {}", shard_id);

Expand Down Expand Up @@ -150,27 +150,47 @@ where
}

pub async fn get_shards(client: &AwsKinesisClient, stream: &str) -> io::Result<Vec<String>> {
let resp = client
.list_shards(stream)
.await
.map_err(|e| {
let mut seed = client.list_shards(stream, None).await;

let mut results: Vec<ListShardsOutput> = vec![];

while let Ok(result) = &seed {
results.push(result.clone());
if let Some(next_token) = result.next_token() {
let result = client.list_shards(stream, Some(next_token)).await;
seed = result;
} else {
break;
}
}

match seed {
Ok(_) => {
let shards: Vec<String> = results
.iter()
.flat_map(|r| {
r.shards()
.unwrap()
.iter()
.map(|s| s.shard_id().unwrap().to_string())
.collect::<Vec<String>>()
})
.collect::<Vec<String>>();

info!("Found {} shards", shards.len());

Ok(shards)
}
Err(e) => {
let message = match e.downcast_ref::<SdkError<ListShardsError>>() {
Some(ServiceError(inner)) => inner.err().to_string(),
Some(other) => other.to_string(),
_ => e.to_string(),
};

io::Error::new(io::ErrorKind::Other, message)
})
.map(|e| {
e.shards()
.unwrap()
.iter()
.map(|s| s.shard_id.as_ref().unwrap().clone())
.collect::<Vec<String>>()
})?;

Ok(resp)
Err(io::Error::new(io::ErrorKind::Other, message))
}
}
}

pub fn wait_secs() -> u64 {
Expand Down
7 changes: 3 additions & 4 deletions src/kinesis/models.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::aws::client::KinesisClient;
use crate::iterator::ShardIterator;
use crate::iterator::{at_timestamp, latest};
use crate::kinesis::ticker::TickerUpdate;
use crate::kinesis::ticker::TickerMessage;
use crate::kinesis::IteratorProvider;
use anyhow::Result;
use async_trait::async_trait;
Expand Down Expand Up @@ -51,7 +51,7 @@ pub struct ShardProcessorConfig<K: KinesisClient> {
pub to_datetime: Option<chrono::DateTime<Utc>>,
pub semaphore: Arc<Semaphore>,
pub tx_records: Sender<Result<ShardProcessorADT, ProcessError>>,
pub tx_ticker_updates: Sender<TickerUpdate>,
pub tx_ticker_updates: Sender<TickerMessage>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -101,8 +101,7 @@ pub trait ShardProcessor<K: KinesisClient>: Send + Sync {
async fn publish_records_shard(
&self,
shard_iterator: &str,
shard_id: String,
tx_ticker: Sender<TickerUpdate>,
tx_ticker: Sender<TickerMessage>,
tx_shard_iterator_progress: Sender<ShardIteratorProgress>,
) -> Result<()>;

Expand Down
Loading

0 comments on commit 5b209cc

Please sign in to comment.