Skip to content

Commit

Permalink
feat: WebSocket upgrade support (supabase#271)
Browse files Browse the repository at this point in the history
* stamp(sb_workers): polishing

* stamp: add `http_utils` crate

* stamp: update dependencies

* stamp: update `Cargo.lock`

* refactor: remove watcher api and introduce supabase tag api

* feat: websocket upgrade support

* stamp: rid duplicate things

* stamp(base): add dev dependencies

* stamp: add websocket upgrade test

* stamp: add websocket upgrade examples

* stamp: add websocket upgrade example for main worker

* stamp(http_utils): add license header
  • Loading branch information
nyannyacha authored Feb 27, 2024
1 parent 832531e commit 1399bb2
Show file tree
Hide file tree
Showing 27 changed files with 1,297 additions and 188 deletions.
49 changes: 49 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion crates/base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
http_utils = { version = "0.1.0", path = "../http_utils" }
async-trait.workspace = true
thiserror.workspace = true
monch.workspace = true
Expand All @@ -31,7 +32,7 @@ deno_webidl = { workspace = true }
deno_web = { workspace = true }
deno_websocket = { workspace = true }
httparse = { version = "1.8.0" }
hyper = { workspace = true, features = ["full"] }
hyper = { workspace = true, features = ["full", "backports"] }
http = { version = "0.2" }
import_map.workspace = true
log = { workspace = true }
Expand Down Expand Up @@ -65,7 +66,10 @@ deno_webgpu.workspace = true
sb_ai = { version = "0.1.0", path = "../sb_ai" }

[dev-dependencies]
tokio-util = { workspace = true, features = ["rt", "compat"] }
serial_test = { version = "3.0.0" }
async-tungstenite = { version = "0.25.0", default-features = false }
tungstenite = { version = "0.21.0", default-features = false, features = ["handshake"] }

[build-dependencies]
sb_core = { version = "0.1.0", path = "../sb_core" }
Expand Down
4 changes: 3 additions & 1 deletion crates/base/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ mod supabase_startup_snapshot {
use event_worker::js_interceptors::sb_events_js_interceptors;
use event_worker::sb_user_event_worker;
use sb_ai::sb_ai;
use sb_core::http_start::sb_core_http;
use sb_core::http::sb_core_http;
use sb_core::http_start::sb_core_http_start;
use sb_core::net::sb_core_net;
use sb_core::permissions::sb_core_permissions;
use sb_core::runtime::sb_core_runtime;
Expand Down Expand Up @@ -206,6 +207,7 @@ mod supabase_startup_snapshot {
sb_core_main_js::init_ops_and_esm(),
sb_core_net::init_ops_and_esm(),
sb_core_http::init_ops_and_esm(),
sb_core_http_start::init_ops_and_esm(),
deno_node::init_ops_and_esm::<Permissions>(None, fs),
sb_core_runtime::init_ops_and_esm(None),
];
Expand Down
4 changes: 3 additions & 1 deletion crates/base/src/deno_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ use futures_util::future::poll_fn;
use log::{error, trace};
use once_cell::sync::{Lazy, OnceCell};
use sb_core::conn_sync::ConnSync;
use sb_core::http::sb_core_http;
use sb_core::http_start::sb_core_http_start;
use sb_core::util::sync::AtomicFlag;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
Expand All @@ -37,7 +39,6 @@ use sb_ai::sb_ai;
use sb_core::cache::CacheSetting;
use sb_core::cert::ValueRootCertStoreProvider;
use sb_core::external_memory::custom_allocator;
use sb_core::http_start::sb_core_http;
use sb_core::net::sb_core_net;
use sb_core::permissions::{sb_core_permissions, Permissions};
use sb_core::runtime::sb_core_runtime;
Expand Down Expand Up @@ -293,6 +294,7 @@ impl DenoRuntime {
sb_core_main_js::init_ops(),
sb_core_net::init_ops(),
sb_core_http::init_ops(),
sb_core_http_start::init_ops(),
deno_node::init_ops::<Permissions>(Some(npm_resolver), file_system),
sb_core_runtime::init_ops(Some(main_module_url.clone())),
];
Expand Down
89 changes: 74 additions & 15 deletions crates/base/src/rt_worker/worker_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ use cpu_timer::CPUTimer;
use event_worker::events::{
BootEvent, ShutdownEvent, WorkerEventWithMetadata, WorkerEvents, WorkerMemoryUsed,
};
use http::StatusCode;
use http_utils::io::Upgraded2;
use http_utils::utils::{emit_status_code, get_upgrade_type};
use hyper::client::conn::http1;
use hyper::upgrade::OnUpgrade;
use hyper::{Body, Request, Response};
use log::{debug, error};
use sb_core::conn_sync::ConnSync;
Expand All @@ -22,6 +27,7 @@ use sb_workers::errors::WorkerError;
use std::future::pending;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::io::copy_bidirectional;
use tokio::net::UnixStream;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, watch, Notify};
Expand Down Expand Up @@ -83,43 +89,96 @@ async fn handle_request(
// create a unix socket pair
let (sender_stream, recv_stream) = UnixStream::pair()?;
let WorkerRequestMsg {
req,
mut req,
res_tx,
conn_watch,
} = msg;

let _ = unix_stream_tx.send((recv_stream, conn_watch.clone()));
let req_upgrade_type = get_upgrade_type(req.headers());
let req_upgrade = req_upgrade_type
.clone()
.and_then(|it| Some(it).zip(req.extensions_mut().remove::<OnUpgrade>()));

// send the HTTP request to the worker over Unix stream
let (mut request_sender, connection) = hyper::client::conn::handshake(sender_stream).await?;
let (mut request_sender, connection) = http1::Builder::new()
.writev(true)
.handshake(sender_stream)
.await?;

let (upgrade_tx, upgrade_rx) = oneshot::channel();

// spawn a task to poll the connection and drive the HTTP state
tokio::task::spawn(async move {
match connection.without_shutdown().await {
Err(e) => {
error!("Error in worker connection: {}", e.message(),);
}
tokio::task::spawn({
async move {
match connection.without_shutdown().await {
Err(e) => {
error!("Error in worker connection: {}", e.message());
}

Ok(parts) => {
if let Some((requested, req_upgrade)) = req_upgrade {
if let Ok((Some(accepted), status)) = upgrade_rx.await {
if status == StatusCode::SWITCHING_PROTOCOLS && accepted == requested {
tokio::spawn(relay_upgraded_request_and_response(
req_upgrade,
parts,
));

Ok(parts) => {
if let Some(mut watcher) = conn_watch {
if watcher.wait_for(|it| *it == ConnSync::Recv).await.is_err() {
error!("cannot track outbound connection correctly");
return;
}
};
}
}

drop(parts);
if let Some(mut watcher) = conn_watch {
if watcher.wait_for(|it| *it == ConnSync::Recv).await.is_err() {
error!("cannot track outbound connection correctly");
}
}
}
}
}
});

tokio::task::yield_now().await;

let result = request_sender.send_request(req).await;
let _ = res_tx.send(result);
let res = request_sender.send_request(req).await;
let Ok(res) = res else {
drop(res_tx.send(res));
return Ok(());
};

if let Some(requested) = req_upgrade_type {
let res_upgrade_type = get_upgrade_type(res.headers());
let _ = upgrade_tx.send((res_upgrade_type.clone(), res.status()));

match res_upgrade_type {
Some(accepted) if accepted == requested => {}
_ => {
drop(res_tx.send(Ok(emit_status_code(StatusCode::BAD_GATEWAY))));
return Ok(());
}
}
}

drop(res_tx.send(Ok(res)));
Ok(())
}

async fn relay_upgraded_request_and_response(
downstream: OnUpgrade,
parts: http1::Parts<UnixStream>,
) {
let mut upstream = Upgraded2::new(parts.io, parts.read_buf);
let mut downstream = downstream.await.expect("failed to upgrade request");

copy_bidirectional(&mut upstream, &mut downstream)
.await
.expect("coping between upgraded connections failed");

// XXX(Nyannyacha): Here you might want to emit the event metadata.
}

#[allow(clippy::too_many_arguments)]
pub fn create_supervisor(
key: Uuid,
Expand Down
5 changes: 3 additions & 2 deletions crates/base/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ impl Service<Request<Body>> for WorkerService {
parts,
Body::wrap_stream(NotifyOnEos {
inner: body,
cancel: Some(cancel.clone()),
cancel: Some(cancel),
}),
);

Expand Down Expand Up @@ -294,7 +294,8 @@ impl Server {
let _guard = cancel.drop_guard();

let conn_fut = Http::new()
.serve_connection(conn, service);
.serve_connection(conn, service)
.with_upgrades();

if let Err(e) = conn_fut.await {
// Most common cause for these errors are
Expand Down
13 changes: 13 additions & 0 deletions crates/base/test_cases/websocket-upgrade/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Deno.serve(async (req: Request) => {
const { socket, response } = Deno.upgradeWebSocket(req);

socket.onopen = () => {
socket.send("meow");
};

socket.onmessage = ev => {
socket.send(ev.data);
};

return response;
});
Loading

0 comments on commit 1399bb2

Please sign in to comment.