Skip to content

Commit

Permalink
token splicing in vllm
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Apr 12, 2024
1 parent 9f27a0e commit c299820
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 2 deletions.
7 changes: 6 additions & 1 deletion controllers/pyctrl/samples/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,15 @@ async def test_backtrack_lang():


async def test_hello():
aici.log_level = 10
aici.log_level = 3
prompt = await aici.GetPrompt()
print("prompt", prompt)
await aici.gen_tokens(regex=r"[A-Z].*", max_tokens=5)
l = aici.Label()
await aici.FixedTokens("\n2 + 2 = ")
await aici.gen_tokens(regex=r"\d+", max_tokens=1)
await aici.FixedTokens("\n3 + 3 = ", following=l)
await aici.gen_tokens(regex=r"\d+", max_tokens=1)


async def test_main():
Expand Down
8 changes: 8 additions & 0 deletions py/pyaici/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ class Branch:
mask: Optional[int] = None
splices: List[Splice] = field(default_factory=list)

def find_splice(self, token: Optional[int]) -> Optional[Splice]:
if self.mask is None:
token = None
for splice in self.splices:
if not splice.when_sampled or token in splice.when_sampled:
return splice
return None

def is_splice(self) -> bool:
return len(self.splices) == 1 and self.splices[0].when_sampled == []

Expand Down
2 changes: 1 addition & 1 deletion py/vllm
Submodule vllm updated 103 files
2 changes: 2 additions & 0 deletions scripts/vllm-server.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ RUST_LOG=info,tokenizers=error,aicirt=trace \
PYTHONPATH=py:py/vllm \
python3 -m vllm.entrypoints.openai.api_server \
--enforce-eager \
--use-v2-block-manager \
--enable-chunked-prefill \
--aici-rt ./target/release/aicirt \
--aici-tokenizer $AICI_TOK \
--model $MODEL \
Expand Down

0 comments on commit c299820

Please sign in to comment.