Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize healthcheck #664

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 105 additions & 44 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,56 +38,117 @@ impl Health {

#[allow(dead_code)]
pub(crate) async fn check(&mut self) -> bool {
// The server will put data into self.generation_health whenever we get something back from the model.
// We fail the health check if the if there were failures coming back from the model server.
// The "else" statement is only done before the router has recieved any traffic.
if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls
self.client.health().await.is_ok()
} else {
// Generation is unhealthy or have not sent any generation request yet
if self.shard_info().supports_generation {
let mut liveness_request = Request {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
tokenized_inputs: None,
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
adapter_id: "".to_string(),
schema: None,
return_k_alternatives: 0,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
adapter_index: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
cache_len: 0,
chunk_len: None,
};
// Dummy batch of 1 token and 1 generated token
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
let value = self.client.prefill(batch, None).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
return value;
}

// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
tokenized_inputs: None,
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
adapter_id: "".to_string(),
schema: None,
return_k_alternatives: 0,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
adapter_index: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
cache_len: 0,
chunk_len: None,
// Create different requestas based on the type of model this is
if self.shard_info().supports_embeddings {
let liveness_request = Request {
id: LIVENESS_ID,
prefill_logprobs: false,
inputs: "San Francisco".to_string(),
tokenized_inputs: None, // Tokenization happens on the model server instead
truncate: 0,
parameters: None,
stopping_parameters: None,
adapter_index: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
cache_len: 0,
chunk_len: None,
};
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
let value = self.client.embed(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
return value;
};
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};
// Skips the queue
let value = self.client.prefill(batch, None).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value

if self.shard_info().supports_classification {
let liveness_request = Request {
id: LIVENESS_ID,
prefill_logprobs: false,
inputs: "San Francisco".to_string(),
tokenized_inputs: None, // Tokenization happens on the model server instead
truncate: 0,
parameters: None,
stopping_parameters: None,
adapter_index: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
cache_len: 0,
chunk_len: None,
};
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
};
let value = self.client.classify(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
return value;
}

// Return false - need to implement that shard type.
return false;
}
}
}
108 changes: 12 additions & 96 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice,
ChatMessage, ClassifyRequest, CompatGenerateRequest, CompletionFinishReason, CompletionRequest,
CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice,
CompletionStreamResponse, Details, EmbedParameters, EmbedRequest, EmbedResponse, Entity,
CompletionStreamResponse, Details, EmbedRequest, EmbedResponse, Entity,
ErrorResponse, FinishReason, FunctionDefinition, GenerateParameters, GenerateRequest,
GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message,
OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken,
Expand Down Expand Up @@ -550,102 +550,18 @@ example = json ! ({"error": "unhealthy", "error_type": "healthcheck"})),
)]
/// Health check method
async fn health(
infer: Extension<Infer>,
health: Extension<Health>,
mut health: Extension<Health>,
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
if health.shard_info().supports_classification {
let classify_request = ClassifyRequest {
inputs: "San Francisco".to_string(),
};
match infer.classify(classify_request).await {
Ok(_) => {}
Err(error) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: error.to_string(),
error_type: error.error_type().to_string(),
}),
));
}
}
}
if health.shard_info().supports_embeddings {
let embed_request = EmbedRequest {
inputs: "San Francisco".to_string(),
parameters: EmbedParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
},
};
match infer.embed(embed_request).await {
Ok(_) => {}
Err(error) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: error.to_string(),
error_type: error.error_type().to_string(),
}),
));
}
}
}
if health.shard_info().supports_generation {
let generate_request = GenerateRequest {
inputs: "Who?".to_string(),
parameters: GenerateParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
best_of: None,
temperature: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
seed: None,
repetition_penalty: None,
watermark: false,
return_full_text: None,
stop: vec![],
truncate: None,
details: false,
decoder_input_details: false,
return_k_alternatives: None,
apply_chat_template: false,
response_format: None,
max_new_tokens: Some(1),
ignore_eos_token: false,
},
};
match infer.generate(generate_request).await {
Ok(response) => {
if response.generated_text.text.len() == 0 {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Empty generation".to_string(),
error_type: "failed healthcheck".to_string(),
}),
));
}
}
Err(error) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: error.to_string(),
error_type: error.error_type().to_string(),
}),
));
}
}
}
Ok(())
match health.check().await {
true => Ok(()),
false => Err((
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "unhealthy".to_string(),
error_type: "healthcheck".to_string(),
}),
)),
}
}

/// Generate tokens
Expand Down
Loading