Skip to content

Commit

Permalink
Speed up best_of when using prefix caching (#698)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Nov 27, 2024
1 parent 8f51d0e commit da95224
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 2 deletions.
3 changes: 3 additions & 0 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,9 @@ def generate(

if resp.status_code != 200:
raise parse_error(resp.status_code, payload, resp.headers if LORAX_DEBUG_MODE else None)

if LORAX_DEBUG_MODE:
print(resp.headers)

return Response(**payload[0])

Expand Down
2 changes: 1 addition & 1 deletion docs/guides/contributing/development_env.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ We'll be working out of three different terminals during development, each servi
Install development dependencies:

```shell
DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \
apt update && DEBIAN_FRONTEND=noninteractive apt install pkg-config rsync tmux rust-gdb git -y && \
PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
Expand Down
12 changes: 12 additions & 0 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ pub struct Infer {
limit_concurrent_requests: Arc<Semaphore>,
/// tokenizer for NER processing
tokenizer: Option<Arc<Tokenizer>>,
/// Whether prefix caching is enabled
pub prefix_caching: bool,
}

impl Infer {
Expand Down Expand Up @@ -268,6 +270,7 @@ impl Infer {
chat_template,
limit_concurrent_requests: semaphore,
tokenizer: tokenizer,
prefix_caching,
}
}

Expand Down Expand Up @@ -861,10 +864,19 @@ impl Infer {
&self,
request: GenerateRequest,
best_of: usize,
prefix_caching: bool,
) -> Result<(InferResponse, Vec<InferResponse>), InferError> {
// validate best_of parameter separately
let best_of = self.validation.validate_best_of(best_of)?;

// If prefix caching is enabled, first generate a single token to cache the prefix, then generate
// subsequent responses.
if prefix_caching {
let mut prefix_request = request.clone();
prefix_request.parameters.max_new_tokens = Some(1);
self.generate(prefix_request).await?;
}

// create multiple generate requests
let mut infer_responses: Vec<InferResponse> =
try_join_all((0..best_of).map(|_| self.generate(request.clone()))).await?;
Expand Down
4 changes: 3 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,7 +639,9 @@ async fn generate(
// Inference
let (response, best_of_responses) = match req.0.parameters.best_of {
Some(best_of) if best_of > 1 => {
let (response, best_of_responses) = infer.generate_best_of(req.0, best_of).await?;
let (response, best_of_responses) = infer
.generate_best_of(req.0, best_of, infer.prefix_caching)
.await?;
(response, Some(best_of_responses))
}
_ => (infer.generate(req.0).await?, None),
Expand Down

0 comments on commit da95224

Please sign in to comment.