Skip to content

Commit

Permalink
Fix outlines compatibility with speculative decoding (#578)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair authored Aug 16, 2024
1 parent 12dc740 commit ca2e643
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 9 deletions.
4 changes: 3 additions & 1 deletion server/lorax_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ def from_pb(

adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device)

request_tokenizers = [tokenizers.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests]
# always use the base model tokenizer for the next token chooser until we revisit adding back support
# for per-request tokenizers
request_tokenizers = [tokenizer for _ in pb.requests]
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, request_tokenizers, dtype, device
)
Expand Down
41 changes: 34 additions & 7 deletions server/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ boto3 = "^1.28.34"
urllib3 = "<=1.26.18"
hqq = { version = "^0.1.7", optional = true }
stanford-stk = { version = "^0.7.0", markers = "sys_platform == 'linux'" }
outlines = { version = "^0.0.40", optional = true }
outlines = { version = "^0.0.46", optional = true }
prometheus-client = "^0.20.0"
py-cpuinfo = "^9.0.0"
nvidia-ml-py = "^12.555.43"
Expand Down

0 comments on commit ca2e643

Please sign in to comment.