Skip to content

Commit

Permalink
Record TTFT and TPOT in response headers (#684)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Nov 18, 2024
1 parent e4394d6 commit 80730b7
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
9 changes: 9 additions & 0 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ impl Infer {
let mut result_generated_text = None;
let mut result_start = None;
let mut result_queued = None;
let mut result_prefill_time = None;

// Iterate on stream
while let Some(response) = stream.next().await {
Expand All @@ -419,6 +420,7 @@ impl Infer {
InferStreamResponse::Prefill {
tokens,
tokens_length,
prefill_time,
} => {
// Create Token objects
// We do that here instead of in the Python code as Rust for loops are faster
Expand All @@ -432,6 +434,7 @@ impl Infer {
.collect();
}
result_prefill_length = tokens_length;
result_prefill_time = Some(prefill_time);
}
// Push last token
InferStreamResponse::Token(token) => result_tokens.push(token),
Expand Down Expand Up @@ -463,13 +466,15 @@ impl Infer {
if let (Some(generated_text), Some(queued), Some(start)) =
(result_generated_text, result_queued, result_start)
{
let prefill_time = result_prefill_time.unwrap_or(Instant::now());
Ok(InferResponse {
prefill: result_prefill,
tokens: result_tokens,
prompt_tokens: result_prefill_length,
generated_text,
queued,
start,
prefill_time,
})
} else {
let err = InferError::IncompleteGeneration;
Expand Down Expand Up @@ -1369,9 +1374,11 @@ fn send_responses(

if generation.prefill_tokens_length > 0 {
// Send message
let prefill_time = Instant::now();
entry.response_tx.send(Ok(InferStreamResponse::Prefill {
tokens: generation.prefill_tokens,
tokens_length: generation.prefill_tokens_length,
prefill_time,
}))?;
}

Expand Down Expand Up @@ -1510,6 +1517,7 @@ pub(crate) enum InferStreamResponse {
Prefill {
tokens: Option<PrefillTokens>,
tokens_length: u32,
prefill_time: Instant,
},
// Intermediate messages
Token(Token),
Expand Down Expand Up @@ -1550,6 +1558,7 @@ pub(crate) struct InferResponse {
pub(crate) generated_text: GeneratedText,
pub(crate) queued: Instant,
pub(crate) start: Instant,
pub(crate) prefill_time: Instant,
}

#[derive(Debug, Error)]
Expand Down
21 changes: 21 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use once_cell::sync::OnceCell;
use reqwest_middleware::ClientBuilder;
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use serde_json::Value;
use std::cmp;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::atomic::AtomicBool;
Expand Down Expand Up @@ -697,13 +698,21 @@ async fn generate(
let queue_time = response.start - response.queued;
let inference_time = Instant::now() - response.start;
let time_per_token = inference_time / response.generated_text.generated_tokens;
let time_to_first_token = response.prefill_time - response.start;
let time_per_output_token = (inference_time - time_to_first_token)
/ cmp::max(response.generated_text.generated_tokens - 1, 1);

// Rust Tracing metadata
span.record("total_time", format!("{total_time:?}"));
span.record("validation_time", format!("{validation_time:?}"));
span.record("queue_time", format!("{queue_time:?}"));
span.record("inference_time", format!("{inference_time:?}"));
span.record("time_per_token", format!("{time_per_token:?}"));
span.record("time_to_first_token", format!("{time_to_first_token:?}"));
span.record(
"time_per_output_token",
format!("{time_per_output_token:?}"),
);
span.record("seed", format!("{:?}", response.generated_text.seed));
span.record("prompt_tokens", format!("{prompt_tokens:?}"));
span.record("generated_tokens", format!("{generated_tokens:?}"));
Expand Down Expand Up @@ -753,6 +762,18 @@ async fn generate(
"x-time-per-token",
time_per_token.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-time-to-first-token",
time_to_first_token.as_millis().to_string().parse().unwrap(),
);
headers.insert(
"x-time-per-output-token",
time_per_output_token
.as_millis()
.to_string()
.parse()
.unwrap(),
);

headers.insert("x-model-id", info.model_id.parse().unwrap());

Expand Down

0 comments on commit 80730b7

Please sign in to comment.