diff --git a/Cargo.toml b/Cargo.toml index 24e42c1..375a443 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,8 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -axum = "0.6.19" +axum = { version = "0.6.19", features = ["ws"] } +axum-extra = { version = "0.9.1", features = ["typed-header"] } clap = "3.2.6" tokio = { version = "1.35.1", features = ["full"] } serde = { version = "1.0", features = ["derive"] } @@ -19,7 +20,7 @@ timer = "0.2.0" lettre = { version = "0.10.4", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] } dirs = "3.0.2" reqwest = { version = "0.11.23", default-features = false, features = ["blocking", "json", "rustls-tls"] } - +futures = "0.3" rdkafka = { version = "0.33.2", default-features = false, features = ["cmake-build"] } [dev-dependencies] diff --git a/README.md b/README.md index ba68ca0..08fdda1 100644 --- a/README.md +++ b/README.md @@ -71,15 +71,17 @@ Supported types of job execution include: | kafka_resp_topic | 返回任务执行结果的Topic | Topic for response execute result | -Both real-time triggering and delayed triggering support HTTP and Kafka multi-channel access. By default, only HTTP is enabled. After setting the correct Kafka prefix parameters, tasks can be received from Kafka. +Both real-time triggering and delayed triggering support HTTP、WebSockets and Kafka multi-channel access. By default, only HTTP is enabled. After setting the correct Kafka prefix parameters, tasks can be received from Kafka. -实时触发和延迟触发均支持 HTTP 和 Kafka 多通道接入,默认只开启HTTP,在设置好正确的kafka前缀的参数后,即可从Kafka接收任务 +实时触发和延迟触发均支持 HTTP、 Kafka、WebSockets等 多通道接入,默认只开启HTTP和WebSockets,在设置好正确的kafka前缀的参数后,即可从Kafka接收任务 http 接收任务的例子(Sample for http): curl -X POST http://127.0.0.1:8000/task_in_queue -H 'Content-Type: application/json' -d '{...}' + WebSockets连接的地址为:(Address for WebSockets incomming) + ws://ip:port/ws_task The format of the JSON body for HTTP requests and the message topic for Kafka is consistent, both in JSON format, as defined below: diff --git a/src/main.rs b/src/main.rs index 89dd2f0..b4231a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ mod runner; +mod ws; extern crate redis; use std::sync::{Arc, Mutex}; @@ -8,11 +9,12 @@ use redis::{AsyncCommands, Commands}; use std::net::SocketAddr; use std::net::IpAddr; use std::net::Ipv4Addr; -use axum::{extract::{Json, path::Path as PathParam, State}, Router, routing::{post, get}, response::IntoResponse}; +use axum::{extract::{Json, path::Path as PathParam, State, ws::{Message as WsMessage, WebSocket, WebSocketUpgrade}}, Router, routing::{post, get}, response::IntoResponse, ServiceExt}; +use axum_extra::{headers, TypedHeader}; use std::{env, thread}; - use std::path::{Path, PathBuf}; use std::str::FromStr; +use axum::extract::ConnectInfo; use tokio::fs::File; use tokio::io::{self, BufReader, AsyncBufReadExt}; use chrono::{DateTime, Utc, Local}; @@ -25,6 +27,7 @@ use rdkafka::config::ClientConfig; use rdkafka::consumer::{Consumer, BaseConsumer}; use rdkafka::producer::{BaseProducer, BaseRecord, Producer}; use rdkafka::Message; +use crate::ws::{handle_socket, MessageExecutor}; #[derive(Clone)] @@ -32,7 +35,8 @@ pub struct AppState { pub config_path: String, pub redis_client: redis::Client, queue: QueueGroup, - pub config: AppConfig + pub config: AppConfig, + pub ws_executor: MessageExecutor } #[derive(Serialize, Deserialize, Debug)] @@ -206,6 +210,34 @@ impl KafkaProducer { } } +#[derive(Clone)] +struct ResponseQueue { + queue: Arc>> +} + +impl ResponseQueue { + fn new() -> ResponseQueue { + ResponseQueue{ queue: Arc::new(Mutex::new(VecDeque::new())) } + } + + fn queue_resp(&mut self, task_id: String, resp: String) { + while let Ok(mut queue) = self.queue.lock() { + queue.push_back((task_id.clone(), resp.clone())); + } + } + + fn wait_for(&mut self) -> Option<(String, String)> { + while let Ok(mut queue) = self.queue.lock() { + if let Some(resp) = queue.pop_front(){ + return Some(resp); + }else{ + return None; + } + } + None + } +} + const TASK_WRONG: &'static str = "task||wrong"; @@ -347,11 +379,15 @@ async fn main() { let mut queue_group = QueueGroup::init_by_number(workers); let client = redis::Client::open(redis).unwrap(); + let mut ws_exec = MessageExecutor::new(); + + let mut resp_queue = ResponseQueue::new(); for thread_id in 0..workers { let mut group = queue_group.clone(); let mut redis_connection = client.get_connection().unwrap(); let appconfig = appconfig.clone(); + let mut resp_queue1 = resp_queue.clone(); thread::spawn(move || { let mut pd = KafkaProducer::from_bootstrap(appconfig.kafka_servers.as_str(), appconfig.kafka_resp_topic.clone()); loop { @@ -388,6 +424,10 @@ async fn main() { if task.src_chn == "kafka" { pd.sent(serde_json::json!({"result": "OK", "request_id": task.id.clone()})); } + if task.src_chn == "ws" { + //ws_exec.response_for(task.id.clone(), serde_json::json!({"result": "OK", "request_id": task.id.clone()})); + resp_queue1.queue_resp(task_id.clone(), serde_json::to_string(&serde_json::json!({"result": "OK", "request_id": task.id.clone()})).unwrap()); + } } } redis_connection.srem::<&str, String, ()>(TASK_WORKING, task.id.clone()).expect("redis error"); @@ -528,6 +568,18 @@ async fn main() { }); } + let mut resp_queue2 = resp_queue.clone(); + let mut ws_exec1 = ws_exec.clone(); + tokio::spawn(async move { + loop { + if let Some((tid, resp_str)) = resp_queue2.wait_for() { + ws_exec1.response_for(tid, resp_str).await; + }else{ + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + } + } + }); + if let Ok(mut main_conn) = client.get_async_connection().await { if let Ok(working_ids) = main_conn.smembers::<&str, Vec>(TASK_WORKING).await { for tk_id in working_ids { @@ -546,17 +598,19 @@ async fn main() { config_path: cron_path.to_string(), redis_client: client.clone(), queue: queue_group.clone(), - config: appconfig.clone() + config: appconfig.clone(), + ws_executor: ws_exec.clone() }; let app = Router::new() .route("/task_in_queue", post(handler)) .route("/task_resp/:key", get(waiting)) + .route("/ws_task", get(ws_handler)) .route("/sys_info", get(system_info_handler)) .with_state(app_state); println!("server will start at 0.0.0.0:{}", port); let serv = axum::Server::bind(& SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), int_port)) - .serve(app.into_make_service()) + .serve(app.into_make_service_with_connect_info::()) .await; match serv { Ok(_)=>{ @@ -625,6 +679,21 @@ async fn waiting_for_result(conn: &mut Connection, flag: String, waiting: usize) } } +async fn ws_handler( + State(state): State, + ws: WebSocketUpgrade, + ConnectInfo(addr): ConnectInfo, +) -> impl IntoResponse { + let addr = addr.clone(); + println!("need to upgrade:{:?}", &addr); + ws.on_upgrade( move |socket| async move { + let socket = socket; + println!("socks {:?} connected, {}", &addr, state.queue.size); + //handle_socket(state.clone(), socket, addr) + handle_socket(state.clone(), socket, addr.clone()).await + }) +} + async fn system_info_handler(State(mut state): State) -> impl IntoResponse { let mut conn = state.redis_client.get_async_connection().await.unwrap(); let working_keys = if let Ok(_working_keys) = conn.smembers::>(TASK_WORKING.to_string()).await { diff --git a/src/ws.rs b/src/ws.rs new file mode 100644 index 0000000..08c1bae --- /dev/null +++ b/src/ws.rs @@ -0,0 +1,161 @@ +use std::ops::ControlFlow; +use axum::extract::ws::{Message, WebSocket}; +use std::{net::SocketAddr}; +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::mpsc::{channel, Sender}; +use chrono::{DateTime, Utc}; +use redis::aio::Connection; +use redis::AsyncCommands; +use tokio::sync::Mutex; +use crate::{AppState, TASK_DELAY, TASK_WORKING, WebTask}; +use futures::{sink::SinkExt, stream::StreamExt}; + +#[derive(Clone)] +enum WsMsg { + STR(String), + BYT(Vec) +} + +#[derive(Clone)] +pub struct MessageExecutor { + sender_map: Arc>>>, + req_map: Arc>>, +} + +impl MessageExecutor { + pub fn new()->MessageExecutor { + MessageExecutor { + sender_map: Arc::new(Mutex::new(HashMap::new())), + req_map: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub async fn bind_sender(&mut self, who: SocketAddr, sender: Sender) { + let mut sm = self.sender_map.lock().await; + if let Some(rs) = sm.insert(who, sender){ + } + } + + pub async fn clear_client(&mut self, who: SocketAddr) { + let mut sm = self.sender_map.lock().await; + if let Some(v) = sm.remove(&who) { + } + } + + pub async fn bind_request_id(&mut self, task_id: String, who: SocketAddr) { + let mut req_map = self.req_map.lock().await; + if req_map.contains_key(&task_id){ + req_map.remove(&task_id); + } + req_map.insert(task_id.clone(), who); + } + + pub async fn response_for(&mut self, task_id: String, resp: String) { + let who_is = { + let req_map = self.req_map.lock().await; + if let Some(who) = req_map.get(&task_id) { + Some(who.clone()) + } else { + None + } + }; + if let Some(the_who) = who_is { + let mut sender_map = self.sender_map.lock().await; + if let Some(mut sender) = sender_map.get(&the_who){ + sender.send(WsMsg::STR(resp)); + sender_map.remove(&the_who); + } + } + } + + +} + + +pub async fn handle_socket(mut state: AppState, mut socket: WebSocket, who: SocketAddr) { + let mut redis_conn = state.redis_client.get_async_connection().await.expect("Redis连接失败"); + let (mut sender, mut receiver) = socket.split(); + let (mpsc_tx, mpsc_rx) = channel::(); + state.ws_executor.bind_sender(who, mpsc_tx).await; + + let mut send_task = tokio::spawn(async move { + loop { + if let Ok(msg) = mpsc_rx.recv() { + if let WsMsg::STR(str_resp) = msg { + if sender.send(Message::Text(str_resp)).await.is_err() { + println!("向ws client写入消息错误"); + return 1; + } + } + }else{ + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + } + }); + + let mut recv_task = tokio::spawn(async move { + let mut cnt = 0; + while let Some(Ok(msg)) = receiver.next().await { + cnt += 1; + // print message and break if instructed to do so + if process_message(msg, who, &mut state, &mut redis_conn).await.is_break() { + break; + } + } + cnt + }); + tokio::select! { + rv_a = (&mut send_task) => { + match rv_a { + Ok(a) => println!("{a} messages sent to {who}"), + Err(a) => println!("Error sending messages {a:?}") + } + recv_task.abort(); + }, + rv_b = (&mut recv_task) => { + match rv_b { + Ok(b) => println!("Received {b} messages"), + Err(b) => println!("Error receiving messages {b:?}") + } + send_task.abort(); + } + } +} + +async fn process_message(msg: Message, who: SocketAddr, state: &mut AppState, redis_conn: &mut Connection) -> ControlFlow<(), ()> { + let msg_need_to_process = match msg { + Message::Text(t) => Some(t), + Message::Close(c) => { + if let Some(cf) = c { + println!(">>> {} sent close with code {} and reason `{}`", who, cf.code, cf.reason); + } else { + println!(">>> {who} somehow sent close message without CloseFrame"); + } + return ControlFlow::Break(()); + } + Message::Pong(v) => None, + Message::Ping(v) => None, + _=>None + }; + if let Some(msg_str) = msg_need_to_process { + if let Ok(mut web_task) = serde_json::from_str::(msg_str.as_str()) { + let mut task = web_task.gen_task(state.queue.size); + let now: DateTime = Utc::now(); + let now_ts = now.timestamp(); + task.src_chn = "ws".to_string(); + let redis_payload = serde_json::to_string(&task).unwrap(); + redis_conn.set::(task.id.clone(), redis_payload).await.expect("set error"); + state.ws_executor.bind_request_id(task.id.clone(), who.clone()).await; + if web_task.delay == 0 { + redis_conn.sadd::(TASK_WORKING.to_string(), task.id.clone()).await.expect("set list error"); + println!("开始分配Worker线程"); + state.queue.dispatch_task(&task); + } else { + let delay_key = format!("{} {}", task.id, now_ts + web_task.delay as i64); + redis_conn.sadd::(TASK_DELAY.to_string(), delay_key.clone()).await.expect("set list error"); + } + } + } + ControlFlow::Continue(()) +} \ No newline at end of file