Skip to content

Commit

Permalink
Re-introduced semaphore to limit concurrent connections
Browse files Browse the repository at this point in the history
  • Loading branch information
gr211 committed May 30, 2024
1 parent 88bdf52 commit b567bf9
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Arch Linux.
--progress Print progress status
--shard-id <SHARD_ID> Shard ID to tail from. Repeat option for each shard ID to filter on
-o, --output-file <OUTPUT_FILE> Output file to write to
-c, --concurrent <CONCURRENT> Concurrent number of shards to tail
-v, --verbose Display additional information
--base64 Base64 encode payloads (eg. for binary data)
--utf8 Forces UTF-8 printable payloads
Expand Down
6 changes: 6 additions & 0 deletions src/cli_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use chrono::{DateTime, Utc};
use clap::Parser;
use log::info;

pub const SEMAPHORE_DEFAULT_SIZE: usize = 50;

#[derive(Debug, Parser)]
#[command(
version = "{#RELEASE_VERSION} - Grum Ltd\nReport bugs to https://github.com/grumlimited/kinesis-tailr/issues"
Expand Down Expand Up @@ -79,6 +81,10 @@ pub struct Opt {
#[structopt(long, short)]
pub output_file: Option<String>,

/// Concurrent number of shards to tail
#[structopt(short, long)]
pub concurrent: Option<usize>,

/// Display additional information
#[structopt(short, long)]
pub verbose: bool,
Expand Down
7 changes: 7 additions & 0 deletions src/kinesis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ where
self.seed_shards(tx_shard_iterator_progress.clone()).await?;

while let Some(res) = rx_shard_iterator_progress.recv().await {
let permit = self.get_config().semaphore.clone().acquire_owned().await?;

let res_clone = res.clone();

match res.next_shard_iterator {
Expand Down Expand Up @@ -105,6 +107,8 @@ where
rx_shard_iterator_progress.close();
}
};

drop(permit);
}

debug!("ShardProcessor {} finished", self.get_config().shard_id);
Expand All @@ -129,6 +133,8 @@ where
&self,
tx_shard_iterator_progress: Sender<ShardIteratorProgress>,
) -> Result<()> {
let permit = self.get_config().semaphore.clone().acquire_owned().await?;

debug!("Seeding shard {}", self.get_config().shard_id);

match self.get_iterator().await {
Expand All @@ -149,6 +155,7 @@ where
}
}

drop(permit);
Ok(())
}

Expand Down
4 changes: 4 additions & 0 deletions src/kinesis/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use aws_sdk_kinesis::types::Shard;
use chrono::Utc;
use log::{debug, info};
use tokio::sync::mpsc::Sender;
use tokio::sync::Semaphore;
use tokio::time::sleep;

use crate::aws::client::AwsKinesisClient;
Expand All @@ -32,6 +33,7 @@ pub fn new(
shard_id: String,
from_datetime: Option<chrono::DateTime<Utc>>,
to_datetime: Option<chrono::DateTime<Utc>>,
semaphore: Arc<Semaphore>,
tx_records: Sender<Result<ShardProcessorADT, ProcessError>>,
tx_ticker_updates: Option<Sender<TickerMessage>>,
) -> Box<dyn ShardProcessor<AwsKinesisClient> + Send + Sync> {
Expand All @@ -44,6 +46,7 @@ pub fn new(
stream,
shard_id: Arc::new(shard_id),
to_datetime,
semaphore,
tx_records,
tx_ticker_updates,
},
Expand All @@ -55,6 +58,7 @@ pub fn new(
stream,
shard_id: Arc::new(shard_id),
to_datetime,
semaphore,
tx_records,
tx_ticker_updates,
},
Expand Down
2 changes: 2 additions & 0 deletions src/kinesis/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::fmt::Debug;
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::mpsc::Sender;
use tokio::sync::Semaphore;

#[derive(Debug, Clone)]
pub struct ShardIteratorProgress {
Expand Down Expand Up @@ -49,6 +50,7 @@ pub struct ShardProcessorConfig {
pub stream: String,
pub shard_id: Arc<String>,
pub to_datetime: Option<chrono::DateTime<Utc>>,
pub semaphore: Arc<Semaphore>,
pub tx_records: Sender<Result<ShardProcessorADT, ProcessError>>,
pub tx_ticker_updates: Option<Sender<TickerMessage>>,
}
Expand Down
21 changes: 20 additions & 1 deletion src/kinesis/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use aws_sdk_kinesis::types::error::InvalidArgumentException;
use aws_sdk_kinesis::types::{Record, Shard};
use chrono::prelude::*;
use chrono::Utc;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, Semaphore};

use crate::aws::stream::StreamClient;
use crate::kinesis::helpers;
Expand All @@ -34,12 +34,15 @@ async fn seed_shards_test() {
done: Arc::new(Mutex::new(false)),
};

let semaphore: Arc<Semaphore> = Arc::new(Semaphore::new(10));

let processor = ShardProcessorLatest {
client,
config: ShardProcessorConfig {
stream: "test".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore,
tx_records,
tx_ticker_updates: Some(tx_ticker_updates),
},
Expand Down Expand Up @@ -69,12 +72,15 @@ async fn seed_shards_test_timestamp_in_future() {

let client = TestTimestampInFutureKinesisClient {};

let semaphore: Arc<Semaphore> = Arc::new(Semaphore::new(10));

let processor = ShardProcessorAtTimestamp {
client,
config: ShardProcessorConfig {
stream: "test".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore,
tx_records,
tx_ticker_updates: Some(tx_ticker_updates),
},
Expand All @@ -96,12 +102,15 @@ async fn produced_record_is_processed() {
done: Arc::new(Mutex::new(false)),
};

let semaphore: Arc<Semaphore> = Arc::new(Semaphore::new(10));

let processor = ShardProcessorLatest {
client: client.clone(),
config: ShardProcessorConfig {
stream: "test".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore,
tx_records,
tx_ticker_updates: Some(tx_ticker_updates),
},
Expand Down Expand Up @@ -140,13 +149,16 @@ async fn beyond_to_timestamp_is_received() {
done: Arc::new(Mutex::new(false)),
};

let semaphore: Arc<Semaphore> = Arc::new(Semaphore::new(10));

let to_datetime = Utc.with_ymd_and_hms(2020, 6, 1, 12, 0, 0).unwrap();
let processor = ShardProcessorLatest {
client,
config: ShardProcessorConfig {
stream: "test".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: Some(to_datetime),
semaphore,
tx_records,
tx_ticker_updates: Some(tx_ticker_updates),
},
Expand Down Expand Up @@ -178,13 +190,16 @@ async fn has_records_beyond_end_ts_when_has_end_ts() {
done: Arc::new(Mutex::new(false)),
};

let semaphore: Arc<Semaphore> = Arc::new(Semaphore::new(10));

let to_datetime = Utc.with_ymd_and_hms(2020, 6, 1, 12, 0, 0).unwrap();
let processor = ShardProcessorLatest {
client,
config: ShardProcessorConfig {
stream: "test".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: Some(to_datetime),
semaphore,
tx_records,
tx_ticker_updates: Some(tx_ticker_updates),
},
Expand Down Expand Up @@ -236,12 +251,15 @@ async fn has_records_beyond_end_ts_when_no_end_ts() {
done: Arc::new(Mutex::new(false)),
};

let semaphore: Arc<Semaphore> = Arc::new(Semaphore::new(10));

let processor = ShardProcessorLatest {
client,
config: ShardProcessorConfig {
stream: "test".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore,
tx_records,
tx_ticker_updates: Some(tx_ticker_updates),
},
Expand Down Expand Up @@ -285,6 +303,7 @@ async fn handle_iterator_refresh_ok() {
stream: "test".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore: Arc::new(Semaphore::new(10)),
tx_records: mpsc::channel::<Result<ShardProcessorADT, ProcessError>>(10).0,
tx_ticker_updates: Some(mpsc::channel::<TickerMessage>(10).0),
},
Expand Down
16 changes: 15 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#![allow(clippy::result_large_err)]

use std::sync::Arc;

use anyhow::Result;
use clap::Parser;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, Semaphore};
use tokio::task::JoinSet;

use kinesis::helpers::get_shards;
Expand Down Expand Up @@ -113,6 +115,8 @@ async fn main() -> Result<()> {
};

let shard_processors = {
let semaphore = semaphore(shard_count, opt.concurrent);

selected_shards
.iter()
.map(|shard_id| {
Expand All @@ -122,6 +126,7 @@ async fn main() -> Result<()> {
shard_id.clone(),
from_datetime,
to_datetime,
semaphore.clone(),
tx_records.clone(),
tx_ticker_updates.clone(),
);
Expand All @@ -145,3 +150,12 @@ async fn main() -> Result<()> {

Ok(())
}

fn semaphore(shard_count: usize, concurrent: Option<usize>) -> Arc<Semaphore> {
let concurrent = match concurrent {
Some(concurrent) => concurrent,
None => std::cmp::min(shard_count, SEMAPHORE_DEFAULT_SIZE),
};

Arc::new(Semaphore::new(concurrent))
}
1 change: 0 additions & 1 deletion src/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ where
total_records_processed += records.len() as u32;
records.iter().for_each(|record| {
let data = self.format_record(record);
// writeln!(handle, "{}", data).unwrap();
let _ = handle.write(data.as_slice()).unwrap();
self.delimiter(handle).unwrap()
});
Expand Down

0 comments on commit b567bf9

Please sign in to comment.