Skip to content

Commit

Permalink
Console made generic (#7)
Browse files Browse the repository at this point in the history
* Console made generic (#7)
* Fix split_at
  • Loading branch information
gr211 authored May 8, 2023
1 parent a887ad3 commit b6b3d6b
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 141 deletions.
57 changes: 0 additions & 57 deletions src/console/tests.rs

This file was deleted.

12 changes: 6 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ use std::io;
use tokio::sync::mpsc;

use crate::cli_helpers::parse_date;
use crate::console::Console;
use crate::sink::console::ConsoleSink;
use crate::sink::Sink;
use kinesis::helpers::get_shards;
use kinesis::models::*;
mod console;

mod iterator;
mod kinesis;
mod sink;

#[derive(Debug, Parser)]
struct Opt {
Expand Down Expand Up @@ -160,16 +162,14 @@ async fn main() -> Result<(), io::Error> {
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
}

Console::new(
ConsoleSink::new(
max_messages,
print_key,
print_shard,
print_timestamp,
print_delimiter,
rx_records,
tx_records,
)
.run()
.run(tx_records, rx_records)
.await
}

Expand Down
186 changes: 108 additions & 78 deletions src/console.rs → src/sink.rs
Original file line number Diff line number Diff line change
@@ -1,70 +1,106 @@
use crate::kinesis::models::*;
use chrono::*;
use std::io::{self, BufWriter, Error, Write};
use std::rc::Rc;
use async_trait::async_trait;
use chrono::TimeZone;
use std::io;
use std::io::{BufWriter, Error, Write};
use std::sync::Arc;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::Mutex;
pub const CONSOLE_BUF_SIZE: usize = 8 * 1024; // 8kB

pub struct Console {
use crate::kinesis::models::{PanicError, RecordResult, ShardProcessorADT};

pub mod console;

#[derive(Clone)]
pub struct SinkConfig {
max_messages: Option<u32>,
print_key: bool,
print_shardid: bool,
print_timestamp: bool,
print_delimiter: bool,
exit_after_termination: bool,
rx_records: Receiver<Result<ShardProcessorADT, PanicError>>,
tx_records: Sender<Result<ShardProcessorADT, PanicError>>,
}

impl Console {
pub fn new(
max_messages: Option<u32>,
print_key: bool,
print_shardid: bool,
print_timestamp: bool,
print_delimiter: bool,
pub trait Configurable {
fn get_config(&self) -> SinkConfig;
}

#[async_trait]
pub trait SinkOutput<W>
where
W: Write + Send,
{
fn offer(&mut self) -> BufWriter<W>;
}

#[async_trait]
pub trait Sink<T, W>
where
W: Write + Send,
T: SinkOutput<W> + Configurable + Send + Sync,
{
async fn run_inner(
&mut self,
tx_records: Sender<Result<ShardProcessorADT, PanicError>>,
rx_records: Receiver<Result<ShardProcessorADT, PanicError>>,
handle: &mut BufWriter<W>,
) -> io::Result<()>;

async fn run(
&mut self,
tx_records: Sender<Result<ShardProcessorADT, PanicError>>,
) -> Console {
Console {
max_messages,
print_key,
print_shardid,
print_timestamp,
print_delimiter,
exit_after_termination: true,
rx_records,
tx_records,
rx_records: Receiver<Result<ShardProcessorADT, PanicError>>,
) -> io::Result<()>;

fn handle_termination(&self, tx_records: Sender<Result<ShardProcessorADT, PanicError>>);

fn delimiter(&self, handle: &mut BufWriter<W>) -> Result<(), Error>;

fn format_nb_messages(&self, messages_processed: u32) -> String {
match messages_processed {
0 => "0 message processed".to_string(),
1 => "1 message processed".to_string(),
_ => format!("{} messages processed", messages_processed),
}
}

pub async fn run(&mut self) -> io::Result<()> {
let stdout = io::stdout(); // get the global stdout entity
let mut handle = io::BufWriter::with_capacity(CONSOLE_BUF_SIZE, stdout);
fn format_record(&self, record_result: &RecordResult) -> String;

self.run_inner(&mut handle).await
fn format_records(&self, record_results: &[RecordResult]) -> Vec<String> {
record_results
.iter()
.map(|record_result| self.format_record(record_result))
.collect()
}
}

pub async fn run_inner<W>(&mut self, handle: &mut BufWriter<W>) -> io::Result<()>
where
W: std::io::Write,
{
#[async_trait]
impl<T, W> Sink<T, W> for T
where
W: Write + Send,
T: SinkOutput<W> + Configurable + Send + Sync,
{
async fn run_inner(
&mut self,
tx_records: Sender<Result<ShardProcessorADT, PanicError>>,
mut rx_records: Receiver<Result<ShardProcessorADT, PanicError>>,
handle: &mut BufWriter<W>,
) -> io::Result<()> {
self.delimiter(handle).unwrap();
let count = Rc::new(Mutex::new(0));

self.handle_termination();
let count = Arc::new(Mutex::new(0));

self.handle_termination(tx_records.clone());

while let Some(res) = self.rx_records.recv().await {
while let Some(res) = rx_records.recv().await {
match res {
Ok(adt) => match adt {
ShardProcessorADT::Progress(res) => {
let mut lock = count.lock().await;

match self.max_messages {
match self.get_config().max_messages {
Some(max_messages) => {
if *lock >= max_messages {
self.tx_records
tx_records
.send(Ok(ShardProcessorADT::Termination))
.await
.unwrap();
Expand All @@ -77,9 +113,10 @@ impl Console {
};

if remaining > 0 && !res.is_empty() {
*lock += res.len() as u32;
let split_at = std::cmp::min(remaining as usize, res.len());
*lock += split_at as u32;

let split = res.split_at(remaining as usize);
let split = res.split_at(split_at);
let to_display = split.0;

let data = self.format_records(to_display);
Expand All @@ -88,7 +125,7 @@ impl Console {
writeln!(handle, "{}", data).unwrap();
});
self.delimiter(handle)?
};
}
}
None => {
let data = self.format_records(res.as_slice());
Expand All @@ -102,51 +139,51 @@ impl Console {
}
}
ShardProcessorADT::Termination => {
let messages_processed = match self.max_messages {
Some(max_messages) => max_messages,
_ => *count.lock().await,
};
let messages_processed = *count.lock().await;

writeln!(handle, "{}", self.format_nb_messages(messages_processed))?;
handle.flush()?;
self.rx_records.close();
rx_records.close();

if self.exit_after_termination {
if self.get_config().exit_after_termination {
std::process::exit(0)
}
}
},
Err(e) => {
panic!("Error: {:?}", e);
}
};
}
}

Ok(())
}

fn format_nb_messages(&self, messages_processed: u32) -> String {
match messages_processed {
0 => "0 message processed".to_string(),
1 => "1 message processed".to_string(),
_ => format!("{} messages processed", messages_processed),
}
async fn run(
&mut self,
tx_records: Sender<Result<ShardProcessorADT, PanicError>>,
rx_records: Receiver<Result<ShardProcessorADT, PanicError>>,
) -> io::Result<()> {
let r = &mut self.offer();
self.run_inner(tx_records, rx_records, r).await
}

fn handle_termination(&self) {
let tx_records_clone = self.tx_records.clone();
ctrlc_async::set_async_handler(async move {
tx_records_clone
.send(Ok(ShardProcessorADT::Termination))
.await
.unwrap();
})
.expect("Error setting Ctrl-C handler");
fn handle_termination(&self, tx_records: Sender<Result<ShardProcessorADT, PanicError>>) {
// Note: the exit_after_termination check is to help
// with tests where only one handler can be registered.
if self.get_config().exit_after_termination {
ctrlc_async::set_async_handler(async move {
tx_records
.send(Ok(ShardProcessorADT::Termination))
.await
.unwrap();
})
.expect("Error setting Ctrl-C handler");
}
}

fn delimiter<W>(&self, handle: &mut BufWriter<W>) -> Result<(), Error>
where
W: std::io::Write,
{
if self.print_delimiter {
fn delimiter(&self, handle: &mut BufWriter<W>) -> Result<(), Error> {
if self.get_config().print_delimiter {
writeln!(
handle,
"------------------------------------------------------------------------"
Expand All @@ -155,31 +192,24 @@ impl Console {
Ok(())
}

fn format_records(&self, record_results: &[RecordResult]) -> Vec<String> {
record_results
.iter()
.map(|record_result| self.format_record(record_result))
.collect()
}

fn format_record(&self, record_result: &RecordResult) -> String {
let data = std::str::from_utf8(record_result.data.as_slice())
.unwrap()
.to_string();

let data = if self.print_key {
let data = if self.get_config().print_key {
format!("{} {}", record_result.sequence_id, data)
} else {
data
};

let data = if self.print_shardid {
let data = if self.get_config().print_shardid {
format!("{} {}", record_result.shard_id, data)
} else {
data
};

if self.print_timestamp {
if self.get_config().print_timestamp {
let date = chrono::Utc
.timestamp_opt(record_result.datetime.secs(), 0)
.unwrap();
Expand Down
Loading

0 comments on commit b6b3d6b

Please sign in to comment.