Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gold-180 - fix: inproper nodejs task scheduling, use mpsc for multi_send #10

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 40 additions & 46 deletions shardus_net/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,13 +300,13 @@ pub fn multi_send_with_header(mut cx: FunctionContext) -> JsResult<JsUndefined>
let header_js_string: String = cx.argument::<JsString>(3)?.value(cx) as String;
let data_js_string: String = cx.argument::<JsString>(4)?.value(cx) as String;
let complete_cb = cx.argument::<JsFunction>(5)?.root(cx);
let await_processing = cx.argument::<JsBoolean>(6)?.value(cx); // this flag lets us skip the processing on the stats and the callback
let schedule_complete_callback = cx.argument::<JsBoolean>(6)?.value(cx); // this flag lets us skip the processing on the stats and the callback

let shardus_net_sender = cx.this().get::<JsBox<Arc<ShardusNetSender>>, _, _>(cx, "_sender")?;
let stats_incrementers = cx.this().get::<JsBox<Incrementers>, _, _>(cx, "_stats_incrementers")?;

let this = cx.this().root(cx);
let channel = cx.channel();
let nodejs_thread_channel = cx.channel();

for _ in 0..ports.len() {
stats_incrementers.increment_outstanding_sends();
Expand All @@ -322,51 +322,44 @@ pub fn multi_send_with_header(mut cx: FunctionContext) -> JsResult<JsUndefined>

let data = data_js_string.into_bytes().to_vec();

// Create oneshot channels for each host-port pair
let mut senders = Vec::with_capacity(hosts.len());
let mut receivers = Vec::with_capacity(hosts.len());
let complete_cb = Arc::new(complete_cb);
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<SendResult>();

// should a check be added to see if ports.len == hosts.len
for _ in 0..hosts.len() {
let (sender, receiver) = oneshot::channel::<SendResult>();
senders.push(sender);
receivers.push(receiver);
}
RUNTIME.spawn(async move {

let complete_cb = Arc::new(complete_cb);
let this = Arc::new(this);

// Handle the responses asynchronously
for receiver in receivers {
let channel = channel.clone();
let complete_cb = complete_cb.clone();
let this = this.clone();

RUNTIME.spawn(async move {
let result = receiver.await.expect("Complete send tx dropped before notify");

if await_processing {
RUNTIME.spawn_blocking(move || {
channel.send(move |mut cx| {
let cx = &mut cx;
let stats = this.to_inner(cx).get::<JsBox<RefCell<Stats>>, _, _>(cx, "_stats")?;
(**stats).borrow_mut().decrement_outstanding_sends();

let this = cx.undefined();

if let Err(err) = result {
let error = cx.string(format!("{:?}", err));
complete_cb.to_inner(cx).call(cx, this, [error.upcast()])?;
} else {
complete_cb.to_inner(cx).call(cx, this, [])?;
}

Ok(())
});
});
}
});
}
let mut results = Vec::new();

// recv will return None when all tx are dropped
// So this'll not hang forever.
while let Some(result) = rx.recv().await {
results.push(result);
}

if schedule_complete_callback {
nodejs_thread_channel.send(move |mut cx| {
let cx = &mut cx;

let stats = this.to_inner(cx).get::<JsBox<RefCell<Stats>>, _, _>(cx, "_stats")?;
let js_arr = cx.empty_array();
let mut error_count = 0;
for i in 0..results.len() {
(**stats).borrow_mut().decrement_outstanding_sends();
if let Err(err) = &results[i] {
let err = cx.string(format!("{:?}", err));
js_arr.set(cx, error_count, err)?;
error_count += 1;
}
}

let undef = cx.undefined();

complete_cb.to_inner(cx).call(cx, undef, [js_arr.upcast()])?;
kgmyatthu marked this conversation as resolved.
Show resolved Hide resolved

Ok(())
});
}

});

let mut addresses = Vec::new();
for (host, port) in hosts.iter().zip(ports.iter()) {
Expand All @@ -384,7 +377,7 @@ pub fn multi_send_with_header(mut cx: FunctionContext) -> JsResult<JsUndefined>
}

// Send each address with its corresponding sender
shardus_net_sender.multi_send_with_header(addresses, header_version, header, data, senders);
shardus_net_sender.multi_send_with_header(addresses, header_version, header, data, tx);

Ok(cx.undefined())
}
Expand Down Expand Up @@ -577,6 +570,7 @@ fn get_sender_address(mut cx: FunctionContext) -> JsResult<JsObject> {

#[neon::main]
fn main(mut cx: ModuleContext) -> NeonResult<()> {

cx.export_function("Sn", create_shardus_net)?;

cx.export_function("setLoggingEnabled", set_logging_enabled)?;
Expand Down
35 changes: 23 additions & 12 deletions shardus_net/src/shardus_net_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use super::runtime::RUNTIME;
use crate::header::header_types::Header;
use crate::header_factory::{header_serialize_factory, wrap_serialized_message};
use crate::message::Message;
use crate::oneshot::Sender;
use crate::shardus_crypto;
use log::error;
#[cfg(debug)]
Expand Down Expand Up @@ -31,9 +30,14 @@ pub enum SenderError {

pub type SendResult = Result<(), SenderError>;

pub enum Transmitter<T> {
Oneshot(tokio::sync::oneshot::Sender<T>),
Mpsc(UnboundedSender<T>),
}

pub struct ShardusNetSender {
key_pair: crypto::KeyPair,
send_channel: UnboundedSender<(SocketAddr, Vec<u8>, Sender<SendResult>)>,
send_channel: UnboundedSender<(SocketAddr, Vec<u8>, Transmitter<SendResult>)>,
evict_socket_channel: UnboundedSender<SocketAddr>,
}

Expand All @@ -53,38 +57,38 @@ impl ShardusNetSender {
}

// send: send data to a socket address without a header
pub fn send(&self, address: SocketAddr, data: String, complete_tx: Sender<SendResult>) {
pub fn send(&self, address: SocketAddr, data: String, complete_tx: tokio::sync::oneshot::Sender<SendResult>) {
let data = data.into_bytes();
self.send_channel
.send((address, data, complete_tx))
.send((address, data, Transmitter::Oneshot(complete_tx)))
.expect("Unexpected! Failed to send data to channel. Sender task must have been dropped.");
}

// send_with_header: send data to a socket address with a header and signature
pub fn send_with_header(&self, address: SocketAddr, header_version: u8, mut header: Header, data: Vec<u8>, complete_tx: Sender<SendResult>) {
pub fn send_with_header(&self, address: SocketAddr, header_version: u8, mut header: Header, data: Vec<u8>, complete_tx: tokio::sync::oneshot::Sender<SendResult>) {
let compressed_data = header.compress(data);
header.set_message_length(compressed_data.len() as u32);
let serialized_header = header_serialize_factory(header_version, header).expect("Failed to serialize header");
let mut message = Message::new_unsigned(header_version, serialized_header, compressed_data);
message.sign(shardus_crypto::get_shardus_crypto_instance(), &self.key_pair);
let serialized_message = wrap_serialized_message(message.serialize());
self.send_channel
.send((address, serialized_message, complete_tx))
.send((address, serialized_message, Transmitter::Oneshot(complete_tx)))
.expect("Unexpected! Failed to send data with header to channel. Sender task must have been dropped.");
}

// multi_send_with_header: send data to multiple socket addresses with a single header and signature
pub fn multi_send_with_header(&self, addresses: Vec<SocketAddr>, header_version: u8, mut header: Header, data: Vec<u8>, senders: Vec<Sender<SendResult>>) {
pub fn multi_send_with_header(&self, addresses: Vec<SocketAddr>, header_version: u8, mut header: Header, data: Vec<u8>, complete_tx: tokio::sync::mpsc::UnboundedSender<SendResult>) {
let compressed_data = header.compress(data);
header.set_message_length(compressed_data.len() as u32);
let serialized_header = header_serialize_factory(header_version, header).expect("Failed to serialize header");
let mut message = Message::new_unsigned(header_version, serialized_header.clone(), compressed_data.clone());
message.sign(shardus_crypto::get_shardus_crypto_instance(), &self.key_pair);
let serialized_message = wrap_serialized_message(message.serialize());

for (address, sender) in addresses.into_iter().zip(senders.into_iter()) {
for address in addresses {
self.send_channel
.send((address, serialized_message.clone(), sender))
.send((address, serialized_message.clone(), Transmitter::Mpsc(complete_tx.clone())))
.expect("Failed to send data with header to channel");
}
}
Expand All @@ -111,7 +115,7 @@ impl ShardusNetSender {
});
}

fn spawn_sender(send_channel_rx: UnboundedReceiver<(SocketAddr, Vec<u8>, Sender<SendResult>)>, connections: Arc<Mutex<dyn ConnectionCache + Send>>) {
fn spawn_sender(send_channel_rx: UnboundedReceiver<(SocketAddr, Vec<u8>, Transmitter<SendResult>)>, connections: Arc<Mutex<dyn ConnectionCache + Send>>) {
RUNTIME.spawn(async move {
let mut send_channel_rx = send_channel_rx;

Expand All @@ -123,7 +127,14 @@ impl ShardusNetSender {

RUNTIME.spawn(async move {
let result = connection.send(data).await;
complete_tx.send(result).ok();
match complete_tx {
Transmitter::Oneshot(complete_tx) => {
complete_tx.send(result).ok().expect("Failed to send result to oneshot rx")
}
Transmitter::Mpsc(complete_tx) => {
complete_tx.send(result).ok().expect("Failed to send result to mspc rx, rx might have been dropped")
}
}
});
}

Expand Down
21 changes: 16 additions & 5 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ export const Sn = (opts: SnOpts) => {
version: number
headerData: CombinedHeader
},
awaitProcessing: boolean = true
callbackEnabled: boolean = true
) => {
return new Promise<{ success: boolean; error?: string }>((resolve, reject) => {
const stringifiedData = jsonStringify(augData, opts.customStringifier)
Expand All @@ -148,11 +148,22 @@ export const Sn = (opts: SnOpts) => {
: null
/* prettier-ignore */ if(logFlags.net_verbose) logMessageInfo(augData, stringifiedData)

const sendCallback = (error) => {
const multiSendCallback = (error: string[]) => {
if (error.length == address.length) {
throw new Error(`_sendAug: request_id: ${augData.UUID} error sending from rust failure lib-net: ${error.join(', ')}`)
}
if (error.length > 0) {
return resolve({ success: false, error: error.join(', ') })
}
return resolve({ success: true })
}

const sendCallback = (error?: string) => {
if (error) {
resolve({ success: false, error })
} else {
}else {
resolve({ success: true })

}
}
try {
Expand All @@ -167,8 +178,8 @@ export const Sn = (opts: SnOpts) => {
optionalHeader.version,
stringifiedHeader,
stringifiedData,
sendCallback,
awaitProcessing
multiSendCallback,
callbackEnabled
)
} else {
if (logFlags.net_verbose) console.log('send_with_header')
Expand Down
133 changes: 133 additions & 0 deletions test/test_mutli_send.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import { Command } from 'commander'
import { Sn } from '../.'
import { AppHeader, Sign } from '../build/src/types'

const setupLruSender = (port: number, lruSize: number) => {
return Sn({
port,
address: '127.0.0.1',
crypto: {
signingSecretKeyHex:
'c3774b92cc8850fb4026b073081290b82cab3c0f66cac250b4d710ee9aaf83ed8088b37f6f458104515ae18c2a05bde890199322f62ab5114d20c77bde5e6c9d',
hashKey: '69fa4195670576c0160d660c3be36556ff8d504725be8a59b5a96509e0c994bc',
},
senderOpts: {
useLruCache: true,
lruSize: lruSize,
},
headerOpts: {
sendHeaderVersion: 1,
},
})
}

const main = async () => {
/*
create a cli with the following options:
-p, --port <port> Port to listen on
-c, --cache <size> Size of the LRU cache

the cli should create a sender with the following options:
- lruSize: <size>
- port: <port>

on running the cli a listener should be started and sending of message with input from terminal should be allowed
*/

/*
Commands to use for multi_send_with_header

ts-node test/test_multi_send.ts -p 44000 -c 2
path/to/test_multi_send.ts -p <port> -c <cache_size>

data 3 ping
<route> <connections to send data> <message>
*/

console.log('Starting cli...')

const program = new Command()
program.requiredOption('-p, --port <port>', 'Port to listen on')
program.option('-c, --cache <size>', 'Size of the LRU cache', '2')
program.parse(process.argv)

const port = program.port.toString()
const cacheSize = program.cache.toString()

console.log(`Starting listener on port ${port} with cache size ${cacheSize}`)

const sn = setupLruSender(+port, +cacheSize)

const input = process.stdin
input.addListener('data', async (data: Buffer) => {
const inputs = data.toString().trim().split(' ')
const basePort = 44001
let ports: number[] = []

if (inputs.length === 3) {
let count = Number(inputs[1]) // number of ip addresses you want to send info to
// make sure you have enough servers setup to listen on based on count
for (let i = 0; i < count; i++) {
let port = basePort + i
console.log('The port ' + port)
ports.push(port)
}
const baseAddress = '127.0.0.1'
const addresses = Array(count).fill(baseAddress)
const message = inputs[2]

await sn.multiSendWithHeader(
ports,
addresses,
{ message, fromPort: +port },
{
compression: 'Brotli',
sender_id: 'test',
},
1000,
(data: unknown, header?: AppHeader) => {
console.log('onResp: Received response:', JSON.stringify(data, null, 2))
if (header) {
console.log('onResp: Received header:', JSON.stringify(header, null, 2))
}
}
)
console.log('Message sent')
} else if (inputs.length === 2) {
sn.evictSocket(+inputs[1], '127.0.0.1')
console.log('Cache cleared')
} else {
console.log('=> send <port> <message>')
console.log('=> clear <port>')
}
})

sn.listen(async (data: any, remote, respond, header, sign) => {
if (data && data.message === 'ping') {
console.log('Received ping from:', data.fromPort)
console.log('Ping header:', JSON.stringify(header, null, 2))
// await sleep(10000)
return respond(
{ message: 'pong', fromPort: +port },
{
compression: 'Brotli',
}
)
}
if (data && data.message === 'pong') {
console.log('Received pong from:', data.fromPort)
}
if (header) {
console.log('Received header:', JSON.stringify(header, null, 2))
}
if (sign) {
console.log('Received signature:', JSON.stringify(sign, null, 2))
}
})
}

const sleep = (ms: number) => {
return new Promise((resolve) => setTimeout(resolve, ms))
}

main().catch((err) => console.log('ERROR: ', err))
Loading