Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Otel v2 #642

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
router
noyoshi committed Oct 15, 2024
commit 13e6ad41375ce3b41c321b5a8c77d98a20e27151
3 changes: 2 additions & 1 deletion router/Cargo.toml
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ path = "src/main.rs"
[dependencies]
async-stream = "0.3.3"
axum = { version = "0.6.4", features = ["json"] }
axum-tracing-opentelemetry = "0.10.0"
axum-tracing-opentelemetry = "0.14.1"
clap = { version = "4.1.4", features = ["derive", "env"] }
flume = "0.10.14"
futures = "0.3.26"
@@ -64,6 +64,7 @@ image = "0.25.1"
rustls = "0.22.4"
webpki = "0.22.2"
base64 = "0.22.0"
init-tracing-opentelemetry = { version = "0.22.0", features = ["tracing_subscriber_ext"] }

[build-dependencies]
vergen = { version = "8.2.5", features = ["build", "git", "gitcl"] }
94 changes: 91 additions & 3 deletions router/src/server.rs
Original file line number Diff line number Diff line change
@@ -14,15 +14,14 @@ use crate::{
EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, LogProbs, PrefillToken,
ResponseFormat, ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, Token,
TokenizeRequest, TokenizeResponse, UsageInfo, Validation,
TokenizeRequest, TokenizeResponse, UsageInfo, Validation, logging,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use axum::routing::{get, post};
use axum::{http, Json, Router};
use axum_tracing_opentelemetry::opentelemetry_tracing_layer;
use futures::stream::StreamExt;
use futures::Stream;
use lorax_client::{ShardInfo, ShardedClient};
@@ -43,6 +42,11 @@ use tower_http::cors::{
AllowCredentials, AllowHeaders, AllowMethods, AllowOrigin, CorsLayer, ExposeHeaders,
};
use tracing::{info_span, instrument, Instrument};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use axum::{extract::Request, middleware::Next, response::Response};
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
use opentelemetry::Context;
use axum_tracing_opentelemetry::middleware::{OtelAxumLayer, OtelInResponseLayer};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;

@@ -75,6 +79,7 @@ async fn compat_generate(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
info: Extension<Info>,
Extension(context): Extension<Option<opentelemetry::Context>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
@@ -93,6 +98,7 @@ async fn compat_generate(
Ok(generate_stream(
infer,
info,
axum::Extension(context),
request_logger_sender,
req_headers,
Json(req.into()),
@@ -103,6 +109,7 @@ async fn compat_generate(
let (headers, generation) = generate(
infer,
info,
axum::Extension(context),
request_logger_sender,
req_headers,
Json(req.into()),
@@ -142,6 +149,7 @@ async fn completions_v1(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
info: Extension<Info>,
Extension(context): Extension<Option<opentelemetry::Context>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
@@ -178,6 +186,7 @@ async fn completions_v1(
let (headers, stream) = generate_stream_with_callback(
infer,
info,
axum::Extension(context),
request_logger_sender,
req_headers,
Json(gen_req.into()),
@@ -190,6 +199,7 @@ async fn completions_v1(
let (headers, generation) = generate(
infer,
info,
axum::Extension(context),
request_logger_sender,
req_headers,
Json(gen_req.into()),
@@ -227,6 +237,7 @@ async fn chat_completions_v1(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
info: Extension<Info>,
Extension(context): Extension<Option<opentelemetry::Context>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
@@ -315,6 +326,7 @@ async fn chat_completions_v1(
let (headers, stream) = generate_stream_with_callback(
infer,
info,
axum::Extension(context),
request_logger_sender,
req_headers,
Json(gen_req.into()),
@@ -327,6 +339,7 @@ async fn chat_completions_v1(
let (headers, generation) = generate(
infer,
info,
axum::Extension(context),
request_logger_sender,
req_headers,
Json(gen_req.into()),
@@ -486,13 +499,17 @@ seed,
async fn generate(
infer: Extension<Infer>,
info: Extension<Info>,
Extension(context): Extension<Option<opentelemetry::Context>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
req_headers: HeaderMap,
mut req: Json<GenerateRequest>,
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
if let Some(context) = context {
span.set_parent(context);
}
let start_time = Instant::now();
metrics::increment_counter!("lorax_request_count");

@@ -742,6 +759,7 @@ seed,
async fn generate_stream(
infer: Extension<Infer>,
info: Extension<Info>,
Extension(context): Extension<Option<opentelemetry::Context>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
@@ -755,6 +773,7 @@ async fn generate_stream(
let (headers, stream) = generate_stream_with_callback(
infer,
info,
axum::Extension(context),
request_logger_sender,
req_headers,
req,
@@ -767,6 +786,7 @@ async fn generate_stream(
async fn generate_stream_with_callback(
infer: Extension<Infer>,
info: Extension<Info>,
Extension(context): Extension<Option<opentelemetry::Context>>,
request_logger_sender: Extension<
Arc<mpsc::Sender<(i64, String, String, String, String, String)>>,
>,
@@ -776,6 +796,9 @@ async fn generate_stream_with_callback(
end_event: Option<Event>,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current();
if let Some(context) = context {
span.set_parent(context);
}
let start_time = Instant::now();
metrics::increment_counter!("lorax_request_count");

@@ -1356,7 +1379,9 @@ pub async fn run(
.layer(Extension(compat_return_full_text))
.layer(Extension(infer))
.layer(Extension(prom_handle.clone()))
.layer(opentelemetry_tracing_layer())
.layer(OtelAxumLayer::default())
.layer(OtelInResponseLayer::default())
.layer(axum::middleware::from_fn(trace_context_middleware))
.layer(cors_layer)
.layer(Extension(cloned_tokenizer));

@@ -1533,9 +1558,13 @@ async fn embed(
#[instrument(skip_all)]
async fn classify(
infer: Extension<Infer>,
Extension(context): Extension<Option<opentelemetry::Context>>,
Json(req): Json<ClassifyRequest>,
) -> Result<(HeaderMap, Json<Vec<Entity>>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
if let Some(context) = context {
span.set_parent(context);
}
let start_time = Instant::now();
metrics::increment_counter!("lorax_request_count");
tracing::debug!("Input: {}", req.inputs);
@@ -1602,9 +1631,13 @@ async fn classify(
#[instrument(skip_all)]
async fn classify_batch(
infer: Extension<Infer>,
Extension(context): Extension<Option<opentelemetry::Context>>,
Json(req): Json<BatchClassifyRequest>,
) -> Result<(HeaderMap, Json<Vec<Vec<Entity>>>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
if let Some(context) = context {
span.set_parent(context);
}
let start_time = Instant::now();
metrics::increment_counter!("lorax_request_count");
tracing::debug!("Inputs: {:?}", req.inputs);
@@ -1725,3 +1758,58 @@ async fn tokenize(
))
}
}



struct TraceParent {
#[allow(dead_code)]
version: u8,
trace_id: TraceId,
parent_id: SpanId,
trace_flags: TraceFlags,
}


fn parse_traceparent(header_value: &str) -> Option<TraceParent> {
let parts: Vec<&str> = header_value.split('-').collect();
if parts.len() != 4 {
return None;
}

let version = u8::from_str_radix(parts[0], 16).ok()?;
if version == 0xff {
return None;
}

let trace_id = TraceId::from_hex(parts[1]).ok()?;
let parent_id = SpanId::from_hex(parts[2]).ok()?;
let trace_flags = u8::from_str_radix(parts[3], 16).ok()?;

Some(TraceParent {
version,
trace_id,
parent_id,
trace_flags: TraceFlags::new(trace_flags),
})
}

pub async fn trace_context_middleware(mut request: Request, next: Next<B>) -> Response {
let context = request
.headers()
.get("traceparent")
.and_then(|v| v.to_str().ok())
.and_then(parse_traceparent)
.map(|traceparent| {
Context::new().with_remote_span_context(SpanContext::new(
traceparent.trace_id,
traceparent.parent_id,
traceparent.trace_flags,
true,
Default::default(),
))
});

request.extensions_mut().insert(context);

next.run(request).await
}