diff --git a/theia/proxy-rs/Cargo.lock b/theia/proxy-rs/Cargo.lock index 205b136b..d3615106 100644 --- a/theia/proxy-rs/Cargo.lock +++ b/theia/proxy-rs/Cargo.lock @@ -36,6 +36,21 @@ version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anyhow" version = "1.0.82" @@ -236,6 +251,20 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "chrono" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" +dependencies = [ + "android-tzdata", + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-targets 0.52.5", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -253,6 +282,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + [[package]] name = "cpufeatures" version = "0.2.12" @@ -669,6 +704,22 @@ dependencies = [ "pin-project-lite", "smallvec", "tokio", + "want", +] + +[[package]] +name = "hyper-tungstenite" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad" +dependencies = [ + "http-body-util", + "hyper", + "hyper-util", + "pin-project-lite", + "tokio", + "tokio-tungstenite", + "tungstenite", ] [[package]] @@ -678,6 +729,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" dependencies = [ "bytes", + "futures-channel", "futures-util", "http", "http-body", @@ -685,6 +737,32 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", + "tower", + "tower-service", + "tracing", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", ] [[package]] @@ -1110,8 +1188,10 @@ dependencies = [ "anyhow", "axum", "axum-extra", + "chrono", "futures-util", "hyper", + "hyper-tungstenite", "hyper-util", "jsonwebtoken", "lazy_static", @@ -1906,6 +1986,12 @@ dependencies = [ "tracing-log", ] +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "tungstenite" version = "0.21.0" @@ -2011,6 +2097,15 @@ version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -2109,6 +2204,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets 0.52.5", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/theia/proxy-rs/Cargo.toml b/theia/proxy-rs/Cargo.toml index 2d776877..7a7b8c74 100644 --- a/theia/proxy-rs/Cargo.toml +++ b/theia/proxy-rs/Cargo.toml @@ -9,9 +9,11 @@ edition = "2021" anyhow = "1.0.82" axum = { version = "0.7.5", features = ["ws"] } axum-extra = { version = "0.9.3", features = ["cookie"] } +chrono = "0.4.38" futures-util = "0.3.30" hyper = "1.3.1" -hyper-util = { version = "0.1.3", features = ["tokio"] } +hyper-tungstenite = "0.13.0" +hyper-util = { version = "0.1.3", features = ["client", "client-legacy"] } jsonwebtoken = "9.3.0" lazy_static = "1.4.0" serde = "1.0.200" diff --git a/theia/proxy-rs/src/db.rs b/theia/proxy-rs/src/db.rs new file mode 100644 index 00000000..f5ac755a --- /dev/null +++ b/theia/proxy-rs/src/db.rs @@ -0,0 +1,47 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use sqlx::{prelude::FromRow, MySqlPool}; + +#[derive(Deserialize, Serialize, FromRow)] +struct TheiaSession { + cluster_address: Option, +} + +pub async fn get_cluster_address(pool: &MySqlPool, session_id: &str) -> Result { + let row: Option = sqlx::query_as( + r#" + SELECT cluster_address + FROM theia_session + WHERE id = ? AND active = 1 + "#, + ) + .bind(session_id) + .fetch_optional(pool) + .await?; + + match row { + Some(session) => { + if let Some(cluster_address) = session.cluster_address { + Ok(cluster_address) + } else { + Err(anyhow::anyhow!("cluster address not found")) + } + } + None => Err(anyhow::anyhow!("session not found")), + } +} + +pub async fn update_last_proxy_time(session_id: &str, pool: &MySqlPool) -> Result<()> { + sqlx::query( + r#" + UPDATE theia_session + SET last_proxy = NOW() + WHERE id = ? + "#, + ) + .bind(session_id) + .execute(pool) + .await?; + + Ok(()) +} diff --git a/theia/proxy-rs/src/main.rs b/theia/proxy-rs/src/main.rs index 8b6df135..13464615 100644 --- a/theia/proxy-rs/src/main.rs +++ b/theia/proxy-rs/src/main.rs @@ -1,29 +1,34 @@ +mod db; +mod proxy; mod ws; -use anyhow::Result; +use anyhow::{Context, Result}; use axum::{ + body::Body, extract::{Path, Query, Request, WebSocketUpgrade}, http::StatusCode, - response::{IntoResponse, Redirect}, + response::{IntoResponse, Redirect, Response}, routing::get, Extension, Router, }; use axum_extra::extract::cookie::{Cookie, CookieJar}; +use chrono::Utc; +use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Validation}; use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; -use sqlx::{mysql::MySqlPoolOptions, prelude::FromRow, MySqlPool}; +use sqlx::{mysql::MySqlPoolOptions, MySqlPool}; use std::time::Duration; use std::{env::var, sync::Arc}; use tower_http::trace::{self, TraceLayer}; use tracing::Level; const PROXY_SERVER_PORT: u64 = 5000; -const MAX_PROXY_PORT: u64 = 8010; +// TODO: add back port range +// const MAX_PROXY_PORT: u64 = 8010; const MAX_DB_CONNECTIONS: u32 = 50; const FAILED_REDIRECT_URL: &str = "https://anubis-lms.io/error"; -/// JWT expires in 6 hours -const JWT_EXPIRATION: usize = 6 * 60 * 60; +const JWT_EXPIRATION_HOURS: i64 = 6; // Lazy static evaluation of environment variables lazy_static! { @@ -43,12 +48,18 @@ lazy_static! { let host = var("DB_HOST").unwrap_or("127.0.0.1".to_string()); let port = var("DB_PORT").unwrap_or("3306".to_string()); let user = String::from("anubis"); + let database = String::from("anubis"); let password = var("DB_PASSWORD").unwrap_or("anubis".to_string()); - format!("mysql://{}:{}@{}:{}", user, password, host, port) + format!( + "mysql://{}:{}@{}:{}/{}", + user, password, host, port, database + ) }; } +pub type Client = hyper_util::client::legacy::Client; + #[derive(Debug, Serialize, Deserialize)] struct Claims { exp: usize, @@ -62,17 +73,19 @@ fn authenticate_jwt(token: &str) -> Result { let validation = Validation::new(Algorithm::HS256); let decoded = decode::(token, &key, &validation)?; - if decoded.claims.exp < JWT_EXPIRATION { - return Err(anyhow::anyhow!("Token expired")); - } - Ok(decoded.claims) } fn create_jwt(session_id: &str, net_id: &str) -> Result { let key = EncodingKey::from_secret(SECRET_KEY.as_bytes()); + + let expiration = Utc::now() + .checked_add_signed(chrono::Duration::hours(JWT_EXPIRATION_HOURS)) + .expect("valid timestamp") + .timestamp(); + let claims = Claims { - exp: JWT_EXPIRATION, + exp: expiration as usize, session_id: session_id.to_string(), net_id: net_id.to_string(), }; @@ -82,27 +95,11 @@ fn create_jwt(session_id: &str, net_id: &str) -> Result { Ok(token) } -async fn update_last_proxy_time(session_id: &str, pool: &MySqlPool) -> Result<()> { - sqlx::query( - r#" - UPDATE theia_session - SET last_proxy = NOW() - WHERE id = $1 - "#, - ) - .bind(session_id) - .execute(pool) - .await?; - - Ok(()) -} - async fn ping(jar: CookieJar, Extension(pool): Extension>) -> (StatusCode, String) { - tracing::info!("Ping"); match jar.get("ide") { Some(cookie) => match authenticate_jwt(cookie.value()) { Ok(claims) => { - update_last_proxy_time(&claims.session_id, &pool) + db::update_last_proxy_time(&claims.session_id, &pool) .await .unwrap(); } @@ -137,8 +134,7 @@ async fn initialize(params: Query, jar: CookieJar) -> imp }; let new_token = create_jwt(&token.session_id, &token.net_id).unwrap(); - let mut ide_cookie = Cookie::new("ide", new_token); - ide_cookie.set_http_only(true); + let ide_cookie = Cookie::new("ide", new_token); let new_jar = jar.add(ide_cookie); @@ -154,60 +150,49 @@ async fn initialize(params: Query, jar: CookieJar) -> imp ) } -#[derive(Deserialize, Serialize, FromRow)] -struct TheiaSession { - cluster_address: Option, +#[derive(Debug)] +pub struct Target { + pub host: String, + pub port: u16, } -async fn get_cluster_address(pool: &MySqlPool, session_id: &str) -> Result { - let row: Option = sqlx::query_as( - r#" - SELECT cluster_address - FROM theia_session - WHERE id = $1 AND active = 1 - "#, - ) - .bind(session_id) - .fetch_optional(pool) - .await?; - - match row { - Some(session) => { - if let Some(cluster_address) = session.cluster_address { - Ok(cluster_address) - } else { - Err(anyhow::anyhow!("Cluster address not found")) - } - } - None => Err(anyhow::anyhow!("Session not found")), - } +#[derive(Debug)] +enum ProxyType { + Http, + WebSocket, } +/// Handles generic incoming requests to be proxied to the coder server +/// Depending on wether the upgrade websocket header is present +/// the request will be proxied as a websocket or http request async fn handle( - ws: WebSocketUpgrade, + ws_upgrade: Option, Extension(pool): Extension>, + Extension(client): Extension>, + _path: Option>, jar: CookieJar, - _req: Request, -) -> impl IntoResponse { - // let port = path.parse::().unwrap(); - - // if port > MAX_PROXY_PORT { - // return (StatusCode::BAD_REQUEST, "Invalid port".to_string()); - // } + req: Request, +) -> Result { + let proxy_type = match ws_upgrade { + Some(_) => ProxyType::WebSocket, + None => ProxyType::Http, + }; let token = match jar.get("ide") { Some(cookie) => match authenticate_jwt(cookie.value()) { Ok(claims) => claims, - Err(_) => { - return (StatusCode::UNAUTHORIZED, "Invalid token".to_string()); + Err(err) => { + tracing::error!("failed to authenticate jwt: {}", err); + return Err((StatusCode::UNAUTHORIZED, "Invalid token".to_string())); } }, None => { - return (StatusCode::UNAUTHORIZED, "No token provided".to_string()); + tracing::error!("could not find jwt"); + return Err((StatusCode::UNAUTHORIZED, "No token provided".to_string())); } }; - let cluster_address = get_cluster_address(&pool, &token.session_id) + let cluster_address = db::get_cluster_address(&pool, &token.session_id) .await .map_err(|e| { eprintln!("Error: {}", e); @@ -215,37 +200,73 @@ async fn handle( StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string(), ) - }) - .unwrap(); + })?; - let host = format!("ws://{}:{}", cluster_address, MAX_PROXY_PORT); + let target = Target { + host: cluster_address.clone(), + // TODO: handle port from the request path + port: 5000, + }; - let _result = ws.on_upgrade(move |socket| ws::forward(host, socket)); + tracing::debug!("proxying request to target: {:?}", target); + + let result = match proxy_type { + ProxyType::Http => match proxy::proxy_req(client, req, target).await { + Ok(res) => res, + Err(err) => { + tracing::error!("failed to proxy http request: {}", err); + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + "failed to proxy http request".to_string(), + )); + } + }, + ProxyType::WebSocket => { + if let Some(ws) = ws_upgrade { + // upgrade the connection to websocket and proxy the websocket request to the + // target host and port + ws.on_upgrade(move |socket| ws::forward(socket, target)) + } else { + return Err(( + StatusCode::BAD_REQUEST, + "no ws upgrade found for ws proxy request".to_string(), + )); + } + } + }; - (StatusCode::OK, "authorized".to_string()) + Ok(result) } #[tokio::main] -async fn main() { +async fn main() -> Result<()> { tracing_subscriber::fmt() .with_target(false) .compact() .init(); - let pool = MySqlPoolOptions::new() - .max_connections(MAX_DB_CONNECTIONS) - .acquire_timeout(Duration::from_secs(5)) - .connect(&DB_URL) - .await - .unwrap(); - - let pool = Arc::new(pool); + let pool = Arc::new( + MySqlPoolOptions::new() + .max_connections(MAX_DB_CONNECTIONS) + .acquire_timeout(Duration::from_secs(5)) + .connect(&DB_URL) + .await + .context("failed to connect to db")?, + ); + + // client used for proxying http requests + let client: Arc = Arc::new( + hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) + .build(HttpConnector::new()), + ); let app = Router::new() .route("/ping", get(ping)) .route("/initialize", get(initialize)) .route("/", get(handle)) + .route("/*path", get(handle)) .layer(Extension(pool)) + .layer(Extension(client)) .layer( TraceLayer::new_for_http() .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO)) @@ -262,4 +283,6 @@ async fn main() { .unwrap(); axum::serve(listener, app).await.unwrap(); + + Ok(()) } diff --git a/theia/proxy-rs/src/proxy.rs b/theia/proxy-rs/src/proxy.rs new file mode 100644 index 00000000..d646945c --- /dev/null +++ b/theia/proxy-rs/src/proxy.rs @@ -0,0 +1,24 @@ +use axum::{ + extract::Request, + response::{IntoResponse, Response}, +}; + +use anyhow::Result; +use std::sync::Arc; + +use super::{Client, Target}; + +pub async fn proxy_req(client: Arc, mut req: Request, target: Target) -> Result { + let path = req.uri().path(); + let path_query = req + .uri() + .path_and_query() + .map(|v| v.as_str()) + .unwrap_or(path); + + *req.uri_mut() = format!("http://{}:{}{}", target.host, target.port, path_query) + .parse() + .unwrap(); + + Ok(client.request(req).await.unwrap().into_response()) +} diff --git a/theia/proxy-rs/src/ws.rs b/theia/proxy-rs/src/ws.rs index 2f47d581..566d8ad6 100644 --- a/theia/proxy-rs/src/ws.rs +++ b/theia/proxy-rs/src/ws.rs @@ -1,3 +1,4 @@ +use super::Target; use axum::extract::ws::{CloseFrame, Message as AxumMessage, WebSocket}; use futures_util::{SinkExt, StreamExt}; use tokio_tungstenite::tungstenite; @@ -12,6 +13,11 @@ struct WebSocketMessage { message: WebSocketMessageType, } +/// The incoming websocket upgrade connection is of type axum +/// but the outgoing connection to the target server is of type tungstenite +/// so we need to convert between the two message types +/// +/// TODO: find a better way to handle this impl WebSocketMessage { fn tungstenite(message: TsMessage) -> Self { Self { @@ -66,11 +72,12 @@ impl WebSocketMessage { } } -pub async fn forward(url: String, client_ws: WebSocket) { +pub async fn forward(client_ws: WebSocket, target: Target) { + let url = format!("ws://{}:{}", target.host, target.port); let server_ws = match connect_async(url).await { Ok((ws, _)) => ws, Err(e) => { - tracing::warn!("connect error: {}", e); + tracing::warn!("failed to connect to target websocket during proxy: {}", e); return; } }; @@ -86,12 +93,12 @@ pub async fn forward(url: String, client_ws: WebSocket) { let message = WebSocketMessage::axum(message); let res = server_write.send(message.into_tungstenite()).await; if let Err(e) = res { - tracing::warn!("client write error: {}", e); + tracing::warn!("client ws write error: {}", e); continue; } } Err(e) => { - tracing::warn!("client read error: {}", e); + tracing::warn!("client ws read error: {}", e); continue; } } @@ -104,12 +111,12 @@ pub async fn forward(url: String, client_ws: WebSocket) { let message = WebSocketMessage::tungstenite(message); let res = client_write.send(message.into_axum()).await; if let Err(e) = res { - tracing::warn!("client write error: {}", e); + tracing::warn!("server ws write error: {}", e); continue; } } Err(e) => { - tracing::warn!("client read error: {}", e); + tracing::warn!("server ws read error: {}", e); continue; } }