diff --git a/router/src/infer.rs b/router/src/infer.rs index d46ab0471..b8b4b656c 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -118,7 +118,7 @@ impl Infer { })?; let mut adapter_id = request.parameters.adapter_id.clone(); - if adapter_id.is_none() { + if adapter_id.is_none() || adapter_id.as_ref().unwrap().is_empty() { adapter_id = Some(BASE_MODEL_ADAPTER_ID.to_string()); } let mut adapter_source = request.parameters.adapter_source.clone(); diff --git a/router/src/lib.rs b/router/src/lib.rs index 6d1d987d5..ce5003bb6 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -308,6 +308,216 @@ pub(crate) struct ErrorResponse { pub error_type: String, } +// OpenAI compatible structs + +#[derive(Serialize, ToSchema)] +struct UsageInfo { + prompt_tokens: u32, + total_tokens: u32, + completion_tokens: Option, +} + +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct ChatCompletionRequest { + model: String, + messages: Vec, + temperature: Option, + top_p: Option, + n: Option, + max_tokens: Option, + #[serde(default)] + stop: Vec, + stream: Option, + presence_penalty: Option, + frequency_penalty: Option, + logit_bias: Option>, + user: Option, + // Additional parameters + // TODO(travis): add other LoRAX params here +} + +#[derive(Clone, Debug, Deserialize, ToSchema)] +struct CompletionRequest { + model: String, + prompt: String, + suffix: Option, + max_tokens: Option, + temperature: Option, + top_p: Option, + n: Option, + stream: Option, + logprobs: Option, + echo: Option, + #[serde(default)] + stop: Vec, + presence_penalty: Option, + frequency_penalty: Option, + best_of: Option, + logit_bias: Option>, + user: Option, + // Additional parameters + // TODO(travis): add other LoRAX params here +} + +#[derive(Serialize, ToSchema)] +struct LogProbs { + text_offset: Vec, + token_logprobs: Vec>, + tokens: Vec, + top_logprobs: Option>>>, +} + +#[derive(Serialize, ToSchema)] +struct CompletionResponseChoice { + index: i32, + text: String, + logprobs: Option, + finish_reason: Option, // Literal replaced with String +} + +#[derive(Serialize, ToSchema)] +struct CompletionResponse { + id: String, + object: String, + created: i64, + model: String, + choices: Vec, + usage: UsageInfo, +} + +#[derive(Serialize, ToSchema)] +struct CompletionResponseStreamChoice { + index: i32, + text: String, + logprobs: Option, + finish_reason: Option, // Literal replaced with String +} + +#[derive(Serialize, ToSchema)] +struct CompletionStreamResponse { + id: String, + object: String, + created: i64, + model: String, + choices: Vec, + usage: Option, +} + +#[derive(Serialize, ToSchema)] +struct ChatMessage { + role: String, + content: String, +} + +#[derive(Serialize, ToSchema)] +struct ChatCompletionResponseChoice { + index: i32, + message: ChatMessage, + finish_reason: Option, // Literal replaced with String +} + +#[derive(Serialize, ToSchema)] +struct ChatCompletionResponse { + id: String, + object: String, + created: i64, + model: String, + choices: Vec, + usage: UsageInfo, +} + +impl From for CompatGenerateRequest { + fn from(req: CompletionRequest) -> Self { + CompatGenerateRequest { + inputs: req.prompt, + parameters: GenerateParameters { + adapter_id: req.model.parse().ok(), + adapter_source: None, + api_token: None, + best_of: req.best_of.map(|x| x as usize), + temperature: req.temperature, + repetition_penalty: None, + top_k: None, + top_p: req.top_p, + typical_p: None, + do_sample: !req.n.is_none(), + max_new_tokens: req + .max_tokens + .map(|x| x as u32) + .unwrap_or(default_max_new_tokens()), + return_full_text: req.echo, + stop: req.stop, + truncate: None, + watermark: false, + details: true, + decoder_input_details: req.logprobs.is_some(), + seed: None, + }, + stream: req.stream.unwrap_or(false), + } + } +} + +impl From for CompletionResponse { + fn from(resp: GenerateResponse) -> Self { + let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0); + let completion_tokens = resp + .details + .as_ref() + .map(|x| x.generated_tokens) + .unwrap_or(0); + let total_tokens = prompt_tokens + completion_tokens; + + CompletionResponse { + id: "null".to_string(), + object: "text_completion".to_string(), + created: 0, + model: "null".to_string(), + choices: vec![CompletionResponseChoice { + index: 0, + text: resp.generated_text, + logprobs: None, + finish_reason: None, + }], + usage: UsageInfo { + prompt_tokens: prompt_tokens, + total_tokens: total_tokens, + completion_tokens: Some(completion_tokens), + }, + } + } +} + +impl From for CompletionStreamResponse { + fn from(resp: StreamResponse) -> Self { + let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0); + let completion_tokens = resp + .details + .as_ref() + .map(|x| x.generated_tokens) + .unwrap_or(0); + let total_tokens = prompt_tokens + completion_tokens; + + CompletionStreamResponse { + id: "null".to_string(), + object: "text_completion".to_string(), + created: 0, + model: "null".to_string(), + choices: vec![CompletionResponseStreamChoice { + index: 0, + text: resp.generated_text.unwrap_or_default(), + logprobs: None, + finish_reason: None, + }], + usage: Some(UsageInfo { + prompt_tokens: prompt_tokens, + total_tokens: total_tokens, + completion_tokens: Some(completion_tokens), + }), + } + } +} + #[cfg(test)] mod tests { use std::io::Write; diff --git a/router/src/server.rs b/router/src/server.rs index 662594039..41f54d4c9 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -3,9 +3,10 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::validation::ValidationError; use crate::{ - BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason, - GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, - StreamDetails, StreamResponse, Token, Validation, + BestOfSequence, CompatGenerateRequest, CompletionRequest, CompletionResponse, + CompletionStreamResponse, Details, ErrorResponse, FinishReason, GenerateParameters, + GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, StreamDetails, + StreamResponse, Token, Validation, }; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; @@ -77,6 +78,66 @@ async fn compat_generate( } } +/// Generate tokens if `stream == false` or a stream of token if `stream == true` +#[utoipa::path( +post, +tag = "LoRAX", +path = "/v1/completions", +request_body = CompletionRequest, +responses( +(status = 200, description = "Generated Text", +content( +("application/json" = CompletionResponse), +("text/event-stream" = CompletionStreamResponse), +)), +(status = 424, description = "Generation Error", body = ErrorResponse, +example = json ! ({"error": "Request failed during generation"})), +(status = 429, description = "Model is overloaded", body = ErrorResponse, +example = json ! ({"error": "Model is overloaded"})), +(status = 422, description = "Input validation error", body = ErrorResponse, +example = json ! ({"error": "Input validation error"})), +(status = 500, description = "Incomplete generation", body = ErrorResponse, +example = json ! ({"error": "Incomplete generation"})), +) +)] +#[instrument(skip(infer, req))] +async fn completions_v1( + default_return_full_text: Extension, + infer: Extension, + req: Json, +) -> Result)> { + let req = req.0; + let mut gen_req = CompatGenerateRequest::from(req); + + // default return_full_text given the pipeline_tag + if gen_req.parameters.return_full_text.is_none() { + gen_req.parameters.return_full_text = Some(default_return_full_text.0) + } + + // switch on stream + if gen_req.stream { + let callback = move |resp: StreamResponse| { + Event::default() + .json_data(CompletionStreamResponse::from(resp)) + .map_or_else( + |err| { + tracing::error!("Failed to serialize CompletionStreamResponse: {err}"); + Event::default() + }, + |data| data, + ) + }; + + let (headers, stream) = + generate_stream_with_callback(infer, Json(gen_req.into()), callback).await; + Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response()) + } else { + let (headers, generation) = generate(infer, Json(gen_req.into())).await?; + // wrap generation inside a Vec to match api-inference + Ok((headers, Json(vec![CompletionResponse::from(generation.0)])).into_response()) + } +} + /// LoRAX endpoint info #[utoipa::path( get, @@ -351,6 +412,16 @@ async fn generate_stream( HeaderMap, Sse>>, ) { + let callback = |resp: StreamResponse| Event::default().json_data(resp).unwrap(); + let (headers, stream) = generate_stream_with_callback(infer, req, callback).await; + (headers, Sse::new(stream).keep_alive(KeepAlive::default())) +} + +async fn generate_stream_with_callback( + infer: Extension, + req: Json, + callback: impl Fn(StreamResponse) -> Event, +) -> (HeaderMap, impl Stream>) { let span = tracing::Span::current(); let start_time = Instant::now(); metrics::increment_counter!("lorax_request_count"); @@ -479,7 +550,7 @@ async fn generate_stream( details }; - yield Ok(Event::default().json_data(stream_token).unwrap()); + yield Ok(callback(stream_token)); break; } } @@ -510,7 +581,7 @@ async fn generate_stream( } }; - (headers, Sse::new(stream).keep_alive(KeepAlive::default())) + (headers, stream) } /// Prometheus metrics scrape endpoint @@ -699,6 +770,7 @@ pub async fn run( .route("/info", get(get_model_info)) .route("/generate", post(generate)) .route("/generate_stream", post(generate_stream)) + .route("/v1/completions", post(completions_v1)) // AWS Sagemaker route .route("/invocations", post(compat_generate)) // Base Health route