Skip to content

Commit

Permalink
Merge branch 'main' into stateless
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Mar 25, 2024
2 parents ecb7c4a + 034f2e4 commit e015274
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 2 deletions.
14 changes: 13 additions & 1 deletion controllers/jsctrl/samples/aici-types.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ declare module "_aici" {
*/
function detokenize(tokens: number[]): Buffer;

/**
* Return debug string representation of a given token index
*/
function tokenRepr(token: number): string;

/**
* Return identifier of the current sequence.
* Most useful with fork_group parameter in mid_process() callback.
Expand Down Expand Up @@ -200,20 +205,27 @@ declare module "_aici" {
*/
constructor();

toString(): string;

add(t: number): void;
delete(t: number): void;
has(t: number): boolean;
clear(): void;

/**
* Number of all tokens (not only in the set).
* Number of all possible tokens (regardless of whether they are in the set or not).
*/
length: number;

/**
* Include or exclude all tokens from the set.
*/
setAll(value: boolean): void;

/**
* Number of tokens in the set.
*/
numSet(): number;
}

/**
Expand Down
15 changes: 15 additions & 0 deletions controllers/jsctrl/src/jsctrl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ impl TokenSet {
self.inner.len()
}

pub fn toString(&self) -> String {
let trie = &mut GLOBAL_STATE.lock().unwrap().trie;
trie.token_set_dbg(&self.inner)
}

pub fn add(&mut self, tok: u32) {
self.inner.allow_token(tok);
}
Expand All @@ -144,6 +149,10 @@ impl TokenSet {
pub fn setAll(&mut self, val: bool) {
self.inner.set_all(val);
}

pub fn numSet(&self) -> usize {
self.inner.num_set()
}
}

impl Default for TokenSet {
Expand Down Expand Up @@ -261,6 +270,12 @@ mod aici_mod {
Buffer(bytes)
}

#[rquickjs::function]
pub fn tokenRepr(token: TokenId) -> String {
let trie = &mut GLOBAL_STATE.lock().unwrap().trie;
trie.token_dbg(token)
}

#[rquickjs::function]
pub fn getVar(name: String) -> Option<Buffer> {
let name = name.as_str();
Expand Down
7 changes: 7 additions & 0 deletions controllers/jsctrl/ts/aici.ts
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,11 @@ export class ConstrainedToken extends NextToken {
this._constraint = this.mkConstraint();
}
this._constraint.allowTokens(bias);
console.log("ALLOW:", bias.toString());
if (bias.numSet() === 0) {
console.log("Constraint doesn't allow any tokens; adding EOS")
return MidProcessResult.stop();
}
return MidProcessResult.bias(bias);
}

Expand Down Expand Up @@ -677,6 +682,8 @@ export async function genTokens(options: GenOptions): Promise<Token[]> {
const tokens = await next_token.run();
res.push(...tokens);

console.log("GEN-STEP:", tokens.map(t => _aici.tokenRepr(t)).join(", "));

const text = detokenize(res).decode();

if (stopAt !== undefined && text.includes(stopAt)) {
Expand Down
14 changes: 13 additions & 1 deletion controllers/jsctrl/ts/native.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ declare module "_aici" {
*/
function detokenize(tokens: number[]): Buffer;

/**
* Return debug string representation of a given token index
*/
function tokenRepr(token: number): string;

/**
* Return identifier of the current sequence.
* Most useful with fork_group parameter in mid_process() callback.
Expand Down Expand Up @@ -200,20 +205,27 @@ declare module "_aici" {
*/
constructor();

toString(): string;

add(t: number): void;
delete(t: number): void;
has(t: number): boolean;
clear(): void;

/**
* Number of all tokens (not only in the set).
* Number of all possible tokens (regardless of whether they are in the set or not).
*/
length: number;

/**
* Include or exclude all tokens from the set.
*/
setAll(value: boolean): void;

/**
* Number of tokens in the set.
*/
numSet(): number;
}

/**
Expand Down
11 changes: 11 additions & 0 deletions controllers/pyctrl/samples/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# asserts for microsoft/Orca-2-13b


async def test_backtrack_one():
await aici.FixedTokens("3+")
l = aici.Label()
Expand Down Expand Up @@ -124,6 +125,15 @@ def inst(s: str) -> str:
)


async def test_prompt_backtrack():
await aici.FixedTokens("Some test prompt for the model to generate more text.")
l = aici.Label()
await aici.FixedTokens("And then some more text.")
await aici.gen_tokens(max_tokens=2)
await aici.FixedTokens("Now different text.", following=l)
await aici.gen_tokens(max_tokens=2)


async def test_sample():
# initialization code
print("I'm going in the logs!")
Expand Down Expand Up @@ -161,6 +171,7 @@ async def test_eos():
await aici.gen_tokens(regex=r' "[^"]+"', max_tokens=6, store_var="french")
aici.check_vars({"french": ' "bonjour"'})


async def test_joke():
await aici.FixedTokens("Do you want a joke or a poem? A")
answer = await aici.gen_text(options=[" joke", " poem"])
Expand Down
2 changes: 2 additions & 0 deletions rllm/rllm-base/src/seq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ impl Sequence {
) {
self.tokens.truncate(self.get_len() - backtrack);
self.output_ptr = std::cmp::min(self.output_ptr, self.get_len());
// backtracking can remove some tokens from the initial prompt
self.prompt_len = std::cmp::min(self.prompt_len, self.get_len());
if backtrack > 0 {
self.output_pending.clear();
self.output_pending.extend_from_slice(" ↩ ".as_bytes());
Expand Down

0 comments on commit e015274

Please sign in to comment.