Skip to content

Commit

Permalink
🚧 WIP Tracing in Rocket
Browse files Browse the repository at this point in the history
  • Loading branch information
RemiBardon committed Aug 13, 2024
1 parent 8b35e06 commit 8632ee7
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 14 deletions.
151 changes: 149 additions & 2 deletions src/orangutan-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@ mod request_guards;
mod routes;
mod util;

use std::{
convert::Infallible,
fmt::Display,
ops::Deref,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, RwLock, RwLockReadGuard,
},
};

use object_reader::ObjectReader;
use orangutan_helpers::{
generate::{self, *},
Expand All @@ -11,9 +21,10 @@ use orangutan_helpers::{
};
use rocket::{
catch, catchers,
fairing::AdHoc,
fairing::{self, AdHoc, Fairing},
fs::NamedFile,
http::Status,
request::{self, FromRequest},
response::{self, Responder},
Request,
};
Expand Down Expand Up @@ -53,7 +64,9 @@ fn rocket() -> _ {
rocket.shutdown().notify();
}
})
}));
}))
.attach(RequestIdFairing)
.attach(TracingSpanFairing);

// Add support for templating if needed
#[cfg(feature = "templating")]
Expand Down Expand Up @@ -178,3 +191,137 @@ impl From<orangutan_refresh_token::Error> for Error {
}
}
}

// ===== Request ID =====

#[derive(Debug, Clone)]
pub struct RequestId(String);
impl Deref for RequestId {
type Target = String;

fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Display for RequestId {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
Display::fmt(self.deref(), f)
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for RequestId {
type Error = Infallible;

async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
static COUNTER: AtomicUsize = AtomicUsize::new(1);
let request_id = RequestId(
req.headers()
.get_one("X-Request-Id")
.map(ToString::to_string)
.unwrap_or_else(|| COUNTER.fetch_add(1, Ordering::Relaxed).to_string()),
);
request::Outcome::Success(request_id)
}
}

#[rocket::async_trait]
trait RequestIdTrait {
async fn id(&self) -> RequestId;
}

#[rocket::async_trait]
impl<'r> RequestIdTrait for Request<'r> {
async fn id(&self) -> RequestId {
self.guard::<RequestId>().await.unwrap().to_owned()
}
}

// ===== Request ID fairing =====

struct RequestIdFairing;

#[rocket::async_trait]
impl Fairing for RequestIdFairing {
fn info(&self) -> fairing::Info {
fairing::Info {
name: "Add a unique request ID to every request",
kind: fairing::Kind::Request,
}
}

/// See <https://rocket.rs/guide/v0.5/state/#request-local-state>
/// and <https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805/6>.
async fn on_request(
&self,
req: &mut Request<'_>,
_: &mut rocket::Data<'_>,
) {
static COUNTER: AtomicUsize = AtomicUsize::new(1);
let request_id = RequestId(
req.headers()
.get_one("X-Request-Id")
.map(ToString::to_string)
.unwrap_or_else(|| COUNTER.fetch_add(1, Ordering::Relaxed).to_string()),
);
req.local_cache(|| request_id);
}
}

// ===== Tracing span =====

#[derive(Clone)]
pub struct TracingSpan(Arc<RwLock<tracing::Span>>);

impl TracingSpan {
pub fn get(&self) -> RwLockReadGuard<'_, tracing::Span> {
self.0.read().unwrap()
}
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for TracingSpan {
type Error = ();

async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, ()> {
let span: &TracingSpan =
request.local_cache(|| panic!("Tracing span should be managed by then"));
request::Outcome::Success(span.to_owned())
}
}

// ===== Tracing span fairing =====

struct TracingSpanFairing;

#[rocket::async_trait]
impl Fairing for TracingSpanFairing {
fn info(&self) -> fairing::Info {
fairing::Info {
name: "Add request information to tracing span",
kind: fairing::Kind::Request,
}
}

async fn on_request(
&self,
req: &mut Request<'_>,
_: &mut rocket::Data<'_>,
) {
let user_agent = req.headers().get_one("User-Agent").unwrap_or("none");
let request_id = req.id().await;

let span = tracing::debug_span!(
"request",
request_id = %request_id,
http.method = %req.method(),
http.uri = %req.uri().path(),
http.user_agent = %user_agent,
otel.name=%format!("{} {}", req.method(), req.uri().path()),
);

req.local_cache(|| TracingSpan(Arc::new(RwLock::new(span))));
}
}
15 changes: 14 additions & 1 deletion src/orangutan-server/src/request_guards.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
use std::ops::Deref;

use biscuit_auth::Biscuit;
use rocket::{http::Status, outcome::Outcome, request, request::FromRequest, Request};
use rocket::{
http::Status,
outcome::{try_outcome, Outcome},
request::{self, FromRequest},
Request,
};
use tracing::{debug, trace};

use crate::{
config::*,
util::{add_cookie, add_padding, profiles},
TracingSpan,
};

pub struct Token {
Expand All @@ -31,13 +37,20 @@ impl Deref for Token {
pub enum TokenError {
// TODO: Re-enable Basic authentication
// Invalid,
InternalServerError,
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for Token {
type Error = TokenError;

async fn from_request(req: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
let span = try_outcome!(TracingSpan::from_request(req)
.await
.map_error(|(s, ())| (s, TokenError::InternalServerError)));
let _span = span.get();
let _span = _span.enter();

let mut biscuit: Option<Biscuit> = None;
let mut should_save: bool = false;

Expand Down
5 changes: 5 additions & 0 deletions src/orangutan-server/src/routes/auth_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::{
error,
request_guards::Token,
util::{add_cookie, add_padding},
TracingSpan,
};

lazy_static! {
Expand All @@ -27,12 +28,16 @@ pub(super) fn routes() -> Vec<Route> {

#[get("/<_..>?<refresh_token>&<force>")]
fn handle_refresh_token(
span: TracingSpan,
origin: &Origin,
cookies: &CookieJar<'_>,
refresh_token: &str,
token: Option<Token>,
force: Option<bool>,
) -> Result<Redirect, Status> {
let _span = span.get();
let _span = _span.enter();

// URL-decode the string.
let mut refresh_token: String = urlencoding::decode(refresh_token).unwrap().to_string();

Expand Down
52 changes: 45 additions & 7 deletions src/orangutan-server/src/routes/debug_routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use lazy_static::lazy_static;
use rocket::{get, http::CookieJar, routes, Route};

use super::auth_routes::REVOKED_TOKENS;
use crate::{request_guards::Token, Error};
use crate::{request_guards::Token, Error, TracingSpan};

lazy_static! {
/// A list of runtime errors, used to show error logs in an admin page
Expand Down Expand Up @@ -45,7 +45,13 @@ pub(super) fn templates() -> Vec<(&'static str, &'static str)> {
}

#[get("/clear-cookies")]
fn clear_cookies(cookies: &CookieJar<'_>) -> &'static str {
fn clear_cookies(
span: TracingSpan,
cookies: &CookieJar<'_>,
) -> &'static str {
let _span = span.get();
let _span = _span.enter();

for cookie in cookies.iter().map(Clone::clone) {
cookies.remove(cookie);
}
Expand All @@ -54,7 +60,13 @@ fn clear_cookies(cookies: &CookieJar<'_>) -> &'static str {
}

#[get("/_info")]
fn get_user_info(token: Option<Token>) -> String {
fn get_user_info(
span: TracingSpan,
token: Option<Token>,
) -> String {
let _span = span.get();
let _span = _span.enter();

match token {
Some(Token { biscuit, .. }) => format!(
"**Biscuit:**\n\n{}\n\n\
Expand All @@ -74,7 +86,13 @@ pub struct ErrorLog {
}

#[get("/_errors")]
fn errors(token: Token) -> Result<String, Error> {
fn errors(
span: TracingSpan,
token: Token,
) -> Result<String, Error> {
let _span = span.get();
let _span = _span.enter();

if !token.profiles().contains(&"*".to_owned()) {
Err(Error::Unauthorized)?
}
Expand All @@ -101,7 +119,13 @@ pub struct AccessLog {
}

#[get("/_access-logs")]
fn access_logs(token: Token) -> Result<String, Error> {
fn access_logs(
span: TracingSpan,
token: Token,
) -> Result<String, Error> {
let _span = span.get();
let _span = _span.enter();

if !token.profiles().contains(&"*".to_owned()) {
Err(Error::Unauthorized)?
}
Expand Down Expand Up @@ -140,7 +164,13 @@ pub fn log_access(
}

#[get("/_revoked-tokens")]
fn revoked_tokens(token: Token) -> Result<String, Error> {
fn revoked_tokens(
span: TracingSpan,
token: Token,
) -> Result<String, Error> {
let _span = span.get();
let _span = _span.enter();

if !token.profiles().contains(&"*".to_owned()) {
Err(Error::Forbidden)?
}
Expand Down Expand Up @@ -168,7 +198,7 @@ pub mod token_generator {
context,
request_guards::Token,
util::{templating::render, WebsiteRoot},
Error,
Error, TracingSpan,
};

fn token_generation_form_(
Expand All @@ -187,10 +217,14 @@ pub mod token_generator {

#[get("/_generate-token")]
pub fn token_generation_form(
span: TracingSpan,
token: Token,
tera: &State<tera::Tera>,
website_root: WebsiteRoot,
) -> Result<RawHtml<String>, Error> {
let _span = span.get();
let _span = _span.enter();

if !token.profiles().contains(&"*".to_owned()) {
Err(Error::Unauthorized)?
}
Expand All @@ -208,11 +242,15 @@ pub mod token_generator {

#[post("/_generate-token", data = "<form>")]
pub fn generate_token(
span: TracingSpan,
token: Token,
tera: &State<tera::Tera>,
form: Form<Strict<GenerateTokenForm>>,
website_root: WebsiteRoot,
) -> Result<RawHtml<String>, Error> {
let _span = span.get();
let _span = _span.enter();

if !token.profiles().contains(&"*".to_owned()) {
Err(Error::Unauthorized)?
}
Expand Down
8 changes: 7 additions & 1 deletion src/orangutan-server/src/routes/main_route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@ use rocket::{
};
use tracing::{debug, trace};

use crate::{config::*, request_guards::Token, routes::debug_routes::log_access, util::error};
use crate::{
config::*, request_guards::Token, routes::debug_routes::log_access, util::error, TracingSpan,
};

pub(super) fn routes() -> Vec<Route> {
routes![handle_request]
}

#[get("/<_..>")]
async fn handle_request(
span: TracingSpan,
origin: &Origin<'_>,
token: Option<Token>,
object_reader: &State<ObjectReader>,
accept: Option<&Accept>,
) -> Result<Option<ReadObjectResponse>, crate::Error> {
let span = span.get();
let _span = span.enter();

// FIXME: Handle error
let path = urlencoding::decode(origin.path().as_str())
.unwrap()
Expand Down
Loading

0 comments on commit 8632ee7

Please sign in to comment.