From 740fa0b6b0bcd0ab70e22c6b0b925ca7bdde60d7 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang <810895+jeffreyftang@users.noreply.github.com> Date: Mon, 26 Feb 2024 11:54:54 -0600 Subject: [PATCH] enh: Propagate bearer token from header if one exists for OpenAI-compatible endpoints (#278) --- router/src/server.rs | 48 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/router/src/server.rs b/router/src/server.rs index 26dae2782..08dd4d0d8 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -64,6 +64,7 @@ example = json ! ({"error": "Incomplete generation"})), async fn compat_generate( default_return_full_text: Extension, infer: Extension, + req_headers: HeaderMap, req: Json, ) -> Result)> { let mut req = req.0; @@ -75,11 +76,11 @@ async fn compat_generate( // switch on stream if req.stream { - Ok(generate_stream(infer, Json(req.into())) + Ok(generate_stream(infer, req_headers, Json(req.into())) .await .into_response()) } else { - let (headers, generation) = generate(infer, Json(req.into())).await?; + let (headers, generation) = generate(infer, req_headers, Json(req.into())).await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(vec![generation.0])).into_response()) } @@ -111,6 +112,7 @@ example = json ! ({"error": "Incomplete generation"})), async fn completions_v1( default_return_full_text: Extension, infer: Extension, + req_headers: HeaderMap, req: Json, ) -> Result)> { let req = req.0; @@ -136,10 +138,10 @@ async fn completions_v1( }; let (headers, stream) = - generate_stream_with_callback(infer, Json(gen_req.into()), callback).await; + generate_stream_with_callback(infer, req_headers, Json(gen_req.into()), callback).await; Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response()) } else { - let (headers, generation) = generate(infer, Json(gen_req.into())).await?; + let (headers, generation) = generate(infer, req_headers, Json(gen_req.into())).await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(CompletionResponse::from(generation.0))).into_response()) } @@ -171,6 +173,7 @@ example = json ! ({"error": "Incomplete generation"})), async fn chat_completions_v1( default_return_full_text: Extension, infer: Extension, + req_headers: HeaderMap, req: Json, ) -> Result)> { let req = req.0; @@ -196,10 +199,10 @@ async fn chat_completions_v1( }; let (headers, stream) = - generate_stream_with_callback(infer, Json(gen_req.into()), callback).await; + generate_stream_with_callback(infer, req_headers, Json(gen_req.into()), callback).await; Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response()) } else { - let (headers, generation) = generate(infer, Json(gen_req.into())).await?; + let (headers, generation) = generate(infer, req_headers, Json(gen_req.into())).await?; // wrap generation inside a Vec to match api-inference Ok((headers, Json(ChatCompletionResponse::from(generation.0))).into_response()) } @@ -274,7 +277,8 @@ seed, )] async fn generate( infer: Extension, - req: Json, + req_headers: HeaderMap, + mut req: Json, ) -> Result<(HeaderMap, Json), (StatusCode, Json)> { let span = tracing::Span::current(); let start_time = Instant::now(); @@ -295,6 +299,17 @@ async fn generate( req.0.parameters.adapter_parameters.clone(), ); + if req.parameters.api_token.is_none() { + // If no API token was explicitly provided in the request payload, try to set it from the request headers. + let _ = req_headers.get("authorization").map_or((), |x| { + x.to_str().map_or((), |y| { + y.strip_prefix("Bearer ").map_or((), |token| { + req.parameters.api_token = Some(token.to_string()); + }) + }) + }); + } + // Inference let (response, best_of_responses) = match req.0.parameters.best_of { Some(best_of) if best_of > 1 => { @@ -495,19 +510,21 @@ seed, )] async fn generate_stream( infer: Extension, - req: Json, + req_headers: HeaderMap, + mut req: Json, ) -> ( HeaderMap, Sse>>, ) { let callback = |resp: StreamResponse| Event::default().json_data(resp).unwrap(); - let (headers, stream) = generate_stream_with_callback(infer, req, callback).await; + let (headers, stream) = generate_stream_with_callback(infer, req_headers, req, callback).await; (headers, Sse::new(stream).keep_alive(KeepAlive::default())) } async fn generate_stream_with_callback( infer: Extension, - req: Json, + req_headers: HeaderMap, + mut req: Json, callback: impl Fn(StreamResponse) -> Event, ) -> (HeaderMap, impl Stream>) { let span = tracing::Span::current(); @@ -526,6 +543,17 @@ async fn generate_stream_with_callback( ); headers.insert("X-Accel-Buffering", "no".parse().unwrap()); + if req.parameters.api_token.is_none() { + // If no API token was explicitly provided in the request payload, try to set it from the request headers. + let _ = req_headers.get("authorization").map_or((), |x| { + x.to_str().map_or((), |y| { + y.strip_prefix("Bearer ").map_or((), |token| { + req.parameters.api_token = Some(token.to_string()); + }) + }) + }); + } + let stream = async_stream::stream! { // Inference let mut end_reached = false;