From da95224890789210aec105ac0beaba766d1f18cc Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 27 Nov 2024 10:38:46 -0800 Subject: [PATCH] Speed up best_of when using prefix caching (#698) --- clients/python/lorax/client.py | 3 +++ docs/guides/contributing/development_env.md | 2 +- router/src/infer.rs | 12 ++++++++++++ router/src/server.rs | 4 +++- 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index a8028834c..aa58e800d 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -293,6 +293,9 @@ def generate( if resp.status_code != 200: raise parse_error(resp.status_code, payload, resp.headers if LORAX_DEBUG_MODE else None) + + if LORAX_DEBUG_MODE: + print(resp.headers) return Response(**payload[0]) diff --git a/docs/guides/contributing/development_env.md b/docs/guides/contributing/development_env.md index 5f35eebb1..84c55ce91 100644 --- a/docs/guides/contributing/development_env.md +++ b/docs/guides/contributing/development_env.md @@ -47,7 +47,7 @@ We'll be working out of three different terminals during development, each servi Install development dependencies: ```shell -DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \ +apt update && DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \ PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ diff --git a/router/src/infer.rs b/router/src/infer.rs index 52f7dc2a0..101360ba1 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -160,6 +160,8 @@ pub struct Infer { limit_concurrent_requests: Arc, /// tokenizer for NER processing tokenizer: Option>, + /// Whether prefix caching is enabled + pub prefix_caching: bool, } impl Infer { @@ -268,6 +270,7 @@ impl Infer { chat_template, limit_concurrent_requests: semaphore, tokenizer: tokenizer, + prefix_caching, } } @@ -861,10 +864,19 @@ impl Infer { &self, request: GenerateRequest, best_of: usize, + prefix_caching: bool, ) -> Result<(InferResponse, Vec), InferError> { // validate best_of parameter separately let best_of = self.validation.validate_best_of(best_of)?; + // If prefix caching is enabled, first generate a single token to cache the prefix, then generate + // subsequent responses. + if prefix_caching { + let mut prefix_request = request.clone(); + prefix_request.parameters.max_new_tokens = Some(1); + self.generate(prefix_request).await?; + } + // create multiple generate requests let mut infer_responses: Vec = try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?; diff --git a/router/src/server.rs b/router/src/server.rs index 7717d5335..d33412aaf 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -639,7 +639,9 @@ async fn generate( // Inference let (response, best_of_responses) = match req.0.parameters.best_of { Some(best_of) if best_of > 1 => { - let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?; + let (response, best_of_responses) = infer + .generate_best_of(req.0, best_of, infer.prefix_caching) + .await?; (response, Some(best_of_responses)) } _ => (infer.generate(req.0).await?, None),