Skip to content

Commit

Permalink
better implementation/refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielePalaia committed Oct 15, 2024
1 parent 8105c4a commit e409ac5
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 92 deletions.
3 changes: 1 addition & 2 deletions src/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ impl Environment {
) -> SuperStreamProducerBuilder<NoDedup> {
SuperStreamProducerBuilder {
environment: self.clone(),
name: None,
data: PhantomData,
filter_value_extractor: None,
//filter_value_extractor: None,
routing_strategy: routing_strategy,
}
}
Expand Down
7 changes: 4 additions & 3 deletions src/superstream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ pub struct DefaultSuperStreamMetadata {
impl DefaultSuperStreamMetadata {
pub async fn partitions(&mut self) -> Vec<String> {
if self.partitions.len() == 0 {
println!("partition len is 0");
let response = self.client.partitions(self.super_stream.clone()).await;

self.partitions = response.unwrap().streams;
}

return self.partitions.clone();
}
pub async fn routes(&mut self, routing_key: String) -> Vec<String> {
Expand Down Expand Up @@ -73,14 +73,15 @@ impl HashRoutingMurmurStrategy {
message: Message,
metadata: &mut DefaultSuperStreamMetadata,
) -> Vec<String> {
println!("im in routes");
let mut streams: Vec<String> = Vec::new();

let key = (self.routing_extractor)(message);
let key = (self.routing_extractor)(message.clone());
let hash_result = murmur3_32(&mut Cursor::new(key), 104729);

let number_of_partitions = metadata.partitions().await.len();
let route = hash_result.unwrap() % number_of_partitions as u32;
let partitions: Vec<String> = metadata.partitions().await;
let partitions = metadata.partitions().await;
let stream = partitions.into_iter().nth(route as usize).unwrap();
streams.push(stream);

Expand Down
91 changes: 48 additions & 43 deletions src/superstream_producer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::error::ProducerCloseError;
use crate::{
client::Client,
environment::Environment,
Expand All @@ -17,25 +18,23 @@ type FilterValueExtractor = Arc<dyn Fn(&Message) -> String + 'static + Send + Sy
pub struct SuperStreamProducer<T>(
Arc<SuperStreamProducerInternal>,
HashMap<String, Producer<T>>,
DefaultSuperStreamMetadata,
PhantomData<T>,
);

/// Builder for [`SuperStreamProducer`]
pub struct SuperStreamProducerBuilder<T> {
pub(crate) environment: Environment,
pub(crate) name: Option<String>,
pub filter_value_extractor: Option<FilterValueExtractor>,
//pub filter_value_extractor: Option<FilterValueExtractor>,
pub routing_strategy: RoutingStrategy,
pub(crate) data: PhantomData<T>,
}

pub struct SuperStreamProducerInternal {
pub(crate) environment: Environment,
client: Client,
super_stream: String,
publish_version: u16,
filter_value_extractor: Option<FilterValueExtractor>,
super_stream_metadata: DefaultSuperStreamMetadata,
// TODO: implement filtering for superstream
//filter_value_extractor: Option<FilterValueExtractor>,
routing_strategy: RoutingStrategy,
}

Expand All @@ -54,35 +53,48 @@ impl SuperStreamProducer<NoDedup> {
{
let routes = match self.0.routing_strategy.clone() {
RoutingStrategy::HashRoutingStrategy(routing_strategy) => {
routing_strategy
.routes(message.clone(), &mut self.0.super_stream_metadata.clone())
.await
routing_strategy.routes(message.clone(), &mut self.2).await
}
RoutingStrategy::RoutingKeyStrategy(routing_strategy) => {
routing_strategy
.routes(message.clone(), &mut self.0.super_stream_metadata.clone())
.await
routing_strategy.routes(message.clone(), &mut self.2).await
}
};

for route in routes.into_iter() {
if !self.1.contains_key(route.as_str()) {
let producer = self.0.environment.producer().build(route.as_str()).await;

self.1.insert(route.clone(), producer.unwrap());
let producer = self.0.environment.producer().build(route.as_str()).await;
self.1.insert(route.clone(), producer.unwrap());
}

println!("sending message to super_stream {}", route.clone());

let producer = self.1.get(route.as_str()).unwrap();
let result = producer.send(message.clone(), cb.clone()).await;
match result {
Ok(()) => println!("Message correctly sent"),
Err(e) => println!("Error {}", e),
}
let result = producer.send(message.clone(), cb.clone()).await?;
}
Ok(())
}


pub async fn close(self) -> Result<(), ProducerCloseError> {
self.0.client.close().await?;

let mut err: Option<ProducerCloseError> = None;
let mut is_error = false;
for (_, producer) in self.1.into_iter() {
let close = producer.close().await;
match close {
Err(e) => {
is_error = true;
err = Some(e);
}
_ => (),
}
}

if is_error == false {
return Ok(());
} else {
return Err(err.unwrap());
}
}
}

impl<T> SuperStreamProducerBuilder<T> {
Expand All @@ -95,36 +107,29 @@ impl<T> SuperStreamProducerBuilder<T> {
// to the leader anyway - it is the only one capable of writing.
let client = self.environment.create_client().await?;

let mut publish_version = 1;

if self.filter_value_extractor.is_some() {
if client.filtering_supported() {
publish_version = 2
} else {
return Err(ProducerCreateError::FilteringNotSupport);
}
}

let producers = HashMap::new();

let super_stream_metadata = DefaultSuperStreamMetadata {
super_stream: super_stream.to_string(),
client: self.environment.create_client().await?,
partitions: Vec::new(),
routes: Vec::new(),
};

let super_stream_producer = SuperStreamProducerInternal {
environment: self.environment.clone(),
super_stream: super_stream.to_string(),
client,
publish_version,
filter_value_extractor: self.filter_value_extractor,
//filter_value_extractor: self.filter_value_extractor,
routing_strategy: self.routing_strategy,
super_stream_metadata: DefaultSuperStreamMetadata {
super_stream: super_stream.to_string(),
client: self.environment.create_client().await?,
partitions: Vec::new(),
routes: Vec::new(),
},
};

let internal_producer = Arc::new(super_stream_producer);
let super_stream_producer =
SuperStreamProducer(internal_producer.clone(), producers, PhantomData);
let super_stream_producer = SuperStreamProducer(
internal_producer.clone(),
producers,
super_stream_metadata,
PhantomData,
);

Ok(super_stream_producer)
}
Expand Down
71 changes: 31 additions & 40 deletions tests/integration/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::collections::HashMap;

use fake::{Fake, Faker};
use rabbitmq_stream_client::{Client, ClientOptions, Environment};
use rabbitmq_stream_protocol::commands::generic::GenericResponse;
use rabbitmq_stream_protocol::ResponseCode;

pub struct TestClient {
Expand Down Expand Up @@ -38,26 +39,7 @@ impl TestClient {
let super_stream: String = Faker.fake();
let client = Client::connect(ClientOptions::default()).await.unwrap();

let partitions: Vec<String> = [
super_stream.to_string() + "-0",
super_stream.to_string() + "-1",
super_stream.to_string() + "-2",
]
.iter()
.map(|x| x.into())
.collect();

let binding_keys: Vec<String> = ["0", "1", "2"].iter().map(|&x| x.into()).collect();

let response = client
.create_super_stream(
&super_stream,
partitions.clone(),
binding_keys,
HashMap::new(),
)
.await
.unwrap();
let (response, partitions) = create_generic_super_stream(&super_stream, &client).await;

assert_eq!(&ResponseCode::Ok, response.code());
TestClient {
Expand Down Expand Up @@ -109,26 +91,7 @@ impl TestEnvironment {
let client = Client::connect(ClientOptions::default()).await.unwrap();
let env = Environment::builder().build().await.unwrap();

let partitions: Vec<String> = [
super_stream.to_string() + "-0",
super_stream.to_string() + "-1",
super_stream.to_string() + "-2",
]
.iter()
.map(|x| x.into())
.collect();

let binding_keys: Vec<String> = ["0", "1", "2"].iter().map(|&x| x.into()).collect();

let response = client
.create_super_stream(
&super_stream,
partitions.clone(),
binding_keys,
HashMap::new(),
)
.await
.unwrap();
let (response, partitions) = create_generic_super_stream(&super_stream, &client).await;

assert_eq!(&ResponseCode::Ok, response.code());
TestEnvironment {
Expand Down Expand Up @@ -160,3 +123,31 @@ impl Drop for TestEnvironment {
}
}
}

pub async fn create_generic_super_stream(
super_stream: &String,
client: &Client,
) -> (GenericResponse, Vec<String>) {
let partitions: Vec<String> = [
super_stream.to_string() + "-0",
super_stream.to_string() + "-1",
super_stream.to_string() + "-2",
]
.iter()
.map(|x| x.into())
.collect();

let binding_keys: Vec<String> = ["0", "1", "2"].iter().map(|&x| x.into()).collect();

let response = client
.create_super_stream(
&super_stream,
partitions.clone(),
binding_keys,
HashMap::new(),
)
.await
.unwrap();

return (response, partitions);
}
7 changes: 3 additions & 4 deletions tests/integration/producer_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@ use rabbitmq_stream_client::types::{
};

use crate::common::TestEnvironment;
use rabbitmq_stream_protocol::message::Value;
use rabbitmq_stream_protocol::utils::TupleMapperSecond;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tokio::sync::Notify;
use tokio::time::{sleep, Duration};

#[tokio::test(flavor = "multi_thread")]
async fn producer_send_no_name_ok() {
Expand Down Expand Up @@ -400,7 +397,6 @@ fn routing_key_strategy_value_extractor(message: Message) -> String {

fn hash_strategy_value_extractor(message: Message) -> String {
let s = String::from_utf8(Vec::from(message.data().unwrap())).expect("Found invalid UTF-8");

return s;
}

Expand Down Expand Up @@ -441,6 +437,7 @@ async fn key_super_steam_producer_test() {
}

notify_on_send.notified().await;
_ = super_stream_producer.close();
}

#[tokio::test(flavor = "multi_thread")]
Expand All @@ -462,6 +459,7 @@ async fn hash_super_steam_producer_test() {
.unwrap();

for i in 0..message_count {
println!("sending message {}", i);
let counter = confirmed_messages.clone();
let notifier = notify_on_send.clone();
let msg = Message::builder().body(format!("message{}", i)).build();
Expand All @@ -480,6 +478,7 @@ async fn hash_super_steam_producer_test() {
}

notify_on_send.notified().await;
_ = super_stream_producer.close();
}

#[tokio::test(flavor = "multi_thread")]
Expand Down

0 comments on commit e409ac5

Please sign in to comment.