diff --git a/shardus_net/src/lib.rs b/shardus_net/src/lib.rs index 3ef41ab..9d607f4 100644 --- a/shardus_net/src/lib.rs +++ b/shardus_net/src/lib.rs @@ -300,13 +300,13 @@ pub fn multi_send_with_header(mut cx: FunctionContext) -> JsResult let header_js_string: String = cx.argument::(3)?.value(cx) as String; let data_js_string: String = cx.argument::(4)?.value(cx) as String; let complete_cb = cx.argument::(5)?.root(cx); - let await_processing = cx.argument::(6)?.value(cx); // this flag lets us skip the processing on the stats and the callback + let schedule_complete_callback = cx.argument::(6)?.value(cx); // this flag lets us skip the processing on the stats and the callback let shardus_net_sender = cx.this().get::>, _, _>(cx, "_sender")?; let stats_incrementers = cx.this().get::, _, _>(cx, "_stats_incrementers")?; let this = cx.this().root(cx); - let channel = cx.channel(); + let nodejs_thread_channel = cx.channel(); for _ in 0..ports.len() { stats_incrementers.increment_outstanding_sends(); @@ -322,51 +322,44 @@ pub fn multi_send_with_header(mut cx: FunctionContext) -> JsResult let data = data_js_string.into_bytes().to_vec(); - // Create oneshot channels for each host-port pair - let mut senders = Vec::with_capacity(hosts.len()); - let mut receivers = Vec::with_capacity(hosts.len()); + let complete_cb = Arc::new(complete_cb); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); - // should a check be added to see if ports.len == hosts.len - for _ in 0..hosts.len() { - let (sender, receiver) = oneshot::channel::(); - senders.push(sender); - receivers.push(receiver); - } + RUNTIME.spawn(async move { - let complete_cb = Arc::new(complete_cb); - let this = Arc::new(this); - - // Handle the responses asynchronously - for receiver in receivers { - let channel = channel.clone(); - let complete_cb = complete_cb.clone(); - let this = this.clone(); - - RUNTIME.spawn(async move { - let result = receiver.await.expect("Complete send tx dropped before notify"); - - if await_processing { - RUNTIME.spawn_blocking(move || { - channel.send(move |mut cx| { - let cx = &mut cx; - let stats = this.to_inner(cx).get::>, _, _>(cx, "_stats")?; - (**stats).borrow_mut().decrement_outstanding_sends(); - - let this = cx.undefined(); - - if let Err(err) = result { - let error = cx.string(format!("{:?}", err)); - complete_cb.to_inner(cx).call(cx, this, [error.upcast()])?; - } else { - complete_cb.to_inner(cx).call(cx, this, [])?; - } - - Ok(()) - }); - }); - } - }); - } + let mut results = Vec::new(); + + // recv will return None when all tx are dropped + // So this'll not hang forever. + while let Some(result) = rx.recv().await { + results.push(result); + } + + if schedule_complete_callback { + nodejs_thread_channel.send(move |mut cx| { + let cx = &mut cx; + + let stats = this.to_inner(cx).get::>, _, _>(cx, "_stats")?; + let js_arr = cx.empty_array(); + let mut error_count = 0; + for i in 0..results.len() { + (**stats).borrow_mut().decrement_outstanding_sends(); + if let Err(err) = &results[i] { + let err = cx.string(format!("{:?}", err)); + js_arr.set(cx, error_count, err)?; + error_count += 1; + } + } + + let undef = cx.undefined(); + + complete_cb.to_inner(cx).call(cx, undef, [js_arr.upcast()])?; + + Ok(()) + }); + } + + }); let mut addresses = Vec::new(); for (host, port) in hosts.iter().zip(ports.iter()) { @@ -384,7 +377,7 @@ pub fn multi_send_with_header(mut cx: FunctionContext) -> JsResult } // Send each address with its corresponding sender - shardus_net_sender.multi_send_with_header(addresses, header_version, header, data, senders); + shardus_net_sender.multi_send_with_header(addresses, header_version, header, data, tx); Ok(cx.undefined()) } @@ -577,6 +570,7 @@ fn get_sender_address(mut cx: FunctionContext) -> JsResult { #[neon::main] fn main(mut cx: ModuleContext) -> NeonResult<()> { + cx.export_function("Sn", create_shardus_net)?; cx.export_function("setLoggingEnabled", set_logging_enabled)?; diff --git a/shardus_net/src/shardus_net_sender.rs b/shardus_net/src/shardus_net_sender.rs index c3a70f4..3f1209a 100644 --- a/shardus_net/src/shardus_net_sender.rs +++ b/shardus_net/src/shardus_net_sender.rs @@ -2,7 +2,6 @@ use super::runtime::RUNTIME; use crate::header::header_types::Header; use crate::header_factory::{header_serialize_factory, wrap_serialized_message}; use crate::message::Message; -use crate::oneshot::Sender; use crate::shardus_crypto; use log::error; #[cfg(debug)] @@ -31,9 +30,14 @@ pub enum SenderError { pub type SendResult = Result<(), SenderError>; +pub enum Transmitter { + Oneshot(tokio::sync::oneshot::Sender), + Mpsc(UnboundedSender), +} + pub struct ShardusNetSender { key_pair: crypto::KeyPair, - send_channel: UnboundedSender<(SocketAddr, Vec, Sender)>, + send_channel: UnboundedSender<(SocketAddr, Vec, Transmitter)>, evict_socket_channel: UnboundedSender, } @@ -53,15 +57,15 @@ impl ShardusNetSender { } // send: send data to a socket address without a header - pub fn send(&self, address: SocketAddr, data: String, complete_tx: Sender) { + pub fn send(&self, address: SocketAddr, data: String, complete_tx: tokio::sync::oneshot::Sender) { let data = data.into_bytes(); self.send_channel - .send((address, data, complete_tx)) + .send((address, data, Transmitter::Oneshot(complete_tx))) .expect("Unexpected! Failed to send data to channel. Sender task must have been dropped."); } // send_with_header: send data to a socket address with a header and signature - pub fn send_with_header(&self, address: SocketAddr, header_version: u8, mut header: Header, data: Vec, complete_tx: Sender) { + pub fn send_with_header(&self, address: SocketAddr, header_version: u8, mut header: Header, data: Vec, complete_tx: tokio::sync::oneshot::Sender) { let compressed_data = header.compress(data); header.set_message_length(compressed_data.len() as u32); let serialized_header = header_serialize_factory(header_version, header).expect("Failed to serialize header"); @@ -69,22 +73,22 @@ impl ShardusNetSender { message.sign(shardus_crypto::get_shardus_crypto_instance(), &self.key_pair); let serialized_message = wrap_serialized_message(message.serialize()); self.send_channel - .send((address, serialized_message, complete_tx)) + .send((address, serialized_message, Transmitter::Oneshot(complete_tx))) .expect("Unexpected! Failed to send data with header to channel. Sender task must have been dropped."); } // multi_send_with_header: send data to multiple socket addresses with a single header and signature - pub fn multi_send_with_header(&self, addresses: Vec, header_version: u8, mut header: Header, data: Vec, senders: Vec>) { + pub fn multi_send_with_header(&self, addresses: Vec, header_version: u8, mut header: Header, data: Vec, complete_tx: tokio::sync::mpsc::UnboundedSender) { let compressed_data = header.compress(data); header.set_message_length(compressed_data.len() as u32); let serialized_header = header_serialize_factory(header_version, header).expect("Failed to serialize header"); let mut message = Message::new_unsigned(header_version, serialized_header.clone(), compressed_data.clone()); message.sign(shardus_crypto::get_shardus_crypto_instance(), &self.key_pair); let serialized_message = wrap_serialized_message(message.serialize()); - - for (address, sender) in addresses.into_iter().zip(senders.into_iter()) { + + for address in addresses { self.send_channel - .send((address, serialized_message.clone(), sender)) + .send((address, serialized_message.clone(), Transmitter::Mpsc(complete_tx.clone()))) .expect("Failed to send data with header to channel"); } } @@ -111,7 +115,7 @@ impl ShardusNetSender { }); } - fn spawn_sender(send_channel_rx: UnboundedReceiver<(SocketAddr, Vec, Sender)>, connections: Arc>) { + fn spawn_sender(send_channel_rx: UnboundedReceiver<(SocketAddr, Vec, Transmitter)>, connections: Arc>) { RUNTIME.spawn(async move { let mut send_channel_rx = send_channel_rx; @@ -123,7 +127,14 @@ impl ShardusNetSender { RUNTIME.spawn(async move { let result = connection.send(data).await; - complete_tx.send(result).ok(); + match complete_tx { + Transmitter::Oneshot(complete_tx) => { + complete_tx.send(result).ok().expect("Failed to send result to oneshot rx") + } + Transmitter::Mpsc(complete_tx) => { + complete_tx.send(result).ok().expect("Failed to send result to mspc rx, rx might have been dropped") + } + } }); } diff --git a/src/index.ts b/src/index.ts index 939b584..e29d018 100644 --- a/src/index.ts +++ b/src/index.ts @@ -139,7 +139,7 @@ export const Sn = (opts: SnOpts) => { version: number headerData: CombinedHeader }, - awaitProcessing: boolean = true + callbackEnabled: boolean = true ) => { return new Promise<{ success: boolean; error?: string }>((resolve, reject) => { const stringifiedData = jsonStringify(augData, opts.customStringifier) @@ -148,11 +148,22 @@ export const Sn = (opts: SnOpts) => { : null /* prettier-ignore */ if(logFlags.net_verbose) logMessageInfo(augData, stringifiedData) - const sendCallback = (error) => { + const multiSendCallback = (error: string[]) => { + if (error.length == address.length) { + throw new Error(`_sendAug: request_id: ${augData.UUID} error sending from rust failure lib-net: ${error.join(', ')}`) + } + if (error.length > 0) { + return resolve({ success: false, error: error.join(', ') }) + } + return resolve({ success: true }) + } + + const sendCallback = (error?: string) => { if (error) { resolve({ success: false, error }) - } else { + }else { resolve({ success: true }) + } } try { @@ -167,8 +178,8 @@ export const Sn = (opts: SnOpts) => { optionalHeader.version, stringifiedHeader, stringifiedData, - sendCallback, - awaitProcessing + multiSendCallback, + callbackEnabled ) } else { if (logFlags.net_verbose) console.log('send_with_header') diff --git a/test/test_mutli_send.ts b/test/test_mutli_send.ts new file mode 100644 index 0000000..63f03b0 --- /dev/null +++ b/test/test_mutli_send.ts @@ -0,0 +1,133 @@ +import { Command } from 'commander' +import { Sn } from '../.' +import { AppHeader, Sign } from '../build/src/types' + +const setupLruSender = (port: number, lruSize: number) => { + return Sn({ + port, + address: '127.0.0.1', + crypto: { + signingSecretKeyHex: + 'c3774b92cc8850fb4026b073081290b82cab3c0f66cac250b4d710ee9aaf83ed8088b37f6f458104515ae18c2a05bde890199322f62ab5114d20c77bde5e6c9d', + hashKey: '69fa4195670576c0160d660c3be36556ff8d504725be8a59b5a96509e0c994bc', + }, + senderOpts: { + useLruCache: true, + lruSize: lruSize, + }, + headerOpts: { + sendHeaderVersion: 1, + }, + }) +} + +const main = async () => { + /* + create a cli with the following options: + -p, --port Port to listen on + -c, --cache Size of the LRU cache + + the cli should create a sender with the following options: + - lruSize: + - port: + + on running the cli a listener should be started and sending of message with input from terminal should be allowed + */ + + /* + Commands to use for multi_send_with_header + + ts-node test/test_multi_send.ts -p 44000 -c 2 + path/to/test_multi_send.ts -p -c + + data 3 ping + + */ + + console.log('Starting cli...') + + const program = new Command() + program.requiredOption('-p, --port ', 'Port to listen on') + program.option('-c, --cache ', 'Size of the LRU cache', '2') + program.parse(process.argv) + + const port = program.port.toString() + const cacheSize = program.cache.toString() + + console.log(`Starting listener on port ${port} with cache size ${cacheSize}`) + + const sn = setupLruSender(+port, +cacheSize) + + const input = process.stdin + input.addListener('data', async (data: Buffer) => { + const inputs = data.toString().trim().split(' ') + const basePort = 44001 + let ports: number[] = [] + + if (inputs.length === 3) { + let count = Number(inputs[1]) // number of ip addresses you want to send info to + // make sure you have enough servers setup to listen on based on count + for (let i = 0; i < count; i++) { + let port = basePort + i + console.log('The port ' + port) + ports.push(port) + } + const baseAddress = '127.0.0.1' + const addresses = Array(count).fill(baseAddress) + const message = inputs[2] + + await sn.multiSendWithHeader( + ports, + addresses, + { message, fromPort: +port }, + { + compression: 'Brotli', + sender_id: 'test', + }, + 1000, + (data: unknown, header?: AppHeader) => { + console.log('onResp: Received response:', JSON.stringify(data, null, 2)) + if (header) { + console.log('onResp: Received header:', JSON.stringify(header, null, 2)) + } + } + ) + console.log('Message sent') + } else if (inputs.length === 2) { + sn.evictSocket(+inputs[1], '127.0.0.1') + console.log('Cache cleared') + } else { + console.log('=> send ') + console.log('=> clear ') + } + }) + + sn.listen(async (data: any, remote, respond, header, sign) => { + if (data && data.message === 'ping') { + console.log('Received ping from:', data.fromPort) + console.log('Ping header:', JSON.stringify(header, null, 2)) + // await sleep(10000) + return respond( + { message: 'pong', fromPort: +port }, + { + compression: 'Brotli', + } + ) + } + if (data && data.message === 'pong') { + console.log('Received pong from:', data.fromPort) + } + if (header) { + console.log('Received header:', JSON.stringify(header, null, 2)) + } + if (sign) { + console.log('Received signature:', JSON.stringify(sign, null, 2)) + } + }) +} + +const sleep = (ms: number) => { + return new Promise((resolve) => setTimeout(resolve, ms)) +} + +main().catch((err) => console.log('ERROR: ', err))