Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added eager prefill option #524

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,12 @@ struct Args {
#[clap(default_value = "20", long, env)]
max_waiting_tokens: usize,

/// Whether to prioritize running prefill before decode to increase batch size during decode (throughput) over
/// liveness in earlier requests (latency). For batch use cases that are not latnecy sensitive, this should be set
/// to true.
#[clap(long, env)]
eager_prefill: Option<bool>,

/// Maximum number of adapters that can be placed on the GPU and accept requests at a time.
#[clap(default_value = "1024", long, env)]
max_active_adapters: usize,
Expand Down Expand Up @@ -1129,6 +1135,10 @@ fn spawn_webserver(
router_args.push("--embedding-model".to_string());
}

if args.eager_prefill.unwrap_or(false) {
router_args.push("--eager-prefill".to_string());
}

// Ngrok
if args.ngrok {
router_args.push("--ngrok".to_string());
Expand Down
29 changes: 23 additions & 6 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl Infer {
requires_padding: bool,
window_size: Option<u32>,
generation_health: Arc<AtomicBool>,
eager_prefill: bool,
) -> Self {
let adapter_event = Arc::new(AdapterEvent {
batching_task: Notify::new(),
Expand Down Expand Up @@ -90,6 +91,7 @@ impl Infer {
adapter_event,
generation_health,
adapter_scheduler.clone(),
eager_prefill,
));

// Inference limit with a semaphore
Expand Down Expand Up @@ -462,6 +464,7 @@ async fn batching_task(
adapter_event: Arc<AdapterEvent>,
generation_health: Arc<AtomicBool>,
adapter_scheduler: AdapterScheduler,
eager_prefill: bool,
) {
// Infinite loop
loop {
Expand Down Expand Up @@ -489,7 +492,7 @@ async fn batching_task(
// all requests have met their stopping criteria)
while let Some(batch) = cached_batch {
// Get current batch info
let batch_size = batch.size;
let mut batch_size = batch.size;
let batch_max_tokens = batch.max_tokens;
let mut batches = vec![batch];
metrics::gauge!("lorax_batch_current_size", batch_size as f64);
Expand All @@ -499,7 +502,7 @@ async fn batching_task(
// TODO(travis): can execute this more efficiently by making it event-driven
adapter_scheduler.remove_errored_adapters().await;

let min_size = if waiting_tokens >= max_waiting_tokens {
let min_size = if waiting_tokens >= max_waiting_tokens || eager_prefill {
// If we didn't onboard any new requests since >= max_waiting_tokens, we try
// to add a new batch even though its size might be small
None
Expand All @@ -508,19 +511,25 @@ async fn batching_task(
Some((batch_size as f32 * waiting_served_ratio).floor() as usize)
};

let token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let adapters_in_use = batch_entries.adapters_in_use();
let mut token_budget = max_batch_total_tokens.saturating_sub(batch_max_tokens);
let mut adapters_in_use = batch_entries.adapters_in_use();

// Try to get a new batch
if let Some((mut new_entries, new_batch, span)) = adapter_scheduler
while let Some((mut new_entries, new_batch, span)) = adapter_scheduler
.next_batch(
adapters_in_use,
adapters_in_use.clone(),
min_size,
max_batch_prefill_tokens,
token_budget,
)
.await
{
let new_batch_size = new_batch.size;
batch_size += new_batch_size;

let new_batch_max_tokens = new_batch.max_tokens;
token_budget = token_budget.saturating_sub(new_batch_max_tokens);

// Tracking metrics
if min_size.is_some() {
metrics::increment_counter!("lorax_batch_concat", "reason" => "backpressure");
Expand All @@ -547,13 +556,21 @@ async fn batching_task(
let new_cached_batch = new_entries
.process_first(&mut client, new_batch, span, &generation_health)
.await;

adapters_in_use.extend(new_entries.adapters_in_use());

// Reset waiting counter
waiting_tokens = 1;
// Extend current batch with the new batch
if let Some(new_cached_batch) = new_cached_batch {
batch_entries.extend(new_entries);
batches.push(new_cached_batch);
}

if !eager_prefill {
// Do not continue to loop if we are not in eager prefill mode
break;
}
}

// Create span for this batch to add context to inference calls
Expand Down
2 changes: 2 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ pub struct Info {
pub max_waiting_tokens: usize,
#[schema(example = "2")]
pub validation_workers: usize,
#[schema(example = false)]
pub eager_prefill: bool,
/// Router Info
#[schema(example = "0.5.0")]
pub version: &'static str,
Expand Down
4 changes: 4 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ struct Args {
ngrok_edge: Option<String>,
#[clap(default_value = "hub", long, env)]
adapter_source: String,
#[clap(long, env)]
eager_prefill: bool,
}

#[tokio::main]
Expand Down Expand Up @@ -122,6 +124,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_authtoken,
ngrok_edge,
adapter_source,
eager_prefill,
} = args;

init_logging(otlp_endpoint, json_output);
Expand Down Expand Up @@ -376,6 +379,7 @@ async fn main() -> Result<(), RouterError> {
ngrok_edge,
adapter_source,
embedding_model,
eager_prefill,
)
.await?;
Ok(())
Expand Down
3 changes: 3 additions & 0 deletions router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,7 @@ pub async fn run(
ngrok_edge: Option<String>,
adapter_source: String,
embedding_model: bool,
eager_prefill: bool,
) -> Result<(), axum::BoxError> {
// OpenAPI documentation
#[derive(OpenApi)]
Expand Down Expand Up @@ -1049,6 +1050,7 @@ pub async fn run(
shard_info.requires_padding,
shard_info.window_size,
generation_health,
eager_prefill,
);

// Duration buckets
Expand Down Expand Up @@ -1145,6 +1147,7 @@ pub async fn run(
docker_label: option_env!("DOCKER_LABEL"),
request_logger_url: std::env::var("REQUEST_LOGGER_URL").ok(),
embedding_model,
eager_prefill,
};

DEFAULT_ADAPTER_SOURCE
Expand Down
Loading