Skip to content

Commit

Permalink
Apply chat template in router to properly validate input length (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jul 23, 2024
1 parent 240079b commit 59631a0
Show file tree
Hide file tree
Showing 13 changed files with 560 additions and 148 deletions.
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ struct Args {
/// Download model weights only
#[clap(long, env)]
download_only: bool,

/// The path to the tokenizer config file. This path is used to load the tokenizer configuration which may
/// include a `chat_template`. If not provided, the default config will be used from the model hub.
#[clap(long, env)]
tokenizer_config_path: Option<String>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -1093,6 +1098,12 @@ fn spawn_webserver(
router_args.push("--adapter-source".to_string());
router_args.push(adapter_source.to_string());

// Tokenizer config path
if let Some(ref tokenizer_config_path) = args.tokenizer_config_path {
router_args.push("--tokenizer-config-path".to_string());
router_args.push(tokenizer_config_path.to_string());
}

// Model optional max batch total tokens
if let Some(max_batch_total_tokens) = args.max_batch_total_tokens {
router_args.push("--max-batch-total-tokens".to_string());
Expand Down
6 changes: 2 additions & 4 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,10 @@ message Request {
bool prefill_logprobs = 7;
/// Adapter index
uint32 adapter_index = 8;
/// Apply chat template to inputs
bool apply_chat_template = 9;
/// Paged attention blocks
repeated uint32 blocks = 10;
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 11;
repeated uint32 slots = 10;
}

message Batch {
Expand Down
2 changes: 2 additions & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ ngrok = { version = "0.12.3", features = ["axum"], optional = true }
once_cell = "1.19.0"
itertools = "0.12.1"
async-trait = "0.1.80"
minijinja = { version = "2.0.2" }
minijinja-contrib = { version = "2.0.2", features = ["pycompat"] }

[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }
Expand Down
1 change: 0 additions & 1 deletion router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ impl Client {
}),
adapter_index: 0,
prefill_logprobs: true,
apply_chat_template: false,
});
n_tokens += max_input_length;
}
Expand Down
4 changes: 0 additions & 4 deletions router/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ pub(crate) struct ValidGenerateRequest {
pub parameters: NextTokenChooserParameters,
pub stopping_parameters: StoppingCriteriaParameters,
pub adapter: Adapter,
pub apply_chat_template: bool,
}

/// AdapterLoader entry
Expand Down Expand Up @@ -297,7 +296,6 @@ impl BatchEntries for GenerateBatchEntries {
parameters: Some(request.parameters.clone()),
stopping_parameters: Some(request.stopping_parameters.clone()),
adapter_index: adapter.index(),
apply_chat_template: request.apply_chat_template,
blocks,
slots,
};
Expand Down Expand Up @@ -418,7 +416,6 @@ impl BatchEntries for EmbedBatchEntries {
parameters: None,
stopping_parameters: None,
adapter_index: adapter.index(),
apply_chat_template: false,
blocks,
slots,
};
Expand Down Expand Up @@ -533,7 +530,6 @@ impl BatchEntries for ClassifyBatchEntries {
parameters: None,
stopping_parameters: None,
adapter_index: adapter.index(),
apply_chat_template: false,
blocks,
slots,
};
Expand Down
1 change: 0 additions & 1 deletion router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ impl Health {
ignore_eos_token: false,
}),
adapter_index: 0,
apply_chat_template: false,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
Expand Down
132 changes: 130 additions & 2 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use crate::queue::AdapterEvent;
use crate::scheduler::AdapterScheduler;
use crate::validation::{Validation, ValidationError};
use crate::{
AdapterParameters, AlternativeToken, ClassifyRequest, ClassifyResponse, EmbedRequest,
EmbedResponse, Entity, Entry, Token,
AdapterParameters, AlternativeToken, ChatTemplate, ChatTemplateVersions, ClassifyRequest,
ClassifyResponse, EmbedRequest, EmbedResponse, Entity, Entry, HubTokenizerConfig, Message,
TextMessage, Token, TokenizerConfigToken,
};
use crate::{GenerateRequest, PrefillToken};
use flume::r#async::RecvStream;
Expand All @@ -18,7 +19,10 @@ use lorax_client::{
Batch, CachedBatch, ClientError, Embedding, EntityList, GeneratedText, Generation,
PrefillTokens, ShardedClient,
};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
use nohash_hasher::IntMap;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::{
atomic::{AtomicBool, Ordering},
Expand All @@ -30,6 +34,91 @@ use tokio::sync::{Mutex, Notify, OwnedSemaphorePermit, Semaphore, TryAcquireErro
use tokio::time::Instant;
use tracing::{info_span, instrument, Span};

#[derive(Clone, Serialize, Deserialize, Default)]
pub(crate) struct ChatTemplateInputs<'a> {
messages: Vec<TextMessage>,
bos_token: Option<&'a str>,
eos_token: Option<&'a str>,
add_generation_prompt: bool,
tools: Option<&'a str>,
tools_prompt: Option<&'a str>,
}

/// Raise a exception (custom function) used in the chat templates
fn raise_exception(err_text: String) -> Result<String, minijinja::Error> {
Err(minijinja::Error::new(ErrorKind::SyntaxError, err_text))
}

#[derive(Clone)]
struct ChatTemplateRenderer {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
use_default_tool_template: bool,
}

impl ChatTemplateRenderer {
fn new(
template: String,
bos_token: Option<TokenizerConfigToken>,
eos_token: Option<TokenizerConfigToken>,
) -> Self {
let mut env = Box::new(Environment::new());
// enable things like .strip() or .capitalize()
env.set_unknown_method_callback(pycompat::unknown_method_callback);
let template_str = template.into_boxed_str();
env.add_function("raise_exception", raise_exception);

// TODO(travis): revisit when we add tool usage
// check if contains the tools variable within the template
// let use_default_tool_template =
// !template_str.as_ref().replace(' ', "").contains("{{tools}}");
let use_default_tool_template = false;

// leaking env and template_str as read-only, static resources for performance.
let template = Box::leak(env)
.template_from_str(Box::leak(template_str))
.unwrap();

Self {
template,
bos_token: bos_token.map(|token| token.as_str().to_string()),
eos_token: eos_token.map(|token| token.as_str().to_string()),
use_default_tool_template,
}
}

fn apply(
&self,
mut messages: Vec<Message>,
// grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
// TODO(travis): revisit when we add tool usage
// if self.use_default_tool_template {
// if let Some(last_message) = messages.last_mut() {
// if let Some((GrammarType::Json(tools), tool_prompt)) = grammar_with_prompt {
// last_message.content.push(MessageChunk::Text {
// text: format!("\n---\n{}\n{}", tool_prompt, tools),
// });
// }
// }
// }

let messages: Vec<TextMessage> = messages.into_iter().map(|c| c.into()).collect();

self.template
.render(ChatTemplateInputs {
messages,
bos_token: self.bos_token.as_deref(),
eos_token: self.eos_token.as_deref(),
add_generation_prompt: true,
tools: None,
tools_prompt: None,
})
.map_err(InferError::TemplateError)
}
}

/// Inference struct
#[derive(Clone)]
pub struct Infer {
Expand All @@ -39,6 +128,8 @@ pub struct Infer {
adapter_scheduler: AdapterScheduler,
/// Maps adapter ID to a unique index
adapter_to_index: Arc<Mutex<HashMap<AdapterParameters, u32>>>,
/// Chat template
chat_template: Option<ChatTemplateRenderer>,
/// Inference limit
limit_concurrent_requests: Arc<Semaphore>,
}
Expand All @@ -59,6 +150,7 @@ impl Infer {
window_size: Option<u32>,
generation_health: Arc<AtomicBool>,
eager_prefill: bool,
tokenizer_config: HubTokenizerConfig,
preloaded_adapter_ids: Vec<String>,
block_size: u32,
speculate: u32,
Expand Down Expand Up @@ -100,6 +192,19 @@ impl Infer {

let adapter_to_index = Arc::new(Mutex::new(adapter_to_index));

let chat_template = tokenizer_config
.chat_template
.and_then(|t| match t {
ChatTemplateVersions::Single(template) => Some(template),
ChatTemplateVersions::Multiple(templates) => templates
.into_iter()
.find(|t| t.name == "default")
.map(|t| t.template),
})
.map(|t| {
ChatTemplateRenderer::new(t, tokenizer_config.bos_token, tokenizer_config.eos_token)
});

// Spawn batching background task that contains all the inference logic
tokio::spawn(batching_task(
client,
Expand All @@ -120,6 +225,7 @@ impl Infer {
validation,
adapter_scheduler,
adapter_to_index,
chat_template,
limit_concurrent_requests: semaphore,
}
}
Expand Down Expand Up @@ -227,6 +333,25 @@ impl Infer {
// Return Encoding
Ok(encoding.map(|(encoding, _)| encoding))
}

/// Apply the chat template to the chat request
#[instrument(skip_all)]
pub(crate) fn apply_chat_template(
&self,
messages: Vec<Message>,
// grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
self.chat_template
.as_ref()
.ok_or_else(|| InferError::TemplateError(ErrorKind::TemplateNotFound.into()))?
.apply(messages)
.map_err(|e| {
metrics::increment_counter!("lorax_request_failure", "err" => "template");
tracing::error!("{e}");
e
})
}

/// Add a new request to the queue and return a InferResponse
#[instrument(skip(self))]
pub(crate) async fn generate(
Expand Down Expand Up @@ -1196,6 +1321,8 @@ pub enum InferError {
ValidationError(#[from] ValidationError),
#[error("Incomplete generation")]
IncompleteGeneration,
#[error("Failed applying chat template to inputs: {0}")]
TemplateError(#[from] minijinja::Error),
#[error("Embedding Failure")]
EmbeddingFailure,
#[error("Classification Failure")]
Expand All @@ -1209,6 +1336,7 @@ impl InferError {
InferError::Overloaded(_) => "overloaded",
InferError::ValidationError(_) => "validation",
InferError::IncompleteGeneration => "incomplete_generation",
InferError::TemplateError(_) => "template_error",
InferError::EmbeddingFailure => "embedding_failure",
InferError::ClassificationFailure => "classification_failure",
}
Expand Down
Loading

0 comments on commit 59631a0

Please sign in to comment.