Skip to content

Commit

Permalink
Fixed token count
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Jan 8, 2024
1 parent ffed638 commit c2762af
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 19 deletions.
2 changes: 2 additions & 0 deletions docs/reference/python_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class BestOfSequence:
class Details:
# Generation finish reason
finish_reason: FinishReason
# Number of prompt tokens
prompt_tokens: int
# Number of generated tokens
generated_tokens: int
# Sampling seed if sampling was activated
Expand Down
2 changes: 2 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ message Generation {
bool token_is_special = 6;
/// Complete generated text
optional GeneratedText generated_text = 7;
/// Prefill tokens length
uint32 prefill_tokens_length = 8;
}

message FilterBatchRequest {
Expand Down
43 changes: 26 additions & 17 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ impl Infer {
// Return values
let mut result_prefill = Vec::new();
let mut result_tokens = Vec::new();
let mut result_prefill_length = 0;
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
Expand All @@ -197,16 +198,22 @@ impl Infer {
while let Some(response) = stream.next().await {
match response? {
// Add prefill tokens
InferStreamResponse::Prefill(tokens) => {
InferStreamResponse::Prefill {
tokens,
tokens_length,
} => {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
result_prefill = tokens
.ids
.into_iter()
.zip(tokens.logprobs.into_iter())
.zip(tokens.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
if let Some(tokens_val) = tokens {
result_prefill = tokens_val
.ids
.into_iter()
.zip(tokens_val.logprobs.into_iter())
.zip(tokens_val.texts.into_iter())
.map(|((id, logprob), text)| PrefillToken { id, text, logprob })
.collect();
}
result_prefill_length = tokens_length;
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
Expand All @@ -233,6 +240,7 @@ impl Infer {
Ok(InferResponse {
prefill: result_prefill,
tokens: result_tokens,
prompt_tokens: result_prefill_length,
generated_text,
queued,
start,
Expand Down Expand Up @@ -569,10 +577,13 @@ fn send_responses(

let mut stopped = false;

if let Some(prefill_tokens) = generation.prefill_tokens {
if generation.prefill_tokens_length > 0 {
// Send message
entry.response_tx.send_timeout(
Ok(InferStreamResponse::Prefill(prefill_tokens)),
Ok(InferStreamResponse::Prefill {
tokens: generation.prefill_tokens,
tokens_length: generation.prefill_tokens_length,
}),
Duration::from_millis(10),
)?;
}
Expand Down Expand Up @@ -629,7 +640,10 @@ fn send_errors(error: ClientError, entries: &mut IntMap<u64, Entry>) {
#[derive(Debug)]
pub(crate) enum InferStreamResponse {
// Optional first message
Prefill(PrefillTokens),
Prefill {
tokens: Option<PrefillTokens>,
tokens_length: u32,
},
// Intermediate messages
Token(Token),
// Last message
Expand All @@ -645,17 +659,12 @@ pub(crate) enum InferStreamResponse {
pub(crate) struct InferResponse {
pub(crate) prefill: Vec<PrefillToken>,
pub(crate) tokens: Vec<Token>,
pub(crate) prompt_tokens: u32,
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
}

impl InferResponse {
pub(crate) fn prompt_tokens(&self) -> u32 {
self.prefill.len() as u32
}
}

#[derive(Debug, Error)]
pub enum InferError {
#[error("Request failed during generation: {0}")]
Expand Down
6 changes: 4 additions & 2 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ async fn generate(
};

let generated_tokens = response.generated_text.generated_tokens;
let prompt_tokens = response.prompt_tokens();
let prompt_tokens = response.prompt_tokens;
let total_tokens = prompt_tokens + generated_tokens;

// Token details
Expand Down Expand Up @@ -411,7 +411,9 @@ async fn generate_stream(
Ok(response) => {
match response {
// Prefill is ignored
InferStreamResponse::Prefill(_) => {}
InferStreamResponse::Prefill {
..
} => {}
// Yield event for every new token
InferStreamResponse::Token(token) => {
tracing::debug!(parent: &span, "Token: {:?}", token);
Expand Down
1 change: 1 addition & 0 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,7 @@ def generate_token(
generation = Generation(
request.id,
prefill_tokens,
len(all_input_ids[:-1]) if prefill else 0,
next_token_id,
next_token_logprob,
next_token_text,
Expand Down
2 changes: 2 additions & 0 deletions server/lorax_server/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __len__(self):
class Generation:
request_id: int
prefill_tokens: Optional[PrefillTokens]
prefill_tokens_length: int
token_id: int
token_logprob: float
token_text: str
Expand All @@ -87,6 +88,7 @@ def to_pb(self) -> generate_pb2.Generation:
prefill_tokens=self.prefill_tokens.to_pb()
if self.prefill_tokens is not None
else None,
prefill_tokens_length=self.prefill_tokens_length,
token_id=self.token_id,
token_logprob=self.token_logprob,
token_text=self.token_text,
Expand Down

0 comments on commit c2762af

Please sign in to comment.