Skip to content

Commit

Permalink
fix: Disallow early stopping during warmup (#290)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Feb 29, 2024
1 parent e51f078 commit 36a24ba
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion docs/reference/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Options:
This is the maximum allowed input length (expressed in number of tokens) for users. The larger this value, the longer prompt users can send which can impact the overall memory required to handle the load. Please note that some models have a finite range of sequence they can handle

[env: MAX_INPUT_LENGTH=]
[default: 1024]
[default: 1792]

--max-total-tokens <MAX_TOTAL_TOKENS>
This is the most important value to set as it defines the "memory budget" of running clients requests. Clients will send input sequences and ask to generate `max_new_tokens` on top. with a value of `1512` users can send either a prompt of `1000` and ask for `512` new tokens, or send a prompt of `1` and ask for `1511` max_new_tokens. The larger this value, the larger amount each request will be in your RAM and the less effective batching can be
Expand Down
2 changes: 1 addition & 1 deletion router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct Args {
max_best_of: usize,
#[clap(default_value = "4", long, env)]
max_stop_sequences: usize,
#[clap(default_value = "1024", long, env)]
#[clap(default_value = "1792", long, env)]
max_input_length: usize,
#[clap(default_value = "2048", long, env)]
max_total_tokens: usize,
Expand Down
8 changes: 5 additions & 3 deletions server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int):

with tqdm(total=max_new_tokens, desc="Warmup to max_total_tokens") as pbar:
for _ in range(max_new_tokens):
_, batch = self.generate_token(batch)
_, batch = self.generate_token(batch, is_warmup=True)
pbar.update(1)
except RuntimeError as e:
if "CUDA out of memory" in str(e) or isinstance(e, torch.cuda.OutOfMemoryError):
Expand Down Expand Up @@ -836,7 +836,7 @@ def forward(self, batch: FlashCausalLMBatch, adapter_data: AdapterBatchData) ->

@tracer.start_as_current_span("generate_token")
def generate_token(
self, batch: FlashCausalLMBatch
self, batch: FlashCausalLMBatch, is_warmup: bool = False
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.cu_seqlen_prefill is not None
prefill_logprobs = batch.prefill_next_token_indices is not None
Expand Down Expand Up @@ -900,7 +900,9 @@ def generate_token(

# Results
generations: List[Generation] = []
stopped = True

# During warmup, do not allow early stopping
stopped = not is_warmup

# Zipped iterator
iterator = zip(
Expand Down

0 comments on commit 36a24ba

Please sign in to comment.