From 654a31d3b399871faa05bca2637cb17d211f7a31 Mon Sep 17 00:00:00 2001 From: Magdy Saleh <17618143+magdyksaleh@users.noreply.github.com> Date: Fri, 13 Sep 2024 15:49:34 -0400 Subject: [PATCH] Disable healthcheck tracing and add metrics to classify + classify_batch endpoints (#603) --- router/src/infer.rs | 48 ++++++++---- router/src/server.rs | 169 ++++++++++++++++++++++++++++++++++++------- 2 files changed, 177 insertions(+), 40 deletions(-) diff --git a/router/src/infer.rs b/router/src/infer.rs index d7a9bca73..0b234550c 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -584,7 +584,7 @@ impl Infer { pub(crate) async fn classify( &self, request: ClassifyRequest, - ) -> Result, InferError> { + ) -> Result { // Limit concurrent requests by acquiring a permit from the semaphore let _permit = self .clone() @@ -639,6 +639,8 @@ impl Infer { // Return values let mut return_entities = None; + let mut result_start = None; + let mut result_queued = None; let mut stream = response_rx.into_stream(); while let Some(response) = stream.next().await { @@ -661,8 +663,8 @@ impl Infer { } InferStreamResponse::Classify { predictions, - start: _, - queued: _, + start, + queued, id: _, } => { let entities = aggregate_ner_output_simple( @@ -671,12 +673,18 @@ impl Infer { self.tokenizer.clone().unwrap(), ); return_entities = Some(entities); + result_start = Some(start); + result_queued = Some(queued); } } } if let Some(return_entities) = return_entities { - Ok(return_entities.into_iter().map(Entity::from).collect()) + Ok(InferClassifyResponse { + predictions: return_entities, + queued: result_queued.unwrap(), + start: result_start.unwrap(), + }) } else { let err = InferError::ClassificationFailure; metrics::increment_counter!("lorax_request_failure", "err" => "classification_failure"); @@ -689,7 +697,7 @@ impl Infer { pub(crate) async fn classify_batch( &self, request: BatchClassifyRequest, - ) -> Result>, InferError> { + ) -> Result, InferError> { // Limit concurrent requests by acquiring a permit from the semaphore let _permit = self .clone() @@ -762,8 +770,8 @@ impl Infer { // Add prefill tokens InferStreamResponse::Classify { predictions, - start: _, - queued: _, + start, + queued, id, } => { let request_inputs = request_id_map.get(&id.unwrap()).unwrap().clone(); @@ -772,7 +780,14 @@ impl Infer { predictions.clone(), self.tokenizer.clone().unwrap(), ); - all_entities.insert(id.unwrap(), entities); + all_entities.insert( + id.unwrap(), + InferClassifyResponse { + predictions: entities, + queued, + start, + }, + ); } _ => { tracing::error!( @@ -787,15 +802,15 @@ impl Infer { tracing::error!("{err}"); Err(err) } else { - let mut sorted_entries: Vec<_> = all_entities.into_iter().collect(); - sorted_entries.sort_by_key(|&(id, _)| id); + let mut sorted_responses: Vec<_> = all_entities.into_iter().collect(); + sorted_responses.sort_by_key(|&(id, _)| id); - let sorted_entities: Vec> = sorted_entries + let sorted_responses: Vec = sorted_responses .into_iter() - .map(|(_, entities)| entities.into_iter().map(Entity::from).collect()) + .map(|(_, response)| response) .collect(); - Ok(sorted_entities) + Ok(sorted_responses) } } @@ -1489,6 +1504,13 @@ impl InferError { } } +#[derive(Debug)] +pub(crate) struct InferClassifyResponse { + pub(crate) predictions: Vec, + pub(crate) queued: Instant, + pub(crate) start: Instant, +} + fn get_tag(token_class: &str) -> (String, String) { // TODO: don't make the null tag hardcoded let parts: Vec<&str> = token_class.split('-').collect(); diff --git a/router/src/server.rs b/router/src/server.rs index 8dd50ec92..552681d32 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -38,7 +38,7 @@ use std::sync::Mutex; use tokenizers::Tokenizer; use tokio::signal; use tokio::sync::mpsc; -use tokio::time::Instant; +use tokio::time::{Duration, Instant}; use tower_http::cors::{ AllowCredentials, AllowHeaders, AllowMethods, AllowOrigin, CorsLayer, ExposeHeaders, }; @@ -345,19 +345,22 @@ example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})), /// Health check method async fn health( infer: Extension, - info: Extension, - request_logger_sender: Extension>>, - req_headers: HeaderMap, health: Extension, ) -> Result<(), (StatusCode, Json)> { if health.shard_info().supports_classification { let classify_request = ClassifyRequest { inputs: "San Francisco".to_string(), }; - match classify(infer.clone(), Json(classify_request)).await { + match infer.classify(classify_request).await { Ok(_) => {} - Err((status, error)) => { - return Err((status, error)); + Err(error) => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: error.to_string(), + error_type: error.error_type().to_string(), + }), + )); } } } @@ -365,10 +368,16 @@ async fn health( let embed_request = EmbedRequest { inputs: "San Francisco".to_string(), }; - match embed(infer.clone(), Json(embed_request)).await { + match infer.embed(embed_request).await { Ok(_) => {} - Err((status, error)) => { - return Err((status, error)); + Err(error) => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: error.to_string(), + error_type: error.error_type().to_string(), + }), + )); } } } @@ -401,17 +410,9 @@ async fn health( ignore_eos_token: false, }, }; - match generate( - infer, - info, - request_logger_sender, - req_headers, - Json(generate_request), - ) - .await - { + match infer.generate(generate_request).await { Ok(response) => { - if response.1.generated_text.len() == 0 { + if response.generated_text.text.len() == 0 { return Err(( StatusCode::INTERNAL_SERVER_ERROR, Json(ErrorResponse { @@ -421,8 +422,14 @@ async fn health( )); } } - Err((status, error)) => { - return Err((status, error)); + Err(error) => { + return Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json(ErrorResponse { + error: error.to_string(), + error_type: error.error_type().to_string(), + }), + )); } } } @@ -1469,11 +1476,59 @@ async fn embed( async fn classify( infer: Extension, Json(req): Json, -) -> Result>, (StatusCode, Json)> { +) -> Result<(HeaderMap, Json>), (StatusCode, Json)> { + let span = tracing::Span::current(); + let start_time = Instant::now(); metrics::increment_counter!("lorax_request_count"); tracing::debug!("Input: {}", req.inputs); let response = infer.classify(req).await?; - Ok(Json(response)) + + // Timings + let total_time = start_time.elapsed(); + let validation_time = response.queued - start_time; + let queue_time = response.start - response.queued; + let inference_time = Instant::now() - response.start; + + // Rust Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("validation_time", format!("{validation_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + + // Headers + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-validation-time", + validation_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-time", + queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + inference_time.as_millis().to_string().parse().unwrap(), + ); + + // Metrics + metrics::increment_counter!("lorax_request_success"); + metrics::histogram!("lorax_request_duration", total_time.as_secs_f64()); + metrics::histogram!( + "lorax_request_validation_duration", + validation_time.as_secs_f64() + ); + metrics::histogram!("lorax_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!( + "lorax_request_inference_duration", + inference_time.as_secs_f64() + ); + + Ok((headers, Json(response.predictions))) } #[utoipa::path( @@ -1490,11 +1545,71 @@ async fn classify( async fn classify_batch( infer: Extension, Json(req): Json, -) -> Result>>, (StatusCode, Json)> { +) -> Result<(HeaderMap, Json>>), (StatusCode, Json)> { + let span = tracing::Span::current(); + let start_time = Instant::now(); metrics::increment_counter!("lorax_request_count"); tracing::debug!("Inputs: {:?}", req.inputs); - let response = infer.classify_batch(req).await?; - Ok(Json(response)) + let responses = infer.classify_batch(req).await?; + + // Timings + let now = Instant::now(); + let total_time = start_time.elapsed(); + let mut validation_times = Vec::with_capacity(responses.len()); + let mut queue_times = Vec::with_capacity(responses.len()); + let mut inference_times = Vec::with_capacity(responses.len()); + + for r in &responses { + validation_times.push(r.queued - r.start); + queue_times.push(r.start - r.queued); + inference_times.push(now - r.start); + } + + let validation_time = validation_times.iter().sum::() / responses.len() as u32; + let queue_time = queue_times.iter().sum::() / responses.len() as u32; + let inference_time = inference_times.iter().sum::() / responses.len() as u32; + + // Rust Tracing metadata + span.record("total_time", format!("{total_time:?}")); + span.record("validation_time", format!("{validation_time:?}")); + span.record("queue_time", format!("{queue_time:?}")); + span.record("inference_time", format!("{inference_time:?}")); + + // Headers + let mut headers = HeaderMap::new(); + headers.insert("x-compute-type", "gpu+optimized".parse().unwrap()); + headers.insert( + "x-compute-time", + total_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-validation-time", + validation_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-queue-time", + queue_time.as_millis().to_string().parse().unwrap(), + ); + headers.insert( + "x-inference-time", + inference_time.as_millis().to_string().parse().unwrap(), + ); + + // Metrics + metrics::increment_counter!("lorax_request_success"); + metrics::histogram!("lorax_request_duration", total_time.as_secs_f64()); + metrics::histogram!( + "lorax_request_validation_duration", + validation_time.as_secs_f64() + ); + metrics::histogram!("lorax_request_queue_duration", queue_time.as_secs_f64()); + metrics::histogram!( + "lorax_request_inference_duration", + inference_time.as_secs_f64() + ); + + let batch_entity_vec = responses.into_iter().map(|r| r.predictions).collect(); + Ok((headers, Json(batch_entity_vec))) } /// Tokenize inputs