Skip to content

Commit

Permalink
Improve server start ergonomy (#10)
Browse files Browse the repository at this point in the history
* add multiple ways of starting the server based on memory requirements

* bump version before compile

* use lazy_static only if necessary
  • Loading branch information
gabotechs authored Oct 15, 2023
1 parent 8c3040d commit 0e0cc3c
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 51 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/test-lint-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ jobs:
- name: Install bump tool
run: cargo install cargo-workspaces cargo-zigbuild

- name: Bump versions
run: |
SEM_VER=$(.github/semver.sh)
cargo workspaces version $SEM_VER -a -y --force '*' --no-git-commit
- name: Compile for x86_64-unknown-linux-gnu
run: |
rustup target add x86_64-unknown-linux-gnu
Expand All @@ -107,9 +112,6 @@ jobs:
- name: Tag
id: tag
run: |
SEM_VER=$(.github/semver.sh)
cargo workspaces version $SEM_VER -a -y --force '*' --no-git-commit
version=`grep '^version = ' Cargo.toml | sed 's/version = //; s/\"//; s/\"//'`
git config user.name github-actions
git config user.email [email protected]
Expand Down
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ edition = "2021"
members = ["server"]

[dependencies]
lazy_static = "^1.4.0"
signway-server = { path = "server" }
tokio = { version = "1.28.1", features = ["full"] }
tracing = "^0.1.37"
Expand Down
2 changes: 1 addition & 1 deletion server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ hmac = "^0.12.1"
sha2 = "^0.10.6"
time = { version = "^0.3.6", features = ["formatting", "macros", "parsing"] }
percent-encoding = "^2.2.0"
lazy_static = "^1.4.0"
tracing = "^0.1.37"
async-trait = "^0.1.68"

[dev-dependencies]
lazy_static = "^1.4.0"
reqwest = { version = "^0.11.18", features = ["json"] }
97 changes: 70 additions & 27 deletions server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ impl<T: SecretGetter> SignwayServer<T> {
client,
}
}

pub fn from_port(secret_getter: T, port: u16) -> SignwayServer<T> {
let https = HttpsConnector::new();
let client = hyper::Client::builder().build::<_, Body>(https);
Expand Down Expand Up @@ -146,6 +147,35 @@ impl<T: SecretGetter> SignwayServer<T> {
res
}

async fn service_fn(&self, req: Request<Body>) -> Result<Response<Body>> {
let res = if req.method() == Method::OPTIONS {
Ok(ok())
} else {
self.route_gateway(req).await
};
if let Ok(res) = res {
Ok(self.with_cors_headers(res))
} else {
Ok(res?)
}
}

/// Starts the server using Box::leak so that its lifetime becomes 'static. This way, the
/// memmory occupied by the the `SignwayServer` instance will never be freed and will live
/// forever in the program, even after this function has finished. This avoids the
/// runtime costs of maintaining a reference count to the `SignwayServer` instance to
/// share across requests, at the cost of never freeing the memory occupied by the instance.
/// Usually, an application that spawns a server will expect the server to live forever,
/// so it is a fair assumption to keep its memory forever.
pub async fn start_leak(self) -> Result<()> {
let self_leak = Box::leak(Box::new(self));
self_leak.start().await
}

/// Starts the server expecting it to be 'static, meaning that it is going to live forever.
/// Applications might choose to store the server in a static variable by themselves, so that
/// they can just call this method, which does not have the performance implications of
/// maintaining reference counting for the server instance per request.
pub async fn start(&'static self) -> Result<()> {
let in_addr: SocketAddr = ([0, 0, 0, 0], self.port).into();

Expand All @@ -155,18 +185,39 @@ impl<T: SecretGetter> SignwayServer<T> {
loop {
let (stream, _) = listener.accept().await?;

let service = service_fn(move |req| async move {
let res = if req.method() == Method::OPTIONS {
Ok(ok())
} else {
self.route_gateway(req).await
};
if let Ok(res) = res {
Ok(self.with_cors_headers(res))
} else {
res
let service = service_fn(|req| self.service_fn(req));

tokio::spawn(async move {
if let Err(err) = hyper::server::conn::Http::new()
.serve_connection(stream, service)
.await
{
error!("Failed to serve the connection: {:?}", err);
}
});
}
}

/// Starts the server by maintaining a reference count in each request handle. This will grant
/// that the memmory occupied by the server will be freed after this function has finished and
/// all the requests have already been handled. Use this function if your application will
/// continue running after stopping the server. Calling this function has a small runtime cost
/// for maintaining the reference counting.
pub async fn start_arc(self) -> Result<()> {
let in_addr: SocketAddr = ([0, 0, 0, 0], self.port).into();

let listener = TcpListener::bind(in_addr).await?;

let arc_self = Arc::new(self);
info!("Server running in {}", in_addr);
loop {
let (stream, _) = listener.accept().await?;

let arc_self = arc_self.clone();
let service = service_fn(move |req| {
let arc_self = arc_self.clone();
async move { arc_self.service_fn(req).await }
});

tokio::spawn(async move {
if let Err(err) = hyper::server::conn::Http::new()
Expand All @@ -186,7 +237,6 @@ mod tests {

use hyper::http::HeaderValue;
use hyper::StatusCode;
use lazy_static::lazy_static;
use reqwest::header::HeaderMap;
use time::{OffsetDateTime, PrimitiveDateTime};
use url::Url;
Expand All @@ -197,16 +247,6 @@ mod tests {

use super::*;

lazy_static! {
static ref SERVER: SignwayServer<InMemorySecretGetter> =
server_for_testing([("foo", "foo-secret")], 3000);
}

async fn init() -> &'static str {
tokio::spawn(SERVER.start());
"http://localhost:3000"
}

fn server_for_testing<const N: usize>(
config: [(&str, &str); N],
port: u16,
Expand Down Expand Up @@ -240,9 +280,10 @@ mod tests {

#[tokio::test]
async fn simple_get_works() {
let host = init().await;
let server = server_for_testing([("foo", "foo-secret")], 3000);
tokio::spawn(server.start_leak());
let signer = UrlSigner::new("foo", "foo-secret");
let signed_url = signer.get_signed_url(host, &base_request()).unwrap();
let signed_url = signer.get_signed_url("http://localhost:3000", &base_request()).unwrap();

let response = reqwest::Client::new().get(signed_url).send().await.unwrap();

Expand All @@ -259,9 +300,10 @@ mod tests {

#[tokio::test]
async fn options_returns_cors() {
let host = init().await;
let server = server_for_testing([("foo", "foo-secret")], 3001);
tokio::spawn(server.start_arc());
let response = reqwest::Client::new()
.request(Method::OPTIONS, host)
.request(Method::OPTIONS, "http://localhost:3001")
.send()
.await
.unwrap();
Expand Down Expand Up @@ -293,10 +335,11 @@ mod tests {

#[tokio::test]
async fn signed_with_different_secret_does_not_work() {
let host = init().await;
let server = server_for_testing([("foo", "foo-secret")], 3002);
tokio::spawn(server.start_arc());
let bad_signer = UrlSigner::new("foo", "bad-secret");

let signed_url = bad_signer.get_signed_url(host, &base_request()).unwrap();
let signed_url = bad_signer.get_signed_url("http://localhost:3002", &base_request()).unwrap();

let response = reqwest::Client::new().get(signed_url).send().await.unwrap();

Expand Down
25 changes: 7 additions & 18 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use anyhow::anyhow;
use lazy_static::lazy_static;
use std::str::FromStr;

use async_trait::async_trait;
use clap::Parser;
use tracing::info;

use signway_server::hyper::header::HeaderName;
use signway_server::hyper::{Body, Response, StatusCode};
use signway_server::{
Expand Down Expand Up @@ -113,7 +114,10 @@ impl OnBytesTransferred for BytesTransferredLogger {
}
}

fn make_server() -> anyhow::Result<SignwayServer<Config>> {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt().json().init();

let args: Args = Args::parse();
let config: Config = args.clone().try_into()?;
let mut server = SignwayServer::from_env(config);
Expand All @@ -130,24 +134,9 @@ fn make_server() -> anyhow::Result<SignwayServer<Config>> {
if let Some(value) = args.access_control_allow_origin {
server = server.access_control_allow_origin(&value)?;
}
Ok(server)
}

lazy_static! {
static ref SERVER: anyhow::Result<SignwayServer<Config>> = make_server();
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt().json().init();

let server = match SERVER.as_ref() {
Ok(server) => server,
Err(err) => return Err(anyhow::anyhow!(err))
};

tokio::select! {
result = server.start() => {
result = server.start_leak() => {
result
}
_ = tokio::signal::ctrl_c() => {
Expand Down

0 comments on commit 0e0cc3c

Please sign in to comment.