diff --git a/router/src/infer.rs b/router/src/infer.rs index 0b234550c..b1f751a28 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -161,6 +161,7 @@ impl Infer { speculate: u32, preloaded_adapters: Vec, prefix_caching: bool, + is_causal_lm: bool, ) -> Self { let adapter_event = Arc::new(AdapterEvent { batching_task: Notify::new(), @@ -178,6 +179,7 @@ impl Infer { speculate, max_batch_total_tokens, prefix_caching, + is_causal_lm, ); // Initialize with base model adapter (empty) mapping to index 0 @@ -729,13 +731,19 @@ impl Infer { .map(|(id, input)| (id as u64, input.clone())) .collect(); - for (id, r_inputs) in request.inputs.iter().enumerate() { - let inputs = r_inputs.to_string().clone(); - let (tokenized_inputs, input_length) = self - .validation - .validate_input(r_inputs.to_string(), None, Some(1)) - .await?; + // Call validate_input on every input in the request and await the results + let futures: Vec<_> = request + .inputs + .iter() + .map(|input| self.validation.validate_input(input.clone(), None, Some(1))) + .collect(); + let all_tokenized_inputs = try_join_all(futures).await?; + + for ((id, r_inputs), (tokenized_inputs, input_length)) in + request.inputs.iter().enumerate().zip(all_tokenized_inputs) + { + let inputs = r_inputs.to_string().clone(); let valid_request = ValidClassifyRequest { inputs, tokenized_inputs, diff --git a/router/src/scheduler.rs b/router/src/scheduler.rs index 37864c452..08b274f3a 100644 --- a/router/src/scheduler.rs +++ b/router/src/scheduler.rs @@ -44,6 +44,7 @@ impl AdapterScheduler { speculate: u32, max_batch_total_tokens: u32, prefix_caching: bool, + is_causal_lm: bool, ) -> Self { let (sender, receiver) = flume::unbounded(); @@ -60,6 +61,7 @@ impl AdapterScheduler { speculate, max_batch_total_tokens, prefix_caching, + is_causal_lm, )); Self { sender } @@ -124,6 +126,7 @@ async fn adapter_scheduler_task( speculate: u32, max_batch_total_tokens: u32, prefix_caching: bool, + is_causal_lm: bool, ) { let mut state = AdapterSchedulerState::new( client, @@ -135,6 +138,7 @@ async fn adapter_scheduler_task( speculate, max_batch_total_tokens, prefix_caching, + is_causal_lm, ); while let Ok(cmd) = receiver.recv_async().await { @@ -209,6 +213,7 @@ impl AdapterSchedulerState { speculate: u32, max_batch_total_tokens: u32, prefix_caching: bool, + is_causal_lm: bool, ) -> Self { let queues_state = Arc::new(Mutex::new(AdapterQueuesState::new( max_active_adapters, @@ -216,7 +221,8 @@ impl AdapterSchedulerState { ))); let loader = AdapterLoader::new(client.clone()); - let block_allocator = (!requires_padding).then(|| { + // Only causal LMs require the block allocator, due to paged attention + let block_allocator = (!requires_padding && is_causal_lm).then(|| { BlockAllocator::new( max_batch_total_tokens, block_size, diff --git a/router/src/server.rs b/router/src/server.rs index 552681d32..7195ed532 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -1134,12 +1134,21 @@ pub async fn run( generation_health.clone(), shard_info.clone(), ); + + // For non-causal LMs, the max batch total tokens is equal to the max batch prefill tokens + let is_causal_lm = shard_info.supports_generation; + let effective_max_batch_total_tokens = if is_causal_lm { + max_batch_total_tokens + } else { + max_batch_prefill_tokens + }; + let infer = Infer::new( client.clone(), validation, waiting_served_ratio, max_batch_prefill_tokens, - max_batch_total_tokens, + effective_max_batch_total_tokens, max_waiting_tokens, max_concurrent_requests, max_active_adapters, @@ -1154,6 +1163,7 @@ pub async fn run( shard_info.speculate, shard_info.preloaded_adapters, prefix_caching, + is_causal_lm, ); // Duration buckets