From 80730b74c609fafbfa0f8db3540ca8b780833147 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 18 Nov 2024 10:48:06 -0800 Subject: [PATCH] Record TTFT and TPOT in response headers (#684) --- router/src/infer.rs | 9 +++++++++ router/src/server.rs | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/router/src/infer.rs b/router/src/infer.rs index 04afdb2d4..52f7dc2a0 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -411,6 +411,7 @@ impl Infer { let mut result_generated_text = None; let mut result_start = None; let mut result_queued = None; + let mut result_prefill_time = None; // Iterate on stream while let Some(response) = stream.next().await { @@ -419,6 +420,7 @@ impl Infer { InferStreamResponse::Prefill { tokens, tokens_length, + prefill_time, } => { // Create Token objects // We do that here instead of in the Python code as Rust for loops are faster @@ -432,6 +434,7 @@ impl Infer { .collect(); } result_prefill_length = tokens_length; + result_prefill_time = Some(prefill_time); } // Push last token InferStreamResponse::Token(token) => result_tokens.push(token), @@ -463,6 +466,7 @@ impl Infer { if let (Some(generated_text), Some(queued), Some(start)) = (result_generated_text, result_queued, result_start) { + let prefill_time = result_prefill_time.unwrap_or(Instant::now()); Ok(InferResponse { prefill: result_prefill, tokens: result_tokens, @@ -470,6 +474,7 @@ impl Infer { generated_text, queued, start, + prefill_time, }) } else { let err = InferError::IncompleteGeneration; @@ -1369,9 +1374,11 @@ fn send_responses( if generation.prefill_tokens_length > 0 { // Send message + let prefill_time = Instant::now(); entry.response_tx.send(Ok(InferStreamResponse::Prefill { tokens: generation.prefill_tokens, tokens_length: generation.prefill_tokens_length, + prefill_time, }))?; } @@ -1510,6 +1517,7 @@ pub(crate) enum InferStreamResponse { Prefill { tokens: Option, tokens_length: u32, + prefill_time: Instant, }, // Intermediate messages Token(Token), @@ -1550,6 +1558,7 @@ pub(crate) struct InferResponse { pub(crate) generated_text: GeneratedText, pub(crate) queued: Instant, pub(crate) start: Instant, + pub(crate) prefill_time: Instant, } #[derive(Debug, Error)] diff --git a/router/src/server.rs b/router/src/server.rs index e38468573..7717d5335 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -33,6 +33,7 @@ use once_cell::sync::OnceCell; use reqwest_middleware::ClientBuilder; use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware}; use serde_json::Value; +use std::cmp; use std::convert::Infallible; use std::net::SocketAddr; use std::sync::atomic::AtomicBool; @@ -697,6 +698,9 @@ async fn generate( let queue_time = response.start - response.queued; let inference_time = Instant::now() - response.start; let time_per_token = inference_time / response.generated_text.generated_tokens; + let time_to_first_token = response.prefill_time - response.start; + let time_per_output_token = (inference_time - time_to_first_token) + / cmp::max(response.generated_text.generated_tokens - 1, 1); // Rust Tracing metadata span.record("total_time", format!("{total_time:?}")); @@ -704,6 +708,11 @@ async fn generate( span.record("queue_time", format!("{queue_time:?}")); span.record("inference_time", format!("{inference_time:?}")); span.record("time_per_token", format!("{time_per_token:?}")); + span.record("time_to_first_token", format!("{time_to_first_token:?}")); + span.record( + "time_per_output_token", + format!("{time_per_output_token:?}"), + ); span.record("seed", format!("{:?}", response.generated_text.seed)); span.record("prompt_tokens", format!("{prompt_tokens:?}")); span.record("generated_tokens", format!("{generated_tokens:?}")); @@ -753,6 +762,18 @@ async fn generate( "x-time-per-token", time_per_token.as_millis().to_string().parse().unwrap(), ); + headers.insert( + "x-time-to-first-token", + time_to_first_token.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-time-per-output-token", + time_per_output_token + .as_millis() + .to_string() + .parse() + .unwrap(), + ); headers.insert("x-model-id", info.model_id.parse().unwrap());