Skip to content

Commit

Permalink
Cache the user prompt states too.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Feb 14, 2024
1 parent 9cced37 commit 1d40a25
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 111 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ai00_server"
version = "0.3.14"
version = "0.3.15"
edition = "2021"
authors = ["Gu ZhenNiu <[email protected]>", "Zhang Zhenyuan <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
1 change: 1 addition & 0 deletions assets/configs/Config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ quant_type = "Int8" # Quantization t
turbo = true # Whether to use alternative GEMM kernel to speed-up long prompts.
token_chunk_size = 32 # Size of token chunk that is inferred at once. For high end GPUs, this could be 64 or 128 (faster).
head_chunk_size = 8192 # DO NOT modify this if you don't know what you are doing.
state_chunk_size = 4 # The chunk size of layers in model state.
max_runtime_batch = 8 # The maximum batches that can be scheduled for inference at the same time.
max_batch = 16 # The maximum batches that are cached on GPU.
embed_layer = 2 # The (reversed) layer number whose output is used as embedding.
Expand Down
5 changes: 5 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ impl From<Config> for ReloadRequest {
turbo,
token_chunk_size,
head_chunk_size,
state_chunk_size,
max_runtime_batch,
max_batch,
embed_layer,
Expand All @@ -45,6 +46,7 @@ impl From<Config> for ReloadRequest {
turbo,
token_chunk_size,
head_chunk_size,
state_chunk_size,
max_runtime_batch,
max_batch,
embed_layer,
Expand All @@ -70,6 +72,8 @@ pub struct Model {
pub token_chunk_size: usize,
/// The chunk size for each split of the head matrix.
pub head_chunk_size: usize,
/// The chunk size of layers in model state.
pub state_chunk_size: usize,
/// Maximum number of batches that are active at once.
pub max_runtime_batch: usize,
/// Number of states that are cached on GPU.
Expand All @@ -89,6 +93,7 @@ impl Default for Model {
turbo: true,
token_chunk_size: 32,
head_chunk_size: 8192,
state_chunk_size: 4,
max_runtime_batch: 8,
max_batch: 16,
embed_layer: 2,
Expand Down
10 changes: 8 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ mod sampler;
mod utils;

pub const MAX_TOKENS: usize = 4096;
pub const STATE_CHUNK_SIZE: usize = 4;

#[derive(Debug)]
pub enum Token {
Expand Down Expand Up @@ -182,6 +181,8 @@ pub struct ReloadRequest {
pub token_chunk_size: usize,
/// The chunk size for each split of the head matrix.
pub head_chunk_size: usize,
/// The chunk size of layers in model state.
pub state_chunk_size: usize,
/// Maximum number of batches that are active at once.
pub max_runtime_batch: usize,
/// Number of states that are cached on GPU.
Expand Down Expand Up @@ -290,7 +291,7 @@ where

let state: S = StateBuilder::new(context, model.info())
.with_num_batch(request.max_batch)
.with_chunk_size(STATE_CHUNK_SIZE)
.with_chunk_size(request.state_chunk_size)
.build();
Ok((model, state))
}
Expand Down Expand Up @@ -397,6 +398,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
let reload = async move {
let sender = sender.clone();
let max_runtime_batch = request.max_runtime_batch;
let state_chunk_size = request.state_chunk_size;
let embed_layer = request.embed_layer;

let file = File::open(&request.model_path)?;
Expand All @@ -419,6 +421,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
model,
state,
max_runtime_batch,
state_chunk_size,
embed_layer,
))
}
Expand All @@ -429,6 +432,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
model,
state,
max_runtime_batch,
state_chunk_size,
embed_layer,
))
}
Expand All @@ -439,6 +443,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
model,
state,
max_runtime_batch,
state_chunk_size,
embed_layer,
))
}
Expand Down Expand Up @@ -492,6 +497,7 @@ async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {

let context = GenerateContext {
prompt_tokens: tokens.to_vec(),
prompt_cached: false,
prefix: Default::default(),
suffix: tokens,
penalties,
Expand Down
Loading

0 comments on commit 1d40a25

Please sign in to comment.