Skip to content

Commit

Permalink
Move kv cache allocation to router to ensure correct block allocation (
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Jul 19, 2024
1 parent 5c25e26 commit 5a7a1be
Show file tree
Hide file tree
Showing 15 changed files with 460 additions and 319 deletions.
8 changes: 8 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ message InfoResponse {
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 block_size = 5;
uint32 speculate = 6;
}

/// Empty request
Expand Down Expand Up @@ -112,6 +114,10 @@ message Request {
uint32 adapter_index = 7;
/// Apply chat template to inputs
bool apply_chat_template = 8;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
}

message Batch {
Expand All @@ -123,6 +129,8 @@ message Batch {
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks
uint32 max_blocks = 5;
}

message CachedBatch {
Expand Down
6 changes: 5 additions & 1 deletion router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ impl Client {
id: 0,
inputs: "_test ".to_string().repeat(max_input_length as usize),
truncate: truncate_length,
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
Expand Down Expand Up @@ -148,7 +151,8 @@ impl Client {
id: 0,
size: requests.len() as u32,
requests,
max_tokens: 0,
max_tokens: max_input_length,
max_blocks: 0,
};

let max_new_tokens = max_total_tokens - max_input_length;
Expand Down
37 changes: 25 additions & 12 deletions router/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tracing::{info_span, span, Instrument, Span};

use crate::{
adapter::Adapter,
block_allocator::BlockAllocation,
infer::{classify, decode, embed, prefill, InferError, InferStreamResponse},
};

Expand Down Expand Up @@ -135,6 +136,8 @@ pub(crate) struct Entry {
pub queue_time: Instant,
/// Instant when this entry was added to a batch
pub batch_time: Option<Instant>,
/// Block Allocation
pub block_allocation: Option<BlockAllocation>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -186,7 +189,7 @@ impl BatchEntriesState {
entries
}

fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch {
fn create_batch_data(&self, batch_id: u64, max_tokens: u32, max_blocks: u32) -> Batch {
// Final batch size
let size = self.len() as u32;

Expand All @@ -196,6 +199,7 @@ impl BatchEntriesState {
requests: self.batch_requests.clone(),
size,
max_tokens,
max_blocks,
}
}

Expand All @@ -218,10 +222,10 @@ impl BatchEntriesState {
#[async_trait]
pub(crate) trait BatchEntries: Sync + Send + Debug {
fn can_add(&self, entry: &Entry) -> bool;
fn add(&mut self, id: u64, entry: Entry, adapter: Adapter);
fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec<u32>, slots: Vec<u32>);
fn extend(&mut self, entries: Box<dyn BatchEntries>);
fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>;
fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch;
fn create_batch_data(&self, batch_id: u64, max_tokens: u32, max_blocks: u32) -> Batch;
fn adapters_in_use(&self) -> HashSet<Adapter>;
fn is_empty(&self) -> bool;
fn len(&self) -> usize;
Expand Down Expand Up @@ -272,7 +276,7 @@ impl BatchEntries for GenerateBatchEntries {
result
}

fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) {
fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec<u32>, slots: Vec<u32>) {
let valid_request = entry
.request
.as_ref()
Expand All @@ -289,6 +293,8 @@ impl BatchEntries for GenerateBatchEntries {
stopping_parameters: Some(request.stopping_parameters.clone()),
adapter_index: adapter.index(),
apply_chat_template: request.apply_chat_template,
blocks,
slots,
};

self.state.add(id, entry, adapter, request_proto);
Expand All @@ -303,8 +309,9 @@ impl BatchEntries for GenerateBatchEntries {
self.state.drain()
}

fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch {
self.state.create_batch_data(batch_id, max_tokens)
fn create_batch_data(&self, batch_id: u64, max_tokens: u32, max_blocks: u32) -> Batch {
self.state
.create_batch_data(batch_id, max_tokens, max_blocks)
}

fn adapters_in_use(&self) -> HashSet<Adapter> {
Expand Down Expand Up @@ -389,7 +396,7 @@ impl BatchEntries for EmbedBatchEntries {
result
}

fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) {
fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec<u32>, slots: Vec<u32>) {
let valid_request = entry
.request
.as_ref()
Expand All @@ -406,6 +413,8 @@ impl BatchEntries for EmbedBatchEntries {
stopping_parameters: None,
adapter_index: adapter.index(),
apply_chat_template: false,
blocks,
slots,
};

self.state.add(id, entry, adapter, request_proto);
Expand All @@ -420,8 +429,9 @@ impl BatchEntries for EmbedBatchEntries {
self.state.drain()
}

fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch {
self.state.create_batch_data(batch_id, max_tokens)
fn create_batch_data(&self, batch_id: u64, max_tokens: u32, max_blocks: u32) -> Batch {
self.state
.create_batch_data(batch_id, max_tokens, max_blocks)
}

fn adapters_in_use(&self) -> HashSet<Adapter> {
Expand Down Expand Up @@ -500,7 +510,7 @@ impl BatchEntries for ClassifyBatchEntries {
result
}

fn add(&mut self, id: u64, entry: Entry, adapter: Adapter) {
fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec<u32>, slots: Vec<u32>) {
let valid_request = entry
.request
.as_ref()
Expand All @@ -517,6 +527,8 @@ impl BatchEntries for ClassifyBatchEntries {
stopping_parameters: None,
adapter_index: adapter.index(),
apply_chat_template: false,
blocks,
slots,
};

self.state.add(id, entry, adapter, request_proto);
Expand All @@ -531,8 +543,9 @@ impl BatchEntries for ClassifyBatchEntries {
self.state.drain()
}

fn create_batch_data(&self, batch_id: u64, max_tokens: u32) -> Batch {
self.state.create_batch_data(batch_id, max_tokens)
fn create_batch_data(&self, batch_id: u64, max_tokens: u32, max_blocks: u32) -> Batch {
self.state
.create_batch_data(batch_id, max_tokens, max_blocks)
}

fn adapters_in_use(&self) -> HashSet<Adapter> {
Expand Down
136 changes: 136 additions & 0 deletions router/src/block_allocator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use std::cmp::min;
use tokio::sync::{mpsc, oneshot};

#[derive(Debug, Clone)]
pub(crate) struct BlockAllocation {
pub blocks: Vec<u32>,
pub slots: Vec<u32>,
block_allocator: BlockAllocator,
}

impl Drop for BlockAllocation {
fn drop(&mut self) {
self.block_allocator.free(self.blocks.clone())
}
}

#[derive(Debug, Clone)]
pub(crate) struct BlockAllocator {
/// Channel to communicate with the background task
block_allocator: mpsc::UnboundedSender<BlockAllocatorCommand>,
}

impl BlockAllocator {
pub(crate) fn new(
max_batch_total_tokens: u32,
block_size: u32,
window_size: Option<u32>,
) -> Self {
// Create channel
let (sender, receiver) = mpsc::unbounded_channel();

// Launch background queue task
tokio::spawn(block_allocator_task(
max_batch_total_tokens / block_size,
block_size,
window_size,
receiver,
));

Self {
block_allocator: sender,
}
}

pub(crate) async fn allocate(&self, tokens: u32) -> Option<BlockAllocation> {
let (response_sender, response_receiver) = oneshot::channel();
self.block_allocator
.send(BlockAllocatorCommand::Allocate {
tokens,
response_sender,
})
.unwrap();

response_receiver
.await
.unwrap()
.map(|(blocks, slots)| BlockAllocation {
blocks,
slots,
block_allocator: self.clone(),
})
}

pub(crate) fn free(&self, blocks: Vec<u32>) {
self.block_allocator
.send(BlockAllocatorCommand::Free { blocks })
.unwrap();
}
}

async fn block_allocator_task(
blocks: u32,
block_size: u32,
window_size: Option<u32>,
mut receiver: mpsc::UnboundedReceiver<BlockAllocatorCommand>,
) {
// Block 0 is reserved for health checks
let mut free_blocks: Vec<u32> = (1..blocks).collect();
while let Some(cmd) = receiver.recv().await {
match cmd {
BlockAllocatorCommand::Free { blocks } => free_blocks.extend(blocks),
BlockAllocatorCommand::Allocate {
tokens,
response_sender,
} => {
// Apply window size
let (required_blocks, repeats) = {
let (tokens, repeats) = match window_size {
None => (tokens, 1),
Some(window_size) => {
let repeats = (tokens + window_size - 1) / window_size;
let tokens = min(tokens, window_size);
(tokens, repeats as usize)
}
};
// Pad to a multiple of block size
let required_blocks = (tokens + block_size - 1) / block_size;
(required_blocks, repeats)
};

let tokens = tokens as usize;
let allocation = if required_blocks > free_blocks.len() as u32 {
None
} else {
let blocks =
free_blocks.split_off(free_blocks.len() - required_blocks as usize);
let mut slots = Vec::with_capacity(
(required_blocks * block_size * repeats as u32) as usize,
);

'slots: for block_id in blocks.repeat(repeats).iter() {
for s in (block_id * block_size)..((block_id + 1) * block_size) {
slots.push(s);
if slots.len() == tokens {
break 'slots;
}
}
}
Some((blocks, slots))
};
response_sender.send(allocation).unwrap();
}
}
}
}

#[derive(Debug)]
enum BlockAllocatorCommand {
Free {
blocks: Vec<u32>,
},
Allocate {
tokens: u32,
response_sender: oneshot::Sender<Option<(Vec<u32>, Vec<u32>)>>,
},
}
4 changes: 4 additions & 0 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,16 @@ impl Health {
}),
adapter_index: 0,
apply_chat_template: false,
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
};
let batch = Batch {
id: BATCH_ID,
requests: vec![liveness_request],
size: 1,
max_tokens: 2,
max_blocks: 1,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
Expand Down
9 changes: 8 additions & 1 deletion router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ impl Infer {
generation_health: Arc<AtomicBool>,
eager_prefill: bool,
preloaded_adapter_ids: Vec<String>,
block_size: u32,
speculate: u32,
) -> Self {
let adapter_event = Arc::new(AdapterEvent {
batching_task: Notify::new(),
Expand All @@ -70,10 +72,12 @@ impl Infer {
client.clone(),
adapter_event.clone(),
requires_padding,
16,
block_size,
window_size,
max_active_adapters,
adapter_cycle_time_s,
speculate,
max_batch_total_tokens,
);

// Initialize with base model adapter (empty) mapping to index 0
Expand Down Expand Up @@ -194,6 +198,7 @@ impl Infer {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
},
);

Expand Down Expand Up @@ -385,6 +390,7 @@ impl Infer {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
},
);

Expand Down Expand Up @@ -482,6 +488,7 @@ impl Infer {
temp_span: None,
queue_time: Instant::now(),
batch_time: None,
block_allocation: None,
},
);

Expand Down
1 change: 1 addition & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/// LoRAX Webserver
mod adapter;
mod batch;
mod block_allocator;
mod health;
mod infer;
mod loader;
Expand Down
Loading

0 comments on commit 5a7a1be

Please sign in to comment.