Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI v1 Completions API #170

Merged
merged 12 commits into from
Jan 9, 2024
2 changes: 1 addition & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl Infer {
})?;

let mut adapter_id = request.parameters.adapter_id.clone();
if adapter_id.is_none() {
if adapter_id.is_none() || adapter_id.as_ref().unwrap().is_empty() {
adapter_id = Some(BASE_MODEL_ADAPTER_ID.to_string());
}
let mut adapter_source = request.parameters.adapter_source.clone();
Expand Down
210 changes: 210 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,216 @@ pub(crate) struct ErrorResponse {
pub error_type: String,
}

// OpenAI compatible structs

#[derive(Serialize, ToSchema)]
struct UsageInfo {
prompt_tokens: u32,
total_tokens: u32,
completion_tokens: Option<u32>,
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct ChatCompletionRequest {
model: String,
messages: Vec<String>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<i32>,
max_tokens: Option<i32>,
#[serde(default)]
stop: Vec<String>,
stream: Option<bool>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
logit_bias: Option<std::collections::HashMap<String, f32>>,
user: Option<String>,
// Additional parameters
// TODO(travis): add other LoRAX params here
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct CompletionRequest {
model: String,
prompt: String,
suffix: Option<String>,
max_tokens: Option<i32>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<i32>,
stream: Option<bool>,
logprobs: Option<i32>,
echo: Option<bool>,
#[serde(default)]
stop: Vec<String>,
presence_penalty: Option<f32>,
frequency_penalty: Option<f32>,
best_of: Option<i32>,
logit_bias: Option<std::collections::HashMap<String, f32>>,
user: Option<String>,
// Additional parameters
// TODO(travis): add other LoRAX params here
}

#[derive(Serialize, ToSchema)]
struct LogProbs {
text_offset: Vec<i32>,
token_logprobs: Vec<Option<f32>>,
tokens: Vec<String>,
top_logprobs: Option<Vec<Option<std::collections::HashMap<i32, f32>>>>,
}

#[derive(Serialize, ToSchema)]
struct CompletionResponseChoice {
index: i32,
text: String,
logprobs: Option<LogProbs>,
finish_reason: Option<String>, // Literal replaced with String
}

#[derive(Serialize, ToSchema)]
struct CompletionResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<CompletionResponseChoice>,
usage: UsageInfo,
}

#[derive(Serialize, ToSchema)]
struct CompletionResponseStreamChoice {
index: i32,
text: String,
logprobs: Option<LogProbs>,
finish_reason: Option<String>, // Literal replaced with String
}

#[derive(Serialize, ToSchema)]
struct CompletionStreamResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<CompletionResponseStreamChoice>,
usage: Option<UsageInfo>,
}

#[derive(Serialize, ToSchema)]
struct ChatMessage {
role: String,
content: String,
}

#[derive(Serialize, ToSchema)]
struct ChatCompletionResponseChoice {
index: i32,
message: ChatMessage,
finish_reason: Option<String>, // Literal replaced with String
}

#[derive(Serialize, ToSchema)]
struct ChatCompletionResponse {
id: String,
object: String,
created: i64,
model: String,
choices: Vec<ChatCompletionResponseChoice>,
usage: UsageInfo,
}

impl From<CompletionRequest> for CompatGenerateRequest {
fn from(req: CompletionRequest) -> Self {
CompatGenerateRequest {
inputs: req.prompt,
parameters: GenerateParameters {
adapter_id: req.model.parse().ok(),
adapter_source: None,
api_token: None,
best_of: req.best_of.map(|x| x as usize),
temperature: req.temperature,
repetition_penalty: None,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: !req.n.is_none(),
max_new_tokens: req
.max_tokens
.map(|x| x as u32)
.unwrap_or(default_max_new_tokens()),
return_full_text: req.echo,
stop: req.stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: req.logprobs.is_some(),
seed: None,
},
stream: req.stream.unwrap_or(false),
}
}
}

impl From<GenerateResponse> for CompletionResponse {
fn from(resp: GenerateResponse) -> Self {
let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0);
let completion_tokens = resp
.details
.as_ref()
.map(|x| x.generated_tokens)
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

CompletionResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![CompletionResponseChoice {
index: 0,
text: resp.generated_text,
logprobs: None,
finish_reason: None,
}],
usage: UsageInfo {
prompt_tokens: prompt_tokens,
total_tokens: total_tokens,
completion_tokens: Some(completion_tokens),
},
}
}
}

impl From<StreamResponse> for CompletionStreamResponse {
fn from(resp: StreamResponse) -> Self {
let prompt_tokens = resp.details.as_ref().map(|x| x.prompt_tokens).unwrap_or(0);
let completion_tokens = resp
.details
.as_ref()
.map(|x| x.generated_tokens)
.unwrap_or(0);
let total_tokens = prompt_tokens + completion_tokens;

CompletionStreamResponse {
id: "null".to_string(),
object: "text_completion".to_string(),
created: 0,
model: "null".to_string(),
choices: vec![CompletionResponseStreamChoice {
index: 0,
text: resp.generated_text.unwrap_or_default(),
logprobs: None,
finish_reason: None,
}],
usage: Some(UsageInfo {
prompt_tokens: prompt_tokens,
total_tokens: total_tokens,
completion_tokens: Some(completion_tokens),
}),
}
}
}

#[cfg(test)]
mod tests {
use std::io::Write;
Expand Down
82 changes: 77 additions & 5 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, CompatGenerateRequest, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
StreamDetails, StreamResponse, Token, Validation,
BestOfSequence, CompatGenerateRequest, CompletionRequest, CompletionResponse,
CompletionStreamResponse, Details, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, StreamDetails,
StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
Expand Down Expand Up @@ -77,6 +78,66 @@ async fn compat_generate(
}
}

/// Generate tokens if `stream == false` or a stream of token if `stream == true`
#[utoipa::path(
post,
tag = "LoRAX",
path = "/v1/completions",
request_body = CompletionRequest,
responses(
(status = 200, description = "Generated Text",
content(
("application/json" = CompletionResponse),
("text/event-stream" = CompletionStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn completions_v1(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
req: Json<CompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let req = req.0;
let mut gen_req = CompatGenerateRequest::from(req);

// default return_full_text given the pipeline_tag
if gen_req.parameters.return_full_text.is_none() {
gen_req.parameters.return_full_text = Some(default_return_full_text.0)
}

// switch on stream
if gen_req.stream {
let callback = move |resp: StreamResponse| {
Event::default()
.json_data(CompletionStreamResponse::from(resp))
.map_or_else(
|err| {
tracing::error!("Failed to serialize CompletionStreamResponse: {err}");
Event::default()
},
|data| data,
)
};

let (headers, stream) =
generate_stream_with_callback(infer, Json(gen_req.into()), callback).await;
Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response())
} else {
let (headers, generation) = generate(infer, Json(gen_req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![CompletionResponse::from(generation.0)])).into_response())
}
}

/// LoRAX endpoint info
#[utoipa::path(
get,
Expand Down Expand Up @@ -351,6 +412,16 @@ async fn generate_stream(
HeaderMap,
Sse<impl Stream<Item = Result<Event, Infallible>>>,
) {
let callback = |resp: StreamResponse| Event::default().json_data(resp).unwrap();
let (headers, stream) = generate_stream_with_callback(infer, req, callback).await;
(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
}

async fn generate_stream_with_callback(
infer: Extension<Infer>,
req: Json<GenerateRequest>,
callback: impl Fn(StreamResponse) -> Event,
) -> (HeaderMap, impl Stream<Item = Result<Event, Infallible>>) {
let span = tracing::Span::current();
let start_time = Instant::now();
metrics::increment_counter!("lorax_request_count");
Expand Down Expand Up @@ -479,7 +550,7 @@ async fn generate_stream(
details
};

yield Ok(Event::default().json_data(stream_token).unwrap());
yield Ok(callback(stream_token));
break;
}
}
Expand Down Expand Up @@ -510,7 +581,7 @@ async fn generate_stream(
}
};

(headers, Sse::new(stream).keep_alive(KeepAlive::default()))
(headers, stream)
}

/// Prometheus metrics scrape endpoint
Expand Down Expand Up @@ -699,6 +770,7 @@ pub async fn run(
.route("/info", get(get_model_info))
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/v1/completions", post(completions_v1))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
Expand Down
Loading