Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunked prefill #653

Merged
merged 34 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5ddcb30
WIP: router chunked prefill
tgaddair Oct 17, 2024
1d76aa8
WIP: Seqlen
tgaddair Oct 17, 2024
77b0ce6
WIP: server
tgaddair Oct 17, 2024
b02a01d
Graph update
tgaddair Oct 17, 2024
652eb8f
WIP: generate_tokens
tgaddair Oct 17, 2024
996fd2d
Fix prepare_for_prefill
tgaddair Oct 17, 2024
582be6c
Plumbing
tgaddair Oct 17, 2024
445b2c2
InfoResponse
tgaddair Oct 17, 2024
b87fe87
Seqlen for llama
tgaddair Oct 17, 2024
b95d6ed
Rename
tgaddair Oct 17, 2024
5720b3f
Responses are correct
tgaddair Oct 17, 2024
9d75bb0
Fixed compile
tgaddair Oct 17, 2024
8d413db
Fix flashinfer
tgaddair Oct 17, 2024
053c3e5
Fix prefill
tgaddair Oct 17, 2024
2834f5c
Fix chunking
tgaddair Oct 17, 2024
9a45299
Docker build
tgaddair Oct 17, 2024
147e8f9
Fix id
tgaddair Oct 17, 2024
8a69aaf
Added missing file
tgaddair Oct 17, 2024
5467678
Docker
tgaddair Oct 17, 2024
80f47ce
Fix
tgaddair Oct 18, 2024
a07b9b1
Warnings
tgaddair Oct 18, 2024
982fd52
Fix flashinfer graph retrace
tgaddair Oct 18, 2024
2621897
In-place softmax
tgaddair Oct 18, 2024
e0b0177
Fix concatenate
tgaddair Oct 18, 2024
46b06b7
Rename prefill chunking -> chunked prefill
tgaddair Oct 18, 2024
0fc2641
Added launcher args
tgaddair Oct 18, 2024
536ab5f
Revert docker
tgaddair Oct 18, 2024
646463e
Fix vlms
tgaddair Oct 18, 2024
65076e2
Seqlen
tgaddair Oct 18, 2024
787be03
Fixed llava next
tgaddair Oct 18, 2024
fcfa679
Mllama
tgaddair Oct 18, 2024
df5ae30
Fixed embeddings
tgaddair Oct 18, 2024
401f1ae
input_lengths -> seqlen
tgaddair Oct 18, 2024
1d9f5d4
ruff
tgaddair Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,11 @@ struct Args {
#[clap(long, env)]
eager_prefill: Option<bool>,

/// Split prefill requests into multiple chunks and batch them with decode requests. For high QPS scenarios, this
/// can greatly improve throughput by overlapping request types. See: https://arxiv.org/pdf/2308.16369.
#[clap(long, env)]
chunked_prefill: Option<bool>,

/// Whether to use the prefix caching mechanism. This will skip computing attention on previously cached prefixes
/// in the prompt. Useful in cases where many queries need to be run over a shared context, or for long multi-turn
/// chats conversations.
Expand Down Expand Up @@ -496,6 +501,7 @@ fn shard_manager(
cuda_memory_fraction: f32,
adapter_memory_fraction: f32,
prefix_caching: Option<bool>,
chunked_prefill: Option<bool>,
merge_adapter_weights: bool,
backend: Backend,
otlp_endpoint: Option<String>,
Expand Down Expand Up @@ -639,6 +645,11 @@ fn shard_manager(
envs.push(("PREFIX_CACHING".into(), prefix_caching.to_string().into()));
}

// Chunked prefill
if let Some(chunked_prefill) = chunked_prefill {
envs.push(("CHUNKED_PREFILL".into(), chunked_prefill.to_string().into()));
}

// Backend
if backend == Backend::FlashInfer {
envs.push(("FLASH_INFER".into(), "1".into()));
Expand Down Expand Up @@ -1093,6 +1104,7 @@ fn spawn_shards(
let cuda_memory_fraction = args.cuda_memory_fraction;
let adapter_memory_fraction = args.adapter_memory_fraction;
let prefix_caching = args.prefix_caching;
let chunked_prefill = args.chunked_prefill;
let merge_adapter_weights = args.merge_adapter_weights;
let backend = args.backend;
let embedding_dim = args.embedding_dim;
Expand Down Expand Up @@ -1125,6 +1137,7 @@ fn spawn_shards(
cuda_memory_fraction,
adapter_memory_fraction,
prefix_caching,
chunked_prefill,
merge_adapter_weights,
backend,
otlp_endpoint,
Expand Down
29 changes: 24 additions & 5 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ message InfoResponse {
bool supports_generation = 8;
bool supports_embeddings = 9;
bool supports_classification = 10;
bool chunked_prefill = 11;
}

/// Empty request
Expand Down Expand Up @@ -156,8 +157,12 @@ message Request {
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// Prefix length that can be retrieved from the KV cache
uint32 prefix_len = 11;
/// Tokens that can be retrieved from the KV cache.
/// This value is set for the first prefill and never reset
uint32 cache_len = 11;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 12;
}

message Batch {
Expand All @@ -182,6 +187,8 @@ message CachedBatch {
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
}

enum FinishReason {
Expand Down Expand Up @@ -261,13 +268,25 @@ message FilterBatchResponse {
message PrefillRequest {
/// Batch
Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
}

message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;

// TODO(travis): add timings
// /// Forward elapsed time in nanoseconds
// uint64 forward_ns = 3;
// /// Decode elapsed time in nanoseconds
// uint64 decode_ns = 4;
// /// Total elapsed time in nanoseconds
// uint64 total_ns = 5;
// /// Concatenate elapsed time in nanoseconds
// optional uint64 concat_ns = 6;
}

message DecodeRequest {
Expand Down Expand Up @@ -344,9 +363,9 @@ message ClassifyResponse {
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;

/// Maximum number of new tokens to warmup
uint32 max_new_tokens = 2;
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_new_tokens = 4;
}

/// Empty response
Expand Down
12 changes: 10 additions & 2 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
cache_len: 0,
chunk_len: None,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
Expand Down Expand Up @@ -180,6 +181,8 @@ impl Client {
let max_new_tokens = max_total_tokens - max_input_length;
let request = tonic::Request::new(WarmupRequest {
batch: Some(batch),
max_input_length,
max_prefill_tokens,
max_new_tokens,
})
.inject_context();
Expand All @@ -195,8 +198,13 @@ impl Client {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context();
let request = tonic::Request::new(PrefillRequest {
batch: Some(batch),
cached_batch,
})
.inject_context();
let response = self.stub.prefill(request).await?.into_inner();
Ok((response.generations, response.batch))
}
Expand Down
3 changes: 2 additions & 1 deletion router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,12 @@ impl ShardedClient {
pub async fn prefill(
&mut self,
batch: Batch,
cached_batch: Option<CachedBatch>,
) -> Result<(Vec<Generation>, Option<CachedBatch>)> {
let futures: Vec<_> = self
.clients
.iter_mut()
.map(|client| Box::pin(client.prefill(batch.clone())))
.map(|client| Box::pin(client.prefill(batch.clone(), cached_batch.clone())))
.collect();
let results: Result<Vec<(Vec<Generation>, Option<CachedBatch>)>> =
join_all(futures).await.into_iter().collect();
Expand Down
14 changes: 11 additions & 3 deletions router/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ pub(crate) trait BatchEntries: Sync + Send + Debug {
&mut self,
client: &mut ShardedClient,
batch: Batch,
cached_batch: Option<CachedBatch>,
span: Span,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch>;
Expand Down Expand Up @@ -341,7 +342,8 @@ impl BatchEntries for GenerateBatchEntries {
adapter_index: adapter.index(),
blocks,
slots,
prefix_len,
cache_len: prefix_len,
chunk_len: None,
};

self.state.add(id, entry, adapter, request_proto);
Expand Down Expand Up @@ -385,12 +387,14 @@ impl BatchEntries for GenerateBatchEntries {
&mut self,
client: &mut ShardedClient,
batch: Batch,
cached_batch: Option<CachedBatch>,
span: Span,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
prefill(
client,
batch,
cached_batch,
&mut self.state.batch_entries,
&generation_health,
)
Expand Down Expand Up @@ -470,7 +474,8 @@ impl BatchEntries for EmbedBatchEntries {
adapter_index: adapter.index(),
blocks,
slots,
prefix_len,
cache_len: prefix_len,
chunk_len: None,
};

self.state.add(id, entry, adapter, request_proto);
Expand Down Expand Up @@ -514,6 +519,7 @@ impl BatchEntries for EmbedBatchEntries {
&mut self,
client: &mut ShardedClient,
batch: Batch,
_cached_batch: Option<CachedBatch>,
span: Span,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
Expand Down Expand Up @@ -593,7 +599,8 @@ impl BatchEntries for ClassifyBatchEntries {
adapter_index: adapter.index(),
blocks,
slots,
prefix_len,
cache_len: prefix_len,
chunk_len: None,
};

self.state.add(id, entry, adapter, request_proto);
Expand Down Expand Up @@ -637,6 +644,7 @@ impl BatchEntries for ClassifyBatchEntries {
&mut self,
client: &mut ShardedClient,
batch: Batch,
_cached_batch: Option<CachedBatch>,
span: Span,
generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
Expand Down
5 changes: 3 additions & 2 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ impl Health {
// Block 0 is reserved for health checks
blocks: vec![0],
slots: (0..16).collect(),
prefix_len: 0,
cache_len: 0,
chunk_len: None,
};
let batch = Batch {
id: BATCH_ID,
Expand All @@ -83,7 +84,7 @@ impl Health {
max_blocks: 1,
};
// Skips the queue
let value = self.client.prefill(batch).await.is_ok();
let value = self.client.prefill(batch, None).await.is_ok();
// Update generation health
self.generation_health.store(value, Ordering::SeqCst);
value
Expand Down
Loading
Loading