Skip to content

Commit

Permalink
Tokenize inputs in router (#548)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jul 19, 2024
1 parent 1adc076 commit 452ac73
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 19 deletions.
23 changes: 15 additions & 8 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,34 @@ message StoppingCriteriaParameters {
bool ignore_eos_token = 3;
}

message TokenizedInputs {
/// Token IDs
repeated uint32 ids = 1;
}

message Request {
/// Request ID
uint64 id = 1;
/// The generation context
string inputs = 2;
/// Tokenized inputs
TokenizedInputs tokenized_inputs = 3;
/// Context truncation
uint32 truncate = 3;
uint32 truncate = 4;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 4;
NextTokenChooserParameters parameters = 5;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 5;
StoppingCriteriaParameters stopping_parameters = 6;
/// Return prefill logprobs
bool prefill_logprobs = 6;
bool prefill_logprobs = 7;
/// Adapter index
uint32 adapter_index = 7;
uint32 adapter_index = 8;
/// Apply chat template to inputs
bool apply_chat_template = 8;
bool apply_chat_template = 9;
/// Paged attention blocks
repeated uint32 blocks = 9;
repeated uint32 blocks = 10;
/// Paged attention slots
repeated uint32 slots = 10;
repeated uint32 slots = 11;
}

message Batch {
Expand Down
1 change: 1 addition & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ impl Client {
requests.push(Request {
id: 0,
inputs: "_test ".to_string().repeat(max_input_length as usize),
tokenized_inputs: None,
truncate: truncate_length,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
Expand Down
1 change: 1 addition & 0 deletions router/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub use pb::generate::v1::{
AdapterParameters, AlternativeTokens, Batch, CachedBatch, DownloadAdapterResponse, Embedding,
Entity, EntityList, FinishReason, GeneratedText, Generation, MajoritySignMethod, MergeStrategy,
NextTokenChooserParameters, NextTokens, PrefillTokens, Request, StoppingCriteriaParameters,
TokenizedInputs,
};
pub use sharded_client::ShardedClient;
use thiserror::Error;
Expand Down
9 changes: 8 additions & 1 deletion router/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use async_trait::async_trait;

use lorax_client::{
Batch, CachedBatch, NextTokenChooserParameters, Request, ShardedClient,
StoppingCriteriaParameters,
StoppingCriteriaParameters, TokenizedInputs,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use tokenizers::Token;
use tokio::time::Instant;
use tracing::{info_span, span, Instrument, Span};

Expand Down Expand Up @@ -54,6 +55,7 @@ impl ValidRequest for ValidGenerateRequest {
#[derive(Debug)]
pub(crate) struct ValidEmbedRequest {
pub inputs: String,
pub tokenized_inputs: Option<TokenizedInputs>,
pub input_length: u32,
pub adapter: Adapter,
}
Expand Down Expand Up @@ -83,6 +85,7 @@ impl ValidRequest for ValidEmbedRequest {
#[derive(Debug)]
pub(crate) struct ValidClassifyRequest {
pub inputs: String,
pub tokenized_inputs: Option<TokenizedInputs>,
pub input_length: u32,
pub adapter: Adapter,
}
Expand Down Expand Up @@ -112,6 +115,7 @@ impl ValidRequest for ValidClassifyRequest {
#[derive(Debug)]
pub(crate) struct ValidGenerateRequest {
pub inputs: String,
pub tokenized_inputs: Option<TokenizedInputs>,
pub input_length: u32,
pub truncate: u32,
pub decoder_input_details: bool,
Expand Down Expand Up @@ -288,6 +292,7 @@ impl BatchEntries for GenerateBatchEntries {
id,
prefill_logprobs: request.decoder_input_details,
inputs: request.inputs.clone(),
tokenized_inputs: request.tokenized_inputs.clone(),
truncate: request.truncate,
parameters: Some(request.parameters.clone()),
stopping_parameters: Some(request.stopping_parameters.clone()),
Expand Down Expand Up @@ -408,6 +413,7 @@ impl BatchEntries for EmbedBatchEntries {
id,
prefill_logprobs: false,
inputs: request.inputs.clone(),
tokenized_inputs: request.tokenized_inputs.clone(),
truncate: 0,
parameters: None,
stopping_parameters: None,
Expand Down Expand Up @@ -522,6 +528,7 @@ impl BatchEntries for ClassifyBatchEntries {
id,
prefill_logprobs: false,
inputs: request.inputs.clone(),
tokenized_inputs: request.tokenized_inputs.clone(),
truncate: 0,
parameters: None,
stopping_parameters: None,
Expand Down
6 changes: 6 additions & 0 deletions router/src/block_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ async fn block_allocator_task(
};

let tokens = tokens as usize;
tracing::trace!(
"Allocating {} tokens ({} blocks, {} repeats)",
tokens,
required_blocks,
repeats
);
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
Expand Down
1 change: 1 addition & 0 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ impl Health {
let liveness_request = Request {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
tokenized_inputs: None,
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
Expand Down
10 changes: 6 additions & 4 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,14 @@ impl Infer {
// err
// })?;

let (inputs, input_length) = self
let (inputs, tokenized_inputs, input_length) = self
.validation
.validate_input(request.inputs, None, Some(1))
.await?;

let valid_request = ValidEmbedRequest {
inputs: inputs,
inputs,
tokenized_inputs,
input_length: input_length as u32,
adapter: adapter.clone(),
};
Expand Down Expand Up @@ -464,13 +465,14 @@ impl Infer {
None,
);

let (inputs, input_length) = self
let (inputs, tokenized_inputs, input_length) = self
.validation
.validate_input(request.inputs, None, Some(1))
.await?;

let valid_request = ValidClassifyRequest {
inputs: inputs,
inputs,
tokenized_inputs,
input_length: input_length as u32,
adapter: adapter.clone(),
};
Expand Down
8 changes: 8 additions & 0 deletions router/src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,14 @@ impl AdapterSchedulerState {
+ self.speculate
- 1;

tracing::trace!(
"Scheduling {} tokens ({} input, {} output, {} speculate)",
tokens,
entry.request.input_length(),
entry.request.max_new_tokens(),
self.speculate
);

match block_allocator.allocate(tokens).await {
None => {
// Entry is over budget
Expand Down
15 changes: 10 additions & 5 deletions router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::batch::ValidGenerateRequest;
/// Payload validation logic
use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput};
use crate::{GenerateParameters, GenerateRequest};
use lorax_client::{NextTokenChooserParameters, StoppingCriteriaParameters};
use lorax_client::{NextTokenChooserParameters, StoppingCriteriaParameters, TokenizedInputs};
use rand::{thread_rng, Rng};
use thiserror::Error;
use tokenizers::tokenizer::Tokenizer;
Expand Down Expand Up @@ -92,7 +92,7 @@ impl Validation {
inputs: String,
truncate: Option<usize>,
max_new_tokens: Option<u32>,
) -> Result<(String, usize), ValidationError> {
) -> Result<(String, Option<TokenizedInputs>, usize), ValidationError> {
// If we have a fast tokenizer
if let Some((encoding, inputs)) = self.tokenize(inputs.clone(), truncate).await? {
// Create response channel
Expand Down Expand Up @@ -120,8 +120,12 @@ impl Validation {
));
}

let tokenized_inputs = Some(TokenizedInputs {
ids: encoding.get_ids().to_vec(),
});

metrics::histogram!("lorax_request_input_length", input_length as f64);
Ok((inputs, input_length))
Ok((inputs, tokenized_inputs, input_length))
}
// Return inputs without validation
else {
Expand All @@ -139,7 +143,7 @@ impl Validation {
}
}

Ok((inputs, input_length))
Ok((inputs, None, input_length))
}
}

Expand Down Expand Up @@ -291,7 +295,7 @@ impl Validation {
let adapter_id = adapter_id.unwrap_or_else(|| "".to_string());

// Validate inputs
let (inputs, input_length) = self
let (inputs, tokenized_inputs, input_length) = self
.validate_input(request.inputs, truncate, max_new_tokens)
.await?;

Expand Down Expand Up @@ -330,6 +334,7 @@ impl Validation {

Ok(ValidGenerateRequest {
inputs,
tokenized_inputs,
decoder_input_details,
input_length: input_length as u32,
truncate: truncate.unwrap_or(self.max_input_length) as u32,
Expand Down
7 changes: 6 additions & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ def from_pb(
batch_inputs.append(inputs)
max_truncation = max(max_truncation, r.truncate)

batch_tokenized_inputs = tokenizer(batch_inputs, truncation=True, max_length=max_truncation)["input_ids"]
if all(r.HasField("tokenized_inputs") for r in pb.requests):
batch_tokenized_inputs = [
r.tokenized_inputs.ids[-max_truncation :] for r in pb.requests
]
else:
batch_tokenized_inputs = tokenizer(batch_inputs, truncation=True, max_length=max_truncation)["input_ids"]

position_ids = []
cu_seqlen_prefill = [0]
Expand Down

0 comments on commit 452ac73

Please sign in to comment.