Skip to content

Commit

Permalink
OpenAI v1 Chat Completions API (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jan 10, 2024
1 parent 82dac66 commit a90d443
Show file tree
Hide file tree
Showing 24 changed files with 219 additions and 27 deletions.
5 changes: 5 additions & 0 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,11 @@
"api_token": {
"type": "string",
"nullable": true
},
"apply_chat_template": {
"type": "boolean",
"default": "false",
"example": true
}
}
},
Expand Down
2 changes: 2 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ message Request {
bool prefill_logprobs = 6;
/// Adapter index
uint32 adapter_index = 7;
/// Apply chat template to inputs
bool apply_chat_template = 8;
}

message Batch {
Expand Down
1 change: 1 addition & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ impl Client {
}),
adapter_index: 0,
prefill_logprobs: true,
apply_chat_template: false,
});
n_tokens += max_input_length;
}
Expand Down
1 change: 1 addition & 0 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ impl Health {
ignore_eos_token: false,
}),
adapter_index: 0,
apply_chat_template: false,
};
let batch = Batch {
id: BATCH_ID,
Expand Down
40 changes: 39 additions & 1 deletion router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ pub(crate) struct GenerateParameters {
#[schema(default = "true")]
pub decoder_input_details: bool,
#[serde(default)]
#[schema(default = "false")]
pub apply_chat_template: bool,
#[serde(default)]
#[schema(
exclusive_minimum = 0,
nullable = true,
Expand Down Expand Up @@ -177,6 +180,7 @@ fn default_parameters() -> GenerateParameters {
watermark: false,
details: false,
decoder_input_details: false,
apply_chat_template: false,
seed: None,
}
}
Expand Down Expand Up @@ -320,7 +324,7 @@ struct UsageInfo {
#[derive(Clone, Debug, Deserialize, ToSchema)]
struct ChatCompletionRequest {
model: String,
messages: Vec<String>,
messages: Vec<std::collections::HashMap<String, String>>,
temperature: Option<f32>,
top_p: Option<f32>,
n: Option<i32>,
Expand Down Expand Up @@ -451,6 +455,40 @@ impl From<CompletionRequest> for CompatGenerateRequest {
watermark: false,
details: true,
decoder_input_details: req.logprobs.is_some(),
apply_chat_template: false,
seed: None,
},
stream: req.stream.unwrap_or(false),
}
}
}

impl From<ChatCompletionRequest> for CompatGenerateRequest {
fn from(req: ChatCompletionRequest) -> Self {
CompatGenerateRequest {
inputs: serde_json::to_string(&req.messages).unwrap(),
parameters: GenerateParameters {
adapter_id: req.model.parse().ok(),
adapter_source: None,
api_token: None,
best_of: req.n.map(|x| x as usize),
temperature: req.temperature,
repetition_penalty: None,
top_k: None,
top_p: req.top_p,
typical_p: None,
do_sample: !req.n.is_none(),
max_new_tokens: req
.max_tokens
.map(|x| x as u32)
.unwrap_or(default_max_new_tokens()),
return_full_text: None,
stop: req.stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: false,
apply_chat_template: true,
seed: None,
},
stream: req.stream.unwrap_or(false),
Expand Down
1 change: 1 addition & 0 deletions router/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ impl AdapterSchedulerState {
parameters: Some(entry.request.parameters.clone()),
stopping_parameters: Some(entry.request.stopping_parameters.clone()),
adapter_index: adapter.index(),
apply_chat_template: entry.request.apply_chat_template,
});
// Set batch_time
entry.batch_time = Some(Instant::now());
Expand Down
71 changes: 66 additions & 5 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ use crate::health::Health;
use crate::infer::{InferError, InferResponse, InferStreamResponse};
use crate::validation::ValidationError;
use crate::{
BestOfSequence, CompatGenerateRequest, CompletionRequest, CompletionResponse,
CompletionStreamResponse, Details, ErrorResponse, FinishReason, GenerateParameters,
GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken, StreamDetails,
StreamResponse, Token, Validation,
BestOfSequence, ChatCompletionRequest, CompatGenerateRequest, CompletionRequest,
CompletionResponse, CompletionStreamResponse, Details, ErrorResponse, FinishReason,
GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, Info, PrefillToken,
StreamDetails, StreamResponse, Token, Validation,
};
use axum::extract::Extension;
use axum::http::{HeaderMap, Method, StatusCode};
Expand Down Expand Up @@ -78,7 +78,7 @@ async fn compat_generate(
}
}

/// Generate tokens if `stream == false` or a stream of token if `stream == true`
/// OpenAI compatible completions endpoint
#[utoipa::path(
post,
tag = "LoRAX",
Expand Down Expand Up @@ -138,6 +138,66 @@ async fn completions_v1(
}
}

/// OpenAI compatible chat completions endpoint
#[utoipa::path(
post,
tag = "LoRAX",
path = "/v1/chat/completions",
request_body = ChatCompletionRequest,
responses(
(status = 200, description = "Generated Text",
content(
("application/json" = ChatCompletionResponse),
("text/event-stream" = ChatCompletionStreamResponse),
)),
(status = 424, description = "Generation Error", body = ErrorResponse,
example = json ! ({"error": "Request failed during generation"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded"})),
(status = 422, description = "Input validation error", body = ErrorResponse,
example = json ! ({"error": "Input validation error"})),
(status = 500, description = "Incomplete generation", body = ErrorResponse,
example = json ! ({"error": "Incomplete generation"})),
)
)]
#[instrument(skip(infer, req))]
async fn chat_completions_v1(
default_return_full_text: Extension<bool>,
infer: Extension<Infer>,
req: Json<ChatCompletionRequest>,
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
let req = req.0;
let mut gen_req = CompatGenerateRequest::from(req);

// default return_full_text given the pipeline_tag
if gen_req.parameters.return_full_text.is_none() {
gen_req.parameters.return_full_text = Some(default_return_full_text.0)
}

// switch on stream
if gen_req.stream {
let callback = move |resp: StreamResponse| {
Event::default()
.json_data(CompletionStreamResponse::from(resp))
.map_or_else(
|err| {
tracing::error!("Failed to serialize CompletionStreamResponse: {err}");
Event::default()
},
|data| data,
)
};

let (headers, stream) =
generate_stream_with_callback(infer, Json(gen_req.into()), callback).await;
Ok((headers, Sse::new(stream).keep_alive(KeepAlive::default())).into_response())
} else {
let (headers, generation) = generate(infer, Json(gen_req.into())).await?;
// wrap generation inside a Vec to match api-inference
Ok((headers, Json(vec![CompletionResponse::from(generation.0)])).into_response())
}
}

/// LoRAX endpoint info
#[utoipa::path(
get,
Expand Down Expand Up @@ -771,6 +831,7 @@ pub async fn run(
.route("/generate", post(generate))
.route("/generate_stream", post(generate_stream))
.route("/v1/completions", post(completions_v1))
.route("/v1/chat/completions", post(chat_completions_v1))
// AWS Sagemaker route
.route("/invocations", post(compat_generate))
// Base Health route
Expand Down
3 changes: 3 additions & 0 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ impl Validation {
watermark,
adapter_id,
decoder_input_details,
apply_chat_template,
..
} = request.parameters;

Expand Down Expand Up @@ -270,6 +271,7 @@ impl Validation {
parameters,
stopping_parameters,
adapter,
apply_chat_template,
})
}

Expand Down Expand Up @@ -344,6 +346,7 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
pub adapter: Adapter,
pub apply_chat_template: bool,
}

#[derive(Error, Debug)]
Expand Down
4 changes: 3 additions & 1 deletion server/lorax_server/models/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
weight_files,
Weights,
)
from lorax_server.utils.tokenizer import TokenizerManager


class BloomCausalLMBatch(CausalLMBatch):
Expand All @@ -28,10 +29,11 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch = super().from_pb(pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch

Expand Down
6 changes: 5 additions & 1 deletion server/lorax_server/models/causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import torch
import inspect

Expand All @@ -15,6 +16,7 @@
)
from lorax_server.pb import generate_pb2
from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from lorax_server.utils.tokenizer import TokenizerManager

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -69,6 +71,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
Expand All @@ -86,7 +89,8 @@ def from_pb(
adapter_indices_list = []
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
req_inputs = tokenizers.get_inputs(r, tokenizer)
inputs.append(req_inputs)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
Expand Down
14 changes: 10 additions & 4 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
import json
import math
import itertools
from loguru import logger
Expand Down Expand Up @@ -29,11 +30,11 @@
from lorax_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID, load_module_map
from lorax_server.utils.dist import MEMORY_FRACTION
from lorax_server.utils.lora import LM_HEAD, AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights, MergedLoraWeights
from lorax_server.utils.segments import SegmentConcatBuilder, find_segments
from lorax_server.utils.weights import shard_on_dim
from lorax_server.utils.graph import GraphCache
from lorax_server.utils.sgmv import get_tmp_tensor
from lorax_server.utils.tokenizer import TokenizerManager

tracer = trace.get_tracer(__name__)

Expand Down Expand Up @@ -114,13 +115,15 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
inputs = tokenizers.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

batch_tokenized_inputs = tokenizer(
Expand Down Expand Up @@ -746,7 +749,7 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index):
elif adapter_id != BASE_MODEL_ADAPTER_ID:
logger.info(f"Loading adapter weights into model: {adapter_id}")
weight_names = tuple([v[0] for v in self.target_to_layer.values()])
module_map, adapter_config, adapter_weight_names = load_module_map(
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map(
self.model_id, adapter_id, adapter_source, weight_names
)

Expand All @@ -758,6 +761,9 @@ def load_adapter(self, adapter_id, adapter_source, adapter_index):

if len(unused_weight_names) > 0:
logger.warning(f"{adapter_id} unused adapter weights: {unused_weight_names}")

if adapter_tokenizer is not None:
self.tokenizers.add_tokenizer(adapter_index, adapter_tokenizer)

self.adapter_id = adapter_id

Expand Down
6 changes: 5 additions & 1 deletion server/lorax_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import math
import torch
import torch.distributed
Expand Down Expand Up @@ -32,6 +33,7 @@
from lorax_server.utils.adapter import BASE_MODEL_ADAPTER_ID
from lorax_server.utils.lora import DOWN_PROJ, GATE_PROJ, K_PROJ, LM_HEAD, O_PROJ, Q_PROJ, UP_PROJ, V_PROJ, AdapterBatchData, AdapterBatchMetadata
from lorax_server.utils.segments import find_segments
from lorax_server.utils.tokenizer import TokenizerManager

tracer = trace.get_tracer(__name__)

Expand All @@ -55,6 +57,7 @@ def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
tokenizers: TokenizerManager,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
Expand All @@ -64,7 +67,8 @@ def from_pb(
batch_inputs = []
max_truncation = 0
for r in pb.requests:
batch_inputs.append(r.inputs)
inputs = tokenizers.get_inputs(r, tokenizer)
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

batch_tokenized_inputs = tokenizer(
Expand Down
Loading

0 comments on commit a90d443

Please sign in to comment.