diff --git a/router/src/health.rs b/router/src/health.rs index 1c6aaed2d..8d576b99e 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -43,51 +43,86 @@ impl Health { self.client.health().await.is_ok() } else { // Generation is unhealthy or have not sent any generation request yet + // Default to generation + let mut liveness_request = Request { + id: LIVENESS_ID, + inputs: "liveness".to_string(), + tokenized_inputs: None, + truncate: 10, + prefill_logprobs: false, + parameters: Some(NextTokenChooserParameters { + temperature: 1.0, + top_k: 0, + top_p: 1.0, + typical_p: 1.0, + do_sample: false, + seed: 0, + repetition_penalty: 1.0, + watermark: false, + adapter_id: "".to_string(), + schema: None, + return_k_alternatives: 0, + }), + stopping_parameters: Some(StoppingCriteriaParameters { + max_new_tokens: 1, + stop_sequences: vec![], + ignore_eos_token: false, + }), + adapter_index: 0, + // Block 0 is reserved for health checks + blocks: vec![0], + slots: (0..16).collect(), + cache_len: 0, + chunk_len: None, + }; + // Create different requestas based on the type of model this is + if self.shard_info().supports_embeddings { + liveness_request = EmbedRequest { + inputs: "San Francisco".to_string(), + parameters: EmbedParameters { + adapter_id: None, + adapter_source: None, + adapter_parameters: None, + api_token: None, + }, + } + }; + + if self.shard_info().supports_classification { + liveness_request = ClassifyRequest { + inputs: "San Francisco".to_string(), + }; + } // Dummy batch of 1 token and 1 generated token - let liveness_request = Request { - id: LIVENESS_ID, - inputs: "liveness".to_string(), - tokenized_inputs: None, - truncate: 10, - prefill_logprobs: false, - parameters: Some(NextTokenChooserParameters { - temperature: 1.0, - top_k: 0, - top_p: 1.0, - typical_p: 1.0, - do_sample: false, - seed: 0, - repetition_penalty: 1.0, - watermark: false, - adapter_id: "".to_string(), - schema: None, - return_k_alternatives: 0, - }), - stopping_parameters: Some(StoppingCriteriaParameters { - max_new_tokens: 1, - stop_sequences: vec![], - ignore_eos_token: false, - }), - adapter_index: 0, - // Block 0 is reserved for health checks - blocks: vec![0], - slots: (0..16).collect(), - cache_len: 0, - chunk_len: None, - }; let batch = Batch { id: BATCH_ID, requests: vec![liveness_request], size: 1, max_tokens: 2, - max_blocks: 1, }; + // Skips the queue - let value = self.client.prefill(batch, None).await.is_ok(); - // Update generation health - self.generation_health.store(value, Ordering::SeqCst); - value + if self.shard_info().supports_generation { + let value = self.client.prefill(batch, None).await.is_ok(); + // Update generation health + self.generation_health.store(value, Ordering::SeqCst); + return value + } + + if self.shard_info().supports_embeddings { + let value = self.client.embed(batch).await.is_ok(); + // Update generation health + self.generation_health.store(value, Ordering::SeqCst); + return value + } + + if self.shard_info().supports_classification { + let value = self.client.classify(batch).await.is_ok(); + // Update generation health + self.generation_health.store(value, Ordering::SeqCst); + return value + } } } } diff --git a/router/src/server.rs b/router/src/server.rs index 3370114e9..7e5c9e32a 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -553,99 +553,16 @@ async fn health( infer: Extension, health: Extension, ) -> Result<(), (StatusCode, Json)> { - if health.shard_info().supports_classification { - let classify_request = ClassifyRequest { - inputs: "San Francisco".to_string(), - }; - match infer.classify(classify_request).await { - Ok(_) => {} - Err(error) => { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: error.to_string(), - error_type: error.error_type().to_string(), - }), - )); - } - } - } - if health.shard_info().supports_embeddings { - let embed_request = EmbedRequest { - inputs: "San Francisco".to_string(), - parameters: EmbedParameters { - adapter_id: None, - adapter_source: None, - adapter_parameters: None, - api_token: None, - }, - }; - match infer.embed(embed_request).await { - Ok(_) => {} - Err(error) => { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: error.to_string(), - error_type: error.error_type().to_string(), - }), - )); - } - } - } - if health.shard_info().supports_generation { - let generate_request = GenerateRequest { - inputs: "Who?".to_string(), - parameters: GenerateParameters { - adapter_id: None, - adapter_source: None, - adapter_parameters: None, - api_token: None, - best_of: None, - temperature: None, - top_k: None, - top_p: None, - typical_p: None, - do_sample: false, - seed: None, - repetition_penalty: None, - watermark: false, - return_full_text: None, - stop: vec![], - truncate: None, - details: false, - decoder_input_details: false, - return_k_alternatives: None, - apply_chat_template: false, - response_format: None, - max_new_tokens: Some(1), - ignore_eos_token: false, - }, - }; - match infer.generate(generate_request).await { - Ok(response) => { - if response.generated_text.text.len() == 0 { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: "Empty generation".to_string(), - error_type: "failed healthcheck".to_string(), - }), - )); - } - } - Err(error) => { - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - Json(ErrorResponse { - error: error.to_string(), - error_type: error.error_type().to_string(), - }), - )); - } - } - } - Ok(()) + match health.check().await { + true => Ok(()), + false => Err(( + StatusCode::SERVICE_UNAVAILABLE, + Json(ErrorResponse { + error: "unhealthy".to_string(), + error_type: "healthcheck".to_string(), + }), + )), + } } /// Generate tokens