Skip to content

Commit

Permalink
optimize healthcheck
Browse files Browse the repository at this point in the history
  • Loading branch information
noyoshi committed Nov 1, 2024
1 parent bd92e52 commit a2e14ce
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 129 deletions.
107 changes: 71 additions & 36 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,51 +43,86 @@ impl Health {
self.client.health().await.is_ok()
} else {
// Generation is unhealthy or have not sent any generation request yet
// Default to generation
let mut liveness_request = Request {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
tokenized_inputs: None,
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
adapter_id: "".to_string(),
schema: None,
return_k_alternatives: 0,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
adapter_index: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
cache_len: 0,
chunk_len: None,
};
// Create different requestas based on the type of model this is
if self.shard_info().supports_embeddings {
liveness_request = EmbedRequest {
inputs: "San Francisco".to_string(),
parameters: EmbedParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
},
}
};

if self.shard_info().supports_classification {
liveness_request = ClassifyRequest {
inputs: "San Francisco".to_string(),
};
}

// Dummy batch of 1 token and 1 generated token
let liveness_request = Request {
id: LIVENESS_ID,
inputs: "liveness".to_string(),
tokenized_inputs: None,
truncate: 10,
prefill_logprobs: false,
parameters: Some(NextTokenChooserParameters {
temperature: 1.0,
top_k: 0,
top_p: 1.0,
typical_p: 1.0,
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
watermark: false,
adapter_id: "".to_string(),
schema: None,
return_k_alternatives: 0,
}),
stopping_parameters: Some(StoppingCriteriaParameters {
max_new_tokens: 1,
stop_sequences: vec![],
ignore_eos_token: false,
}),
adapter_index: 0,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
cache_len: 0,
chunk_len: None,
};
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};

// Skips the queue
let value = self.client.prefill(batch, None).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
if self.shard_info().supports_generation {
let value = self.client.prefill(batch, None).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
return value
}

if self.shard_info().supports_embeddings {
let value = self.client.embed(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
return value
}

if self.shard_info().supports_classification {
let value = self.client.classify(batch).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
return value
}
}
}
}
103 changes: 10 additions & 93 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,99 +553,16 @@ async fn health(
infer: Extension<Infer>,
health: Extension<Health>,
) -> Result<(), (StatusCode, Json<ErrorResponse>)> {
if health.shard_info().supports_classification {
let classify_request = ClassifyRequest {
inputs: "San Francisco".to_string(),
};
match infer.classify(classify_request).await {
Ok(_) => {}
Err(error) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: error.to_string(),
error_type: error.error_type().to_string(),
}),
));
}
}
}
if health.shard_info().supports_embeddings {
let embed_request = EmbedRequest {
inputs: "San Francisco".to_string(),
parameters: EmbedParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
},
};
match infer.embed(embed_request).await {
Ok(_) => {}
Err(error) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: error.to_string(),
error_type: error.error_type().to_string(),
}),
));
}
}
}
if health.shard_info().supports_generation {
let generate_request = GenerateRequest {
inputs: "Who?".to_string(),
parameters: GenerateParameters {
adapter_id: None,
adapter_source: None,
adapter_parameters: None,
api_token: None,
best_of: None,
temperature: None,
top_k: None,
top_p: None,
typical_p: None,
do_sample: false,
seed: None,
repetition_penalty: None,
watermark: false,
return_full_text: None,
stop: vec![],
truncate: None,
details: false,
decoder_input_details: false,
return_k_alternatives: None,
apply_chat_template: false,
response_format: None,
max_new_tokens: Some(1),
ignore_eos_token: false,
},
};
match infer.generate(generate_request).await {
Ok(response) => {
if response.generated_text.text.len() == 0 {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: "Empty generation".to_string(),
error_type: "failed healthcheck".to_string(),
}),
));
}
}
Err(error) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: error.to_string(),
error_type: error.error_type().to_string(),
}),
));
}
}
}
Ok(())
match health.check().await {
true => Ok(()),
false => Err((
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "unhealthy".to_string(),
error_type: "healthcheck".to_string(),
}),
)),
}
}

/// Generate tokens
Expand Down

0 comments on commit a2e14ce

Please sign in to comment.