Skip to content

Commit

Permalink
fix backtracking into prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Mar 25, 2024
1 parent c662637 commit 034f2e4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
11 changes: 11 additions & 0 deletions controllers/pyctrl/samples/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# asserts for microsoft/Orca-2-13b


async def test_backtrack_one():
await aici.FixedTokens("3+")
l = aici.Label()
Expand Down Expand Up @@ -124,6 +125,15 @@ def inst(s: str) -> str:
)


async def test_prompt_backtrack():
await aici.FixedTokens("Some test prompt for the model to generate more text.")
l = aici.Label()
await aici.FixedTokens("And then some more text.")
await aici.gen_tokens(max_tokens=2)
await aici.FixedTokens("Now different text.", following=l)
await aici.gen_tokens(max_tokens=2)


async def test_sample():
# initialization code
print("I'm going in the logs!")
Expand Down Expand Up @@ -161,6 +171,7 @@ async def test_eos():
await aici.gen_tokens(regex=r' "[^"]+"', max_tokens=6, store_var="french")
aici.check_vars({"french": ' "bonjour"'})


async def test_joke():
await aici.FixedTokens("Do you want a joke or a poem? A")
answer = await aici.gen_text(options=[" joke", " poem"])
Expand Down
2 changes: 2 additions & 0 deletions rllm/rllm-base/src/seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ impl Sequence {
) {
self.tokens.truncate(self.get_len() - backtrack);
self.output_ptr = std::cmp::min(self.output_ptr, self.get_len());
// backtracking can remove some tokens from the initial prompt
self.prompt_len = std::cmp::min(self.prompt_len, self.get_len());
if backtrack > 0 {
self.output_pending.clear();
self.output_pending.extend_from_slice(" ↩ ".as_bytes());
Expand Down

0 comments on commit 034f2e4

Please sign in to comment.