Skip to content

Commit

Permalink
Add prefix caching (#581)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Aug 21, 2024
1 parent 76c6e71 commit 1d2b514
Show file tree
Hide file tree
Showing 38 changed files with 1,344 additions and 139 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ struct Args {
#[clap(long, env)]
eager_prefill: Option<bool>,

/// Whether to use the prefix caching mechanism.
/// TODO(travis): better comment here
#[clap(long, env)]
prefix_caching: Option<bool>,

/// Maximum number of adapters that can be placed on the GPU and accept requests at a time.
#[clap(default_value = "1024", long, env)]
max_active_adapters: usize,
Expand Down Expand Up @@ -440,6 +445,7 @@ fn shard_manager(
watermark_delta: Option<f32>,
cuda_memory_fraction: f32,
adapter_memory_fraction: f32,
prefix_caching: Option<bool>,
otlp_endpoint: Option<String>,
status_sender: mpsc::Sender<ShardStatus>,
shutdown: Arc<AtomicBool>,
Expand Down Expand Up @@ -548,6 +554,11 @@ fn shard_manager(
adapter_memory_fraction.to_string().into(),
));

// Prefix caching
if let Some(prefix_caching) = prefix_caching {
envs.push(("PREFIX_CACHING".into(), prefix_caching.to_string().into()));
}

// Safetensors load fast
envs.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));

Expand Down Expand Up @@ -984,6 +995,7 @@ fn spawn_shards(
let watermark_delta = args.watermark_delta;
let cuda_memory_fraction = args.cuda_memory_fraction;
let adapter_memory_fraction = args.adapter_memory_fraction;
let prefix_caching = args.prefix_caching;
thread::spawn(move || {
shard_manager(
model_id,
Expand All @@ -1009,6 +1021,7 @@ fn spawn_shards(
watermark_delta,
cuda_memory_fraction,
adapter_memory_fraction,
prefix_caching,
otlp_endpoint,
status_sender,
shutdown,
Expand Down Expand Up @@ -1164,6 +1177,10 @@ fn spawn_webserver(
router_args.push("--eager-prefill".to_string());
}

if args.prefix_caching.unwrap_or(false) {
router_args.push("--prefix-caching".to_string());
}

// Ngrok
if args.ngrok {
router_args.push("--ngrok".to_string());
Expand Down
2 changes: 2 additions & 0 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ message Request {
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// Prefix length that can be retrieved from the KV cache
uint32 prefix_len = 11;
}

message Batch {
Expand Down
1 change: 1 addition & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ reqwest-retry = "0.4.0"
regex = "1.5.4"
serde = "1.0.152"
serde_json = { version = "1.0.93", features = ["preserve_order"] }
slotmap = "1.0.7"
thiserror = "1.0.38"
tokenizers = { version = "0.19.1", features = ["http"] }
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] }
Expand Down
1 change: 1 addition & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ impl Client {
// Blocks and slots will be set on the server side if we use paged attention
blocks: vec![],
slots: vec![],
prefix_len: 0,
// Set sampling parameters to also take these ops into account in the max memory
parameters: Some(NextTokenChooserParameters {
temperature: 0.9,
Expand Down
68 changes: 64 additions & 4 deletions router/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::{

pub(crate) trait ValidRequest: Sync + Send + Debug + Any {
fn input_length(&self) -> u32;
fn input_ids(&self) -> Option<Arc<Vec<u32>>>;
fn max_new_tokens(&self) -> u32;
fn adapter(&self) -> Adapter;
fn to_batch(&self, num_entries: usize, queue_len: usize) -> Box<dyn BatchEntries>;
Expand All @@ -35,6 +36,14 @@ impl ValidRequest for ValidGenerateRequest {
self.input_length
}

fn input_ids(&self) -> Option<Arc<Vec<u32>>> {
if let Some(tokenized_inputs) = &self.tokenized_inputs {
Some(Arc::new(tokenized_inputs.ids.clone()))
} else {
None
}
}

fn max_new_tokens(&self) -> u32 {
self.stopping_parameters.max_new_tokens
}
Expand Down Expand Up @@ -65,6 +74,14 @@ impl ValidRequest for ValidEmbedRequest {
self.input_length
}

fn input_ids(&self) -> Option<Arc<Vec<u32>>> {
if let Some(tokenized_inputs) = &self.tokenized_inputs {
Some(Arc::new(tokenized_inputs.ids.clone()))
} else {
None
}
}

fn max_new_tokens(&self) -> u32 {
1
}
Expand Down Expand Up @@ -95,6 +112,14 @@ impl ValidRequest for ValidClassifyRequest {
self.input_length
}

fn input_ids(&self) -> Option<Arc<Vec<u32>>> {
if let Some(tokenized_inputs) = &self.tokenized_inputs {
Some(Arc::new(tokenized_inputs.ids.clone()))
} else {
None
}
}

fn max_new_tokens(&self) -> u32 {
1
}
Expand Down Expand Up @@ -227,7 +252,15 @@ impl BatchEntriesState {
#[async_trait]
pub(crate) trait BatchEntries: Sync + Send + Debug {
fn can_add(&self, entry: &Entry) -> bool;
fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec<u32>, slots: Vec<u32>);
fn add(
&mut self,
id: u64,
entry: Entry,
adapter: Adapter,
blocks: Vec<u32>,
slots: Vec<u32>,
prefix_len: u32,
);
fn extend(&mut self, entries: Box<dyn BatchEntries>);
fn drain(&mut self) -> Vec<(Adapter, u64, Entry)>;
fn create_batch_data(&self, batch_id: u64, max_tokens: u32, max_blocks: u32) -> Batch;
Expand Down Expand Up @@ -281,7 +314,15 @@ impl BatchEntries for GenerateBatchEntries {
result
}

fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec<u32>, slots: Vec<u32>) {
fn add(
&mut self,
id: u64,
entry: Entry,
adapter: Adapter,
blocks: Vec<u32>,
slots: Vec<u32>,
prefix_len: u32,
) {
let valid_request = entry
.request
.as_ref()
Expand All @@ -300,6 +341,7 @@ impl BatchEntries for GenerateBatchEntries {
adapter_index: adapter.index(),
blocks,
slots,
prefix_len,
};

self.state.add(id, entry, adapter, request_proto);
Expand Down Expand Up @@ -401,7 +443,15 @@ impl BatchEntries for EmbedBatchEntries {
result
}

fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec<u32>, slots: Vec<u32>) {
fn add(
&mut self,
id: u64,
entry: Entry,
adapter: Adapter,
blocks: Vec<u32>,
slots: Vec<u32>,
prefix_len: u32,
) {
let valid_request = entry
.request
.as_ref()
Expand All @@ -420,6 +470,7 @@ impl BatchEntries for EmbedBatchEntries {
adapter_index: adapter.index(),
blocks,
slots,
prefix_len,
};

self.state.add(id, entry, adapter, request_proto);
Expand Down Expand Up @@ -515,7 +566,15 @@ impl BatchEntries for ClassifyBatchEntries {
result
}

fn add(&mut self, id: u64, entry: Entry, adapter: Adapter, blocks: Vec<u32>, slots: Vec<u32>) {
fn add(
&mut self,
id: u64,
entry: Entry,
adapter: Adapter,
blocks: Vec<u32>,
slots: Vec<u32>,
prefix_len: u32,
) {
let valid_request = entry
.request
.as_ref()
Expand All @@ -534,6 +593,7 @@ impl BatchEntries for ClassifyBatchEntries {
adapter_index: adapter.index(),
blocks,
slots,
prefix_len,
};

self.state.add(id, entry, adapter, request_proto);
Expand Down
Loading

0 comments on commit 1d2b514

Please sign in to comment.