diff --git a/docs/reference/python_client.md b/docs/reference/python_client.md index 9d635541d..c9e9ed735 100644 --- a/docs/reference/python_client.md +++ b/docs/reference/python_client.md @@ -181,6 +181,8 @@ class BestOfSequence: class Details: # Generation finish reason finish_reason: FinishReason + # Number of prompt tokens + prompt_tokens: int # Number of generated tokens generated_tokens: int # Sampling seed if sampling was activated diff --git a/proto/generate.proto b/proto/generate.proto index ae4caa3f9..c5b52d1dd 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -167,6 +167,8 @@ message Generation { bool token_is_special = 6; /// Complete generated text optional GeneratedText generated_text = 7; + /// Prefill tokens length + uint32 prefill_tokens_length = 8; } message FilterBatchRequest { diff --git a/router/src/infer.rs b/router/src/infer.rs index df20e0061..d46ab0471 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -189,6 +189,7 @@ impl Infer { // Return values let mut result_prefill = Vec::new(); let mut result_tokens = Vec::new(); + let mut result_prefill_length = 0; let mut result_generated_text = None; let mut result_start = None; let mut result_queued = None; @@ -197,16 +198,22 @@ impl Infer { while let Some(response) = stream.next().await { match response? { // Add prefill tokens - InferStreamResponse::Prefill(tokens) => { + InferStreamResponse::Prefill { + tokens, + tokens_length, + } => { // Create Token objects // We do that here instead of in the Python code as Rust for loops are faster - result_prefill = tokens - .ids - .into_iter() - .zip(tokens.logprobs.into_iter()) - .zip(tokens.texts.into_iter()) - .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) - .collect(); + if let Some(tokens_val) = tokens { + result_prefill = tokens_val + .ids + .into_iter() + .zip(tokens_val.logprobs.into_iter()) + .zip(tokens_val.texts.into_iter()) + .map(|((id, logprob), text)| PrefillToken { id, text, logprob }) + .collect(); + } + result_prefill_length = tokens_length; } // Push last token InferStreamResponse::Token(token) => result_tokens.push(token), @@ -233,6 +240,7 @@ impl Infer { Ok(InferResponse { prefill: result_prefill, tokens: result_tokens, + prompt_tokens: result_prefill_length, generated_text, queued, start, @@ -569,10 +577,13 @@ fn send_responses( let mut stopped = false; - if let Some(prefill_tokens) = generation.prefill_tokens { + if generation.prefill_tokens_length > 0 { // Send message entry.response_tx.send_timeout( - Ok(InferStreamResponse::Prefill(prefill_tokens)), + Ok(InferStreamResponse::Prefill { + tokens: generation.prefill_tokens, + tokens_length: generation.prefill_tokens_length, + }), Duration::from_millis(10), )?; } @@ -629,7 +640,10 @@ fn send_errors(error: ClientError, entries: &mut IntMap) { #[derive(Debug)] pub(crate) enum InferStreamResponse { // Optional first message - Prefill(PrefillTokens), + Prefill { + tokens: Option, + tokens_length: u32, + }, // Intermediate messages Token(Token), // Last message @@ -645,17 +659,12 @@ pub(crate) enum InferStreamResponse { pub(crate) struct InferResponse { pub(crate) prefill: Vec, pub(crate) tokens: Vec, + pub(crate) prompt_tokens: u32, pub(crate) generated_text: GeneratedText, pub(crate) queued: Instant, pub(crate) start: Instant, } -impl InferResponse { - pub(crate) fn prompt_tokens(&self) -> u32 { - self.prefill.len() as u32 - } -} - #[derive(Debug, Error)] pub enum InferError { #[error("Request failed during generation: {0}")] diff --git a/router/src/server.rs b/router/src/server.rs index b6d249ea2..98d30ab9e 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -172,7 +172,7 @@ async fn generate( }; let generated_tokens = response.generated_text.generated_tokens; - let prompt_tokens = response.prompt_tokens(); + let prompt_tokens = response.prompt_tokens; let total_tokens = prompt_tokens + generated_tokens; // Token details @@ -411,7 +411,9 @@ async fn generate_stream( Ok(response) => { match response { // Prefill is ignored - InferStreamResponse::Prefill(_) => {} + InferStreamResponse::Prefill { + .. + } => {} // Yield event for every new token InferStreamResponse::Token(token) => { tracing::debug!(parent: &span, "Token: {:?}", token); diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index 91be95a86..e55d0ddd9 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -1198,6 +1198,7 @@ def generate_token( generation = Generation( request.id, prefill_tokens, + len(all_input_ids[:-1]) if prefill else 0, next_token_id, next_token_logprob, next_token_text, diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index 90d3f55fa..f22ff7da3 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -75,6 +75,7 @@ def __len__(self): class Generation: request_id: int prefill_tokens: Optional[PrefillTokens] + prefill_tokens_length: int token_id: int token_logprob: float token_text: str @@ -87,6 +88,7 @@ def to_pb(self) -> generate_pb2.Generation: prefill_tokens=self.prefill_tokens.to_pb() if self.prefill_tokens is not None else None, + prefill_tokens_length=self.prefill_tokens_length, token_id=self.token_id, token_logprob=self.token_logprob, token_text=self.token_text,