From a2486815330ad541208d9c3d617c51b52b1837db Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Fri, 19 Jan 2024 00:24:35 +0000 Subject: [PATCH] organize js stuff into folders --- .vscode/settings.json | 4 +- jsctrl/build.sh | 6 + jsctrl/gen-dts.mjs | 26 ++ jsctrl/samples/aici-types.d.ts | 495 +++++++++++++++++++++++ jsctrl/samples/hello.js | 6 + jsctrl/samples/hellots.ts | 6 + jsctrl/{ts/sample.ts => samples/test.ts} | 6 +- jsctrl/samples/tsconfig.json | 18 + jsctrl/src/jsctrl.rs | 2 +- jsctrl/ts/aici.ts | 51 ++- jsctrl/ts/gen/aici.d.ts | 225 +++++++++++ jsctrl/ts/{ => gen}/aici.js | 34 +- jsctrl/ts/native.d.ts | 116 +++++- jsctrl/ts/sample.js | 128 ------ jsctrl/ts/tsconfig.json | 10 +- 15 files changed, 950 insertions(+), 183 deletions(-) create mode 100755 jsctrl/build.sh create mode 100644 jsctrl/gen-dts.mjs create mode 100644 jsctrl/samples/aici-types.d.ts create mode 100644 jsctrl/samples/hello.js create mode 100644 jsctrl/samples/hellots.ts rename jsctrl/{ts/sample.ts => samples/test.ts} (99%) create mode 100644 jsctrl/samples/tsconfig.json create mode 100644 jsctrl/ts/gen/aici.d.ts rename jsctrl/ts/{ => gen}/aici.js (96%) delete mode 100644 jsctrl/ts/sample.js diff --git a/.vscode/settings.json b/.vscode/settings.json index e5307ea5..10060de7 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -120,6 +120,8 @@ "unpkg.com" ], "files.readonlyInclude": { - "*/ts/*.js": true + "**/ts/gen/*": true, + "**/dist/*": true, + "**/aici-types.d.ts": true } } \ No newline at end of file diff --git a/jsctrl/build.sh b/jsctrl/build.sh new file mode 100755 index 00000000..c969ae76 --- /dev/null +++ b/jsctrl/build.sh @@ -0,0 +1,6 @@ +#!/bin/sh +tsc --version > /dev/null 2>&1 || npm install -g typescript +set -e +tsc -p ts +node gen-dts.mjs + diff --git a/jsctrl/gen-dts.mjs b/jsctrl/gen-dts.mjs new file mode 100644 index 00000000..5ed23e73 --- /dev/null +++ b/jsctrl/gen-dts.mjs @@ -0,0 +1,26 @@ +import * as fs from 'fs'; + +function gen() { + const ts = "./ts/" + const native = fs.readFileSync(ts + '/native.d.ts', 'utf8') + let aici = fs.readFileSync(ts + '/gen/aici.d.ts', 'utf8') + aici = aici.replace(/ pre + aici + '"""') + + const tsconfig = fs.readFileSync("./samples/tsconfig.json", "utf-8") + jssrc = jssrc.replace(/(tsconfig_json = """)[^]*?"""/g, (_, pre) => pre + tsconfig + '"""') + + const hello = fs.readFileSync("./samples/hello.js", "utf-8") + jssrc = jssrc.replace(/(hello_js = """)[^]*?"""/g, (_, pre) => pre + hello + '"""') + + fs.writeFileSync("../pyaici/jssrc.py", jssrc) +} + +gen() \ No newline at end of file diff --git a/jsctrl/samples/aici-types.d.ts b/jsctrl/samples/aici-types.d.ts new file mode 100644 index 00000000..1dbfcdc6 --- /dev/null +++ b/jsctrl/samples/aici-types.d.ts @@ -0,0 +1,495 @@ +// Top-level symbols + +type Token = number; +type Buffer = Uint8Array; + +/** + * Force the exact tokens to be generated; usage: await $`Some text` + */ +declare function $(strings: TemplateStringsArray, ...values: any[]): Promise; + +/** + * Throw an exception if the condition is not met. + */ +declare function assert(cond: boolean, msg?: string): asserts cond; + +/** + * Forces next tokens to be exactly the given text. + */ +declare function fixed(text: string): Promise; + +/** + * Forks the execution into `numForks` branches. + * @param numForks how many branches + * @returns a number from 0 to `numForks`-1, indicating the branch + */ +declare function fork(numForks: number): Promise; + +/** + * Suspends execution until all variables are available. + * @param vars names of variables + * @returns values of the variables + */ +declare function waitVars(...vars: string[]): Promise; + +/** + * Starts the AICI loop. The coroutine may first `await aici.getPrompt()` and + * then can `await aici.gen_*()` or `await aici.FixedTokens()` multiple times. + * @param f async function + */ +declare function start(f: () => Promise): void; + +/** + * Specifies options for gen() and genTokens(). + */ +interface GenOptions { + /** + * Make sure the generated text is one of the options. + */ + options?: string[]; + /** + * Make sure the generated text matches given regular expression. + */ + regex?: string | RegExp; + /** + * Make sure the generated text matches given yacc-like grammar. + */ + yacc?: string; + /** + * Make sure the generated text is a substring of the given string. + */ + substring?: string; + /** + * Used together with `substring` - treat the substring as ending the substring + * (typically '"' or similar). + */ + substringEnd?: string; + /** + * Store result of the generation (as bytes) into a shared variable. + */ + storeVar?: string; + /** + * Stop generation when the string is generated (the result includes the string and any following bytes (from the same token)). + */ + stopAt?: string; + /** + * Stop generation when the given number of tokens have been generated. + */ + maxTokens?: number; +} + +/** + * Generate a string that matches given constraints. + * If the tokens do not map cleanly into strings, it will contain Unicode replacement characters. + */ +declare function gen(options: GenOptions): Promise; + +/** + * Generate a list of tokens that matches given constraints. + */ +declare function genTokens(options: GenOptions): Promise; + +// Extensions of JavaScript built-in types + +interface String { + /** + * UTF-8 encode the current string. + */ + toBuffer(): Uint8Array; +} + +interface StringConstructor { + /** + * Create a string from UTF-8 buffer (with replacement character for invalid sequences) + */ + fromBuffer(buffer: Uint8Array): string; +} + +interface Uint8Array { + /** + * UTF-8 decode the current buffer. + */ + decode(): string; +} + +/** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console) */ +interface Console { + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/debug) */ + debug(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/error) */ + error(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/info) */ + info(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/log) */ + log(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/trace) */ + trace(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/warn) */ + warn(...data: any[]): void; +} + +declare var console: Console; + +// native module +declare module "_aici" { + type Buffer = Uint8Array; + + /** + * Return token indices for a given string (or byte sequence). + */ + function tokenize(text: string | Buffer): number[]; + + /** + * Return byte (~string) representation of a given list of token indices. + */ + function detokenize(tokens: number[]): Buffer; + + /** + * Return identifier of the current sequence. + * Most useful with fork_group parameter in mid_process() callback. + * Best use aici.fork() instead. + */ + function selfSeqId(): number; + + /** + * Print out a message of the error and stop the program. + */ + function panic(error: any): never; + + /** + * Get the value of a shared variable. + */ + function getVar(name: string): Buffer | null; + + /** + * Set the value of a shared variable. + */ + function setVar(name: string, value: string | Buffer): void; + + /** + * Append to the value of a shared variable. + */ + function appendVar(name: string, value: string | Buffer): void; + + /** + * Index of the end of sequence token. + */ + function eosToken(): number; + + /** + * UTF-8 encode + */ + function stringToBuffer(s: string): Buffer; + + /** + * UTF-8 decode (with replacement character for invalid sequences) + */ + function bufferToString(b: Buffer): string; + + /** + * Return a string like `b"..."` that represents the given buffer. + */ + function bufferRepr(b: Buffer): string; + + /** + * Represents a set of tokens. + * The value is true at indices corresponding to tokens in the set. + */ + class TokenSet { + /** + * Create an empty set (with .length set to the total number of tokens). + */ + constructor(); + + add(t: number): void; + delete(t: number): void; + has(t: number): boolean; + clear(): void; + + /** + * Number of all tokens (not only in the set). + */ + length: number; + + /** + * Include or exclude all tokens from the set. + */ + setAll(value: boolean): void; + } + + /** + * Initialize a constraint that allows any token. + */ + class Constraint { + constructor(); + + /** + * Check if the constraint allows the generation to end at the current point. + */ + eosAllowed(): boolean; + + /** + * Check if the constraint forces the generation to end at the current point. + */ + eosForced(): boolean; + + /** + * Check if token `t` is allowed by the constraint. + */ + tokenAllowed(t: number): boolean; + + /** + * Update the internal state of the constraint to reflect that token `t` was appended. + */ + appendToken(t: number): void; + + /** + * Set ts[] to True at all tokens that are allowed by the constraint. + */ + allowTokens(ts: TokenSet): void; + } + + /** + * A constraint that allows only tokens that match the regex. + * The regex is implicitly anchored at the start and end of the generation. + */ + function regexConstraint(pattern: string): Constraint; + + /** + * A constraint that allows only tokens that match the specified yacc-like grammar. + */ + function cfgConstraint(yacc_grammar: string): Constraint; + + /** + * A constraint that allows only word-substrings of given string. + */ + function substrConstraint(template: string, stop_at: string): Constraint; +} +declare module 'aici' { +/// +import { TokenSet, tokenize, detokenize, regexConstraint, cfgConstraint, substrConstraint, Constraint, getVar, setVar, appendVar, eosToken, panic } from "_aici"; +export { TokenSet, tokenize, detokenize, getVar, setVar, appendVar, eosToken }; +export type SeqId = number; +type int = number; +export function inspect(v: any): string; +export function log(...args: any[]): void; +export class AssertionError extends Error { +} +/** + * Throw an exception if the condition is not met. + */ +export function assert(cond: boolean, msg?: string): asserts cond; +/** + * Get list of tokens in the current sequence, including the prompt. + */ +export function getTokens(): Token[]; +/** + * Get the length of the prompt in the current sequence. + */ +export function getPromptLen(): number; +export class MidProcessResult { + _n_skip_me: boolean; + _n_stop: boolean; + _n_logit_bias: TokenSet | null; + _n_backtrack: number; + _n_ff_tokens: Token[]; + constructor(); + static stop(): MidProcessResult; + static skipMe(): MidProcessResult; + static bias(bias: TokenSet): MidProcessResult; + static splice(backtrack: number, tokens: Token[]): MidProcessResult; +} +export class PreProcessResult { + _n_suspended: boolean; + _n_ff_tokens: Token[]; + _n_attention_masks: number[][]; + constructor(); + static continue_(): PreProcessResult; + static suspend(): PreProcessResult; + static fork(numForks: number): PreProcessResult; + static ffTokens(toks: Token[]): PreProcessResult; +} +export class PostProcessResult { + _n_stop_seq: boolean; + constructor(stop_seq?: boolean); + static continue_(): PostProcessResult; + static stop(): PostProcessResult; + static fromTokens(tokens: Token[]): PostProcessResult; +} +export class NextToken { + finished: boolean; + currTokens: Token[] | null; + forkGroup: SeqId[]; + _resolve?: (value: Token[]) => void; + constructor(); + /** + * Awaiting this will return generated token (or tokens, if fast-forwarding requested by self.mid_process()). + * You have only ~1ms to process the results before awaiting a new instance of NextToken() again. + */ + run(): Promise; + /** + * Override to suspend, if the model cannot continue generating tokens + * now (for example, not all variables are available to compute bias). + * ~1ms time limit. + */ + preProcess(): PreProcessResult; + /** + * This can be overridden to return a bias, fast-forward tokens, backtrack etc. + * ~20ms time limit. + */ + midProcess(): MidProcessResult; + /** + * This can be overridden to do something with generated tokens. + * ~1ms time limit. + * @param tokens tokens generated in the last step + */ + postProcess(tokens: Token[]): PostProcessResult; + _pre_process(): PreProcessResult; + _mid_process(fork_group: SeqId[]): MidProcessResult; + _post_process(_backtrack: int, tokens: Token[]): PostProcessResult; + private reset; +} +/** + * Forces next tokens to be exactly the given text. + */ +export function fixed(text: string): Promise; +/** + * Force the exact tokens to be generated; usage: await $`Some text` + */ +export function $(strings: TemplateStringsArray, ...values: any[]): Promise; +/** + * Forces next tokens to be exactly the given text. + * If following is given, the text replaces everything that follows the label. + */ +class FixedTokens extends NextToken { + fixedTokens: Token[]; + following: Label | null; + constructor(text: string | Buffer, following?: Label | null); + preProcess(): PreProcessResult; + midProcess(): MidProcessResult; +} +/** + * Indicates that the generation should stop. + */ +class StopToken extends NextToken { + constructor(); + midProcess(): MidProcessResult; + postProcess(_tokens: Token[]): PostProcessResult; +} +/** + * Generates a token that satisfies the given constraint. + * The constraint will be constructed in mid_process() phase, which has slightly longer time limit. + */ +export class ConstrainedToken extends NextToken { + mkConstraint: () => Constraint; + _constraint: Constraint | null; + constructor(mkConstraint: () => Constraint); + midProcess(): MidProcessResult; + postProcess(tokens: Token[]): PostProcessResult; +} +export class PreToken extends NextToken { + midProcess(): MidProcessResult; +} +/** + * Forks the execution into `numForks` branches. + * @param numForks how many branches + * @returns a number from 0 to `numForks`-1, indicating the branch + */ +export function fork(numForks: number): Promise; +/** + * Suspends execution until all variables are available. + * @param vars names of variables + * @returns values of the variables + */ +export function waitVars(...vars: string[]): Promise; +/** + * Low-level interface for AICI. Use aici.start() to wrap a coroutine. + */ +export interface AiciCallbacks { + init_prompt(prompt: Token[]): void; + pre_process(): PreProcessResult; + mid_process(fork_group: SeqId[]): MidProcessResult; + post_process(backtrack: number, tokens: Token[]): PostProcessResult; +} +/** + * Awaiting this returns the prompt passed by the user. + * The code before call to this function has a long time limit (~1000ms). + * Afterwards, the time limit is ~1ms before awaiting NextToken(). + */ +export function getPrompt(): Promise; +class GetPrompt { + _resolve?: (value: Token[]) => void; + run(): Promise; +} +export type CbType = NextToken; +export class AiciAsync implements AiciCallbacks { + static instance: AiciAsync; + _tokens: Token[]; + _prompt_len: number; + private _pendingCb; + private _token; + private _getPrompt; + private midProcessReEntry; + _setGetPrompt(g: GetPrompt): void; + _nextToken(t: NextToken): void; + constructor(f: () => Promise); + step(tokens: Token[]): void; + init_prompt(prompt: Token[]): void; + pre_process(): PreProcessResult; + mid_process(fork_group: SeqId[]): MidProcessResult; + post_process(backtrack: number, tokens: Token[]): PostProcessResult; +} +/** + * Starts the AICI loop. The coroutine may first `await aici.getPrompt()` and + * then can `await aici.gen_*()` or `await aici.FixedTokens()` multiple times. + * @param f async function + */ +export function start(f: () => Promise): AiciAsync; +/** + * Runs the loop as a test. + */ +export function test(f: () => Promise): AiciAsync; +export class Label { + ptr: number; + /** + * Create a new label the indicates the current position in the sequence. + * Can be passed as `following=` argument to `FixedTokens()`. + */ + constructor(); + /** + * Return tokens generated since the label. + */ + tokensSince(): Token[]; + /** + * Return text generated since the label. + */ + textSince(): string; + /** + * Generate given prompt text, replacing all text after the current label. + */ + fixedAfter(text: string): Promise; +} +export class ChooseConstraint extends Constraint { + ptr: number; + options: Token[][]; + constructor(options: string[]); + eosAllowed(): boolean; + eosForced(): boolean; + tokenAllowed(t: Token): boolean; + appendToken(t: Token): void; + allowTokens(ts: TokenSet): void; +} +export function genTokens(options: GenOptions): Promise; +export function gen(options: GenOptions): Promise; +export function checkVar(name: string, value: string): void; +export function checkVars(d: Record): void; +export const helpers: { + regex_constraint: typeof regexConstraint; + cfg_constraint: typeof cfgConstraint; + substr_constraint: typeof substrConstraint; + FixedTokens: typeof FixedTokens; + StopToken: typeof StopToken; + panic: typeof panic; +}; + +} diff --git a/jsctrl/samples/hello.js b/jsctrl/samples/hello.js new file mode 100644 index 00000000..705b7908 --- /dev/null +++ b/jsctrl/samples/hello.js @@ -0,0 +1,6 @@ +async function main() { + await $`Hello` + await gen({ regex: / [A-Z]+/ }) +} + +start(main) diff --git a/jsctrl/samples/hellots.ts b/jsctrl/samples/hellots.ts new file mode 100644 index 00000000..705b7908 --- /dev/null +++ b/jsctrl/samples/hellots.ts @@ -0,0 +1,6 @@ +async function main() { + await $`Hello` + await gen({ regex: / [A-Z]+/ }) +} + +start(main) diff --git a/jsctrl/ts/sample.ts b/jsctrl/samples/test.ts similarity index 99% rename from jsctrl/ts/sample.ts rename to jsctrl/samples/test.ts index 281cd518..942fe511 100644 --- a/jsctrl/ts/sample.ts +++ b/jsctrl/samples/test.ts @@ -1,16 +1,12 @@ import { - $, Label, - assert, checkVars, - fork, - gen, getPrompt, test, waitVars, getVar, setVar, -} from "./aici"; +} from "aici"; async function main() { await $`2 + 2 =`; diff --git a/jsctrl/samples/tsconfig.json b/jsctrl/samples/tsconfig.json new file mode 100644 index 00000000..8c876c3c --- /dev/null +++ b/jsctrl/samples/tsconfig.json @@ -0,0 +1,18 @@ +{ + "compilerOptions": { + /* Visit https://aka.ms/tsconfig to read more about this file */ + "target": "ES2020", + "lib": [ + "ES2020" + ], + "moduleDetection": "force", + "module": "ES2020", + "allowJs": true, + "checkJs": true, + "strict": true, + "noImplicitThis": true, + "noImplicitReturns": true, + "outDir": "./dist", + "skipDefaultLibCheck": true, + } +} \ No newline at end of file diff --git a/jsctrl/src/jsctrl.rs b/jsctrl/src/jsctrl.rs index 81e005ef..c6c39292 100644 --- a/jsctrl/src/jsctrl.rs +++ b/jsctrl/src/jsctrl.rs @@ -419,7 +419,7 @@ impl Runner { context: Context::full(&rt).unwrap(), }; - let aici_js = include_str!("../ts/aici.js"); + let aici_js = include_str!("../ts/gen/aici.js"); s.with_cb("_new", |ctx| { let global = ctx.globals(); diff --git a/jsctrl/ts/aici.ts b/jsctrl/ts/aici.ts index 7e021e65..3894429f 100644 --- a/jsctrl/ts/aici.ts +++ b/jsctrl/ts/aici.ts @@ -19,13 +19,10 @@ export { TokenSet, tokenize, detokenize, getVar, setVar, appendVar, eosToken }; import * as _aici from "_aici"; -export type Token = number; export type SeqId = number; - type int = number; -type Buffer = Uint8Array; -function dbgarg(arg: any, depth: number): string { +function dbgArg(arg: any, depth: number): string { const maxElts = 20; const maxDepth = 2; const maxStr = 200; @@ -42,7 +39,7 @@ function dbgarg(arg: any, depth: number): string { arg = arg.slice(0, maxElts); suff = ", ...]"; } - return "[" + arg.map((x: any) => dbgarg(x, depth + 1)).join(", ") + suff; + return "[" + arg.map((x: any) => dbgArg(x, depth + 1)).join(", ") + suff; } else { let keys = Object.keys(arg); if (depth >= maxDepth && keys.length > 0) return "{...}"; @@ -53,7 +50,7 @@ function dbgarg(arg: any, depth: number): string { } return ( "{" + - keys.map((k) => `${k}: ${dbgarg(arg[k], depth + 1)}`).join(", ") + + keys.map((k) => `${k}: ${dbgArg(arg[k], depth + 1)}`).join(", ") + suff ); } @@ -70,21 +67,18 @@ function dbgarg(arg: any, depth: number): string { } export function inspect(v: any) { - return dbgarg(v, 0); + return dbgArg(v, 0); } export function log(...args: any[]) { (console as any)._print(args.map((x) => inspect(x)).join(" ")); } -console.log = log; -console.info = log; -console.warn = log; -console.debug = log; -console.trace = log; - export class AssertionError extends Error {} +/** + * Throw an exception if the condition is not met. + */ export function assert(cond: boolean, msg = "Assertion failed"): asserts cond { if (!cond) throw new AssertionError(msg); } @@ -272,7 +266,7 @@ export async function fixed(text: string) { } /** - * Same as fixed(); usage: await $`Some text` + * Force the exact tokens to be generated; usage: await $`Some text` */ export async function $(strings: TemplateStringsArray, ...values: any[]) { let result = ""; @@ -623,7 +617,7 @@ export class Label { ptr: number; /** - * Create a new label the indictes the current position in the sequence. + * Create a new label the indicates the current position in the sequence. * Can be passed as `following=` argument to `FixedTokens()`. */ constructor() { @@ -692,18 +686,7 @@ export class ChooseConstraint extends Constraint { } } -export type GenOptions = { - regex?: string | RegExp; - yacc?: string; - substring?: string; - substringEnd?: string; - options?: string[]; - storeVar?: string; - stopAt?: string; - maxTokens?: number; -}; - -export async function gen_tokens(options: GenOptions): Promise { +export async function genTokens(options: GenOptions): Promise { console.log("GEN", options); const res: Token[] = []; const { @@ -763,7 +746,7 @@ export async function gen_tokens(options: GenOptions): Promise { } export async function gen(options: GenOptions): Promise { - const tokens = await gen_tokens(options); + const tokens = await genTokens(options); return detokenize(tokens).decode(); } @@ -806,3 +789,15 @@ Uint8Array.prototype.toString = function (this: Uint8Array) { Uint8Array.prototype.decode = function (this: Uint8Array) { return _aici.bufferToString(this); }; + +console.log = log; +console.info = log; +console.warn = log; +console.debug = log; +console.trace = log; + +globalThis.$ = $; +globalThis.fixed = fixed; +globalThis.assert = assert; +globalThis.gen = gen; +globalThis.genTokens = genTokens; \ No newline at end of file diff --git a/jsctrl/ts/gen/aici.d.ts b/jsctrl/ts/gen/aici.d.ts new file mode 100644 index 00000000..47403943 --- /dev/null +++ b/jsctrl/ts/gen/aici.d.ts @@ -0,0 +1,225 @@ +/// +import { TokenSet, tokenize, detokenize, regexConstraint, cfgConstraint, substrConstraint, Constraint, getVar, setVar, appendVar, eosToken, panic } from "_aici"; +export { TokenSet, tokenize, detokenize, getVar, setVar, appendVar, eosToken }; +export type SeqId = number; +type int = number; +export declare function inspect(v: any): string; +export declare function log(...args: any[]): void; +export declare class AssertionError extends Error { +} +/** + * Throw an exception if the condition is not met. + */ +export declare function assert(cond: boolean, msg?: string): asserts cond; +/** + * Get list of tokens in the current sequence, including the prompt. + */ +export declare function getTokens(): Token[]; +/** + * Get the length of the prompt in the current sequence. + */ +export declare function getPromptLen(): number; +export declare class MidProcessResult { + _n_skip_me: boolean; + _n_stop: boolean; + _n_logit_bias: TokenSet | null; + _n_backtrack: number; + _n_ff_tokens: Token[]; + constructor(); + static stop(): MidProcessResult; + static skipMe(): MidProcessResult; + static bias(bias: TokenSet): MidProcessResult; + static splice(backtrack: number, tokens: Token[]): MidProcessResult; +} +export declare class PreProcessResult { + _n_suspended: boolean; + _n_ff_tokens: Token[]; + _n_attention_masks: number[][]; + constructor(); + static continue_(): PreProcessResult; + static suspend(): PreProcessResult; + static fork(numForks: number): PreProcessResult; + static ffTokens(toks: Token[]): PreProcessResult; +} +export declare class PostProcessResult { + _n_stop_seq: boolean; + constructor(stop_seq?: boolean); + static continue_(): PostProcessResult; + static stop(): PostProcessResult; + static fromTokens(tokens: Token[]): PostProcessResult; +} +export declare class NextToken { + finished: boolean; + currTokens: Token[] | null; + forkGroup: SeqId[]; + _resolve?: (value: Token[]) => void; + constructor(); + /** + * Awaiting this will return generated token (or tokens, if fast-forwarding requested by self.mid_process()). + * You have only ~1ms to process the results before awaiting a new instance of NextToken() again. + */ + run(): Promise; + /** + * Override to suspend, if the model cannot continue generating tokens + * now (for example, not all variables are available to compute bias). + * ~1ms time limit. + */ + preProcess(): PreProcessResult; + /** + * This can be overridden to return a bias, fast-forward tokens, backtrack etc. + * ~20ms time limit. + */ + midProcess(): MidProcessResult; + /** + * This can be overridden to do something with generated tokens. + * ~1ms time limit. + * @param tokens tokens generated in the last step + */ + postProcess(tokens: Token[]): PostProcessResult; + _pre_process(): PreProcessResult; + _mid_process(fork_group: SeqId[]): MidProcessResult; + _post_process(_backtrack: int, tokens: Token[]): PostProcessResult; + private reset; +} +/** + * Forces next tokens to be exactly the given text. + */ +export declare function fixed(text: string): Promise; +/** + * Force the exact tokens to be generated; usage: await $`Some text` + */ +export declare function $(strings: TemplateStringsArray, ...values: any[]): Promise; +/** + * Forces next tokens to be exactly the given text. + * If following is given, the text replaces everything that follows the label. + */ +declare class FixedTokens extends NextToken { + fixedTokens: Token[]; + following: Label | null; + constructor(text: string | Buffer, following?: Label | null); + preProcess(): PreProcessResult; + midProcess(): MidProcessResult; +} +/** + * Indicates that the generation should stop. + */ +declare class StopToken extends NextToken { + constructor(); + midProcess(): MidProcessResult; + postProcess(_tokens: Token[]): PostProcessResult; +} +/** + * Generates a token that satisfies the given constraint. + * The constraint will be constructed in mid_process() phase, which has slightly longer time limit. + */ +export declare class ConstrainedToken extends NextToken { + mkConstraint: () => Constraint; + _constraint: Constraint | null; + constructor(mkConstraint: () => Constraint); + midProcess(): MidProcessResult; + postProcess(tokens: Token[]): PostProcessResult; +} +export declare class PreToken extends NextToken { + midProcess(): MidProcessResult; +} +/** + * Forks the execution into `numForks` branches. + * @param numForks how many branches + * @returns a number from 0 to `numForks`-1, indicating the branch + */ +export declare function fork(numForks: number): Promise; +/** + * Suspends execution until all variables are available. + * @param vars names of variables + * @returns values of the variables + */ +export declare function waitVars(...vars: string[]): Promise; +/** + * Low-level interface for AICI. Use aici.start() to wrap a coroutine. + */ +export interface AiciCallbacks { + init_prompt(prompt: Token[]): void; + pre_process(): PreProcessResult; + mid_process(fork_group: SeqId[]): MidProcessResult; + post_process(backtrack: number, tokens: Token[]): PostProcessResult; +} +/** + * Awaiting this returns the prompt passed by the user. + * The code before call to this function has a long time limit (~1000ms). + * Afterwards, the time limit is ~1ms before awaiting NextToken(). + */ +export declare function getPrompt(): Promise; +declare class GetPrompt { + _resolve?: (value: Token[]) => void; + run(): Promise; +} +export type CbType = NextToken; +export declare class AiciAsync implements AiciCallbacks { + static instance: AiciAsync; + _tokens: Token[]; + _prompt_len: number; + private _pendingCb; + private _token; + private _getPrompt; + private midProcessReEntry; + _setGetPrompt(g: GetPrompt): void; + _nextToken(t: NextToken): void; + constructor(f: () => Promise); + step(tokens: Token[]): void; + init_prompt(prompt: Token[]): void; + pre_process(): PreProcessResult; + mid_process(fork_group: SeqId[]): MidProcessResult; + post_process(backtrack: number, tokens: Token[]): PostProcessResult; +} +/** + * Starts the AICI loop. The coroutine may first `await aici.getPrompt()` and + * then can `await aici.gen_*()` or `await aici.FixedTokens()` multiple times. + * @param f async function + */ +export declare function start(f: () => Promise): AiciAsync; +/** + * Runs the loop as a test. + */ +export declare function test(f: () => Promise): AiciAsync; +export declare class Label { + ptr: number; + /** + * Create a new label the indicates the current position in the sequence. + * Can be passed as `following=` argument to `FixedTokens()`. + */ + constructor(); + /** + * Return tokens generated since the label. + */ + tokensSince(): Token[]; + /** + * Return text generated since the label. + */ + textSince(): string; + /** + * Generate given prompt text, replacing all text after the current label. + */ + fixedAfter(text: string): Promise; +} +export declare class ChooseConstraint extends Constraint { + ptr: number; + options: Token[][]; + constructor(options: string[]); + eosAllowed(): boolean; + eosForced(): boolean; + tokenAllowed(t: Token): boolean; + appendToken(t: Token): void; + allowTokens(ts: TokenSet): void; +} +export declare function genTokens(options: GenOptions): Promise; +export declare function gen(options: GenOptions): Promise; +export declare function checkVar(name: string, value: string): void; +export declare function checkVars(d: Record): void; +export declare const helpers: { + regex_constraint: typeof regexConstraint; + cfg_constraint: typeof cfgConstraint; + substr_constraint: typeof substrConstraint; + FixedTokens: typeof FixedTokens; + StopToken: typeof StopToken; + panic: typeof panic; +}; diff --git a/jsctrl/ts/aici.js b/jsctrl/ts/gen/aici.js similarity index 96% rename from jsctrl/ts/aici.js rename to jsctrl/ts/gen/aici.js index 0489fbba..c2438518 100644 --- a/jsctrl/ts/aici.js +++ b/jsctrl/ts/gen/aici.js @@ -2,7 +2,7 @@ import { TokenSet, tokenize, detokenize, regexConstraint, cfgConstraint, substrConstraint, Constraint, getVar, setVar, appendVar, eosToken, panic, } from "_aici"; export { TokenSet, tokenize, detokenize, getVar, setVar, appendVar, eosToken }; import * as _aici from "_aici"; -function dbgarg(arg, depth) { +function dbgArg(arg, depth) { const maxElts = 20; const maxDepth = 2; const maxStr = 200; @@ -23,7 +23,7 @@ function dbgarg(arg, depth) { arg = arg.slice(0, maxElts); suff = ", ...]"; } - return "[" + arg.map((x) => dbgarg(x, depth + 1)).join(", ") + suff; + return "[" + arg.map((x) => dbgArg(x, depth + 1)).join(", ") + suff; } else { let keys = Object.keys(arg); @@ -35,7 +35,7 @@ function dbgarg(arg, depth) { keys = keys.slice(0, maxElts); } return ("{" + - keys.map((k) => `${k}: ${dbgarg(arg[k], depth + 1)}`).join(", ") + + keys.map((k) => `${k}: ${dbgArg(arg[k], depth + 1)}`).join(", ") + suff); } } @@ -54,18 +54,16 @@ function dbgarg(arg, depth) { } } export function inspect(v) { - return dbgarg(v, 0); + return dbgArg(v, 0); } export function log(...args) { console._print(args.map((x) => inspect(x)).join(" ")); } -console.log = log; -console.info = log; -console.warn = log; -console.debug = log; -console.trace = log; export class AssertionError extends Error { } +/** + * Throw an exception if the condition is not met. + */ export function assert(cond, msg = "Assertion failed") { if (!cond) throw new AssertionError(msg); @@ -223,7 +221,7 @@ export async function fixed(text) { await new FixedTokens(text).run(); } /** - * Same as fixed(); usage: await $`Some text` + * Force the exact tokens to be generated; usage: await $`Some text` */ export async function $(strings, ...values) { let result = ""; @@ -510,7 +508,7 @@ export function test(f) { } export class Label { /** - * Create a new label the indictes the current position in the sequence. + * Create a new label the indicates the current position in the sequence. * Can be passed as `following=` argument to `FixedTokens()`. */ constructor() { @@ -565,7 +563,7 @@ export class ChooseConstraint extends Constraint { } } } -export async function gen_tokens(options) { +export async function genTokens(options) { console.log("GEN", options); const res = []; const { regex, yacc, substring, substringEnd = '"', options: optionList, storeVar, stopAt, maxTokens = 20, } = options; @@ -608,7 +606,7 @@ export async function gen_tokens(options) { return res; } export async function gen(options) { - const tokens = await gen_tokens(options); + const tokens = await genTokens(options); return detokenize(tokens).decode(); } export function checkVar(name, value) { @@ -645,3 +643,13 @@ Uint8Array.prototype.toString = function () { Uint8Array.prototype.decode = function () { return _aici.bufferToString(this); }; +console.log = log; +console.info = log; +console.warn = log; +console.debug = log; +console.trace = log; +globalThis.$ = $; +globalThis.fixed = fixed; +globalThis.assert = assert; +globalThis.gen = gen; +globalThis.genTokens = genTokens; diff --git a/jsctrl/ts/native.d.ts b/jsctrl/ts/native.d.ts index f08ab136..0b2c210c 100644 --- a/jsctrl/ts/native.d.ts +++ b/jsctrl/ts/native.d.ts @@ -1,3 +1,96 @@ +// Top-level symbols + +type Token = number; +type Buffer = Uint8Array; + +/** + * Force the exact tokens to be generated; usage: await $`Some text` + */ +declare function $(strings: TemplateStringsArray, ...values: any[]): Promise; + +/** + * Throw an exception if the condition is not met. + */ +declare function assert(cond: boolean, msg?: string): asserts cond; + +/** + * Forces next tokens to be exactly the given text. + */ +declare function fixed(text: string): Promise; + +/** + * Forks the execution into `numForks` branches. + * @param numForks how many branches + * @returns a number from 0 to `numForks`-1, indicating the branch + */ +declare function fork(numForks: number): Promise; + +/** + * Suspends execution until all variables are available. + * @param vars names of variables + * @returns values of the variables + */ +declare function waitVars(...vars: string[]): Promise; + +/** + * Starts the AICI loop. The coroutine may first `await aici.getPrompt()` and + * then can `await aici.gen_*()` or `await aici.FixedTokens()` multiple times. + * @param f async function + */ +declare function start(f: () => Promise): void; + +/** + * Specifies options for gen() and genTokens(). + */ +interface GenOptions { + /** + * Make sure the generated text is one of the options. + */ + options?: string[]; + /** + * Make sure the generated text matches given regular expression. + */ + regex?: string | RegExp; + /** + * Make sure the generated text matches given yacc-like grammar. + */ + yacc?: string; + /** + * Make sure the generated text is a substring of the given string. + */ + substring?: string; + /** + * Used together with `substring` - treat the substring as ending the substring + * (typically '"' or similar). + */ + substringEnd?: string; + /** + * Store result of the generation (as bytes) into a shared variable. + */ + storeVar?: string; + /** + * Stop generation when the string is generated (the result includes the string and any following bytes (from the same token)). + */ + stopAt?: string; + /** + * Stop generation when the given number of tokens have been generated. + */ + maxTokens?: number; +} + +/** + * Generate a string that matches given constraints. + * If the tokens do not map cleanly into strings, it will contain Unicode replacement characters. + */ +declare function gen(options: GenOptions): Promise; + +/** + * Generate a list of tokens that matches given constraints. + */ +declare function genTokens(options: GenOptions): Promise; + +// Extensions of JavaScript built-in types + interface String { /** * UTF-8 encode the current string. @@ -7,7 +100,7 @@ interface String { interface StringConstructor { /** - * Create a string from UTF-8 buffer (with replacement cheracter for invalid sequences) + * Create a string from UTF-8 buffer (with replacement character for invalid sequences) */ fromBuffer(buffer: Uint8Array): string; } @@ -19,6 +112,25 @@ interface Uint8Array { decode(): string; } +/** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console) */ +interface Console { + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/debug) */ + debug(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/error) */ + error(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/info) */ + info(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/log) */ + log(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/trace) */ + trace(...data: any[]): void; + /** [MDN Reference](https://developer.mozilla.org/docs/Web/API/console/warn) */ + warn(...data: any[]): void; +} + +declare var console: Console; + +// native module declare module "_aici" { type Buffer = Uint8Array; @@ -70,7 +182,7 @@ declare module "_aici" { function stringToBuffer(s: string): Buffer; /** - * UTF-8 decode (with replacement cheracter for invalid sequences) + * UTF-8 decode (with replacement character for invalid sequences) */ function bufferToString(b: Buffer): string; diff --git a/jsctrl/ts/sample.js b/jsctrl/ts/sample.js deleted file mode 100644 index 905a6490..00000000 --- a/jsctrl/ts/sample.js +++ /dev/null @@ -1,128 +0,0 @@ -import { $, Label, assert, checkVars, fork, gen, getPrompt, test, waitVars, getVar, setVar, } from "./aici"; -async function main() { - await $ `2 + 2 =`; - await gen({ maxTokens: 5 }); -} -async function test_sample() { - // initialization code - console.log("I'm going in the logs!"); - // ... more initialization code, it has long time limit - const _prompt = await getPrompt(); - // here we're out of initialization code - the time limits are tight - // This appends the exact string to the output; similar to adding it to prompt - await $ `The word 'hello' in French is`; - // generate text (tokens) matching the regex - const french = await gen({ regex: / "[^"]+"/, maxTokens: 5 }); - // set a shared variable (they are returned as JSON and are useful with aici.fork()) - setVar("french", french); - await $ ` and in German`; - // shorthand for the above - await gen({ regex: / "[^"]+"/, storeVar: "german" }); - await $ `\nFive`; - // generates one of the strings - await gen({ options: [" pounds", " euros", " dollars"] }); -} -async function test_backtrack_one() { - await $ `3+`; - const l = new Label(); - await $ `2`; - const x = await gen({ regex: /=\d\d?\./, storeVar: "x", maxTokens: 5 }); - console.log("X", x); - await l.fixedAfter("4"); - const y = await gen({ regex: /=\d\d?\./, storeVar: "y", maxTokens: 5 }); - console.log("Y", y); - checkVars({ x: "=5.", y: "=7." }); -} -async function test_fork() { - await $ `The word 'hello' in`; - const id = await fork(3); - if (id === 0) { - const [french, german] = await waitVars("french", "german"); - await $ `${french} is the same as ${german}.`; - await gen({ maxTokens: 5 }); - checkVars({ german: ' "hallo"', french: ' "bonjour"' }); - } - else if (id === 1) { - await $ ` German is`; - await gen({ regex: / "[^"\.]+"/, storeVar: "german", maxTokens: 5 }); - } - else if (id === 2) { - await $ ` French is`; - await gen({ regex: / "[^"\.]+"/, storeVar: "french", maxTokens: 5 }); - } -} -async function test_backtrack_lang() { - await $ `The word 'hello' in`; - const l = new Label(); - await l.fixedAfter(` French is`); - await gen({ regex: / "[^"\.]+"/, storeVar: "french", maxTokens: 5 }); - await l.fixedAfter(` German is`); - await gen({ regex: / "[^"\.]+"/, storeVar: "german", maxTokens: 5 }); - checkVars({ french: ' "bonjour"', german: ' "hallo"' }); -} -async function test_main() { - console.log("start"); - console.log(getVar("test")); - setVar("test", "hello"); - const v = getVar("test"); - console.log(typeof v); - const prompt = await getPrompt(); - console.log(prompt); - await $ `The word 'hello' in French is`; - await gen({ storeVar: "french", maxTokens: 5 }); - await $ `\nIn German it translates to`; - await gen({ regex: / "[^"]+"/, storeVar: "german" }); - await $ `\nFive`; - await gen({ - storeVar: "five", - options: [" pounds", " euros"], - }); - await $ ` is worth about $`; - await gen({ regex: /\d+\.\d/, storeVar: "dollars" }); - checkVars({ - test: "hello", - french: " 'bonjour'.", - german: ' "guten Tag"', - five: " pounds", - dollars: "7.5", - }); -} -async function test_drugs() { - const drug_syn = "\nUse Drug Name syntax for any drug name, for example Advil.\n\n"; - let notes = "The patient should take some tylenol in the evening and aspirin in the morning. Exercise is highly recommended. Get lots of sleep.\n"; - notes = "Start doctor note:\n" + notes + "\nEnd doctor note.\n"; - await $ `[INST] `; - const start = new Label(); - function inst(s) { - return s + drug_syn + notes + " [/INST]\n"; - } - await $ `${inst("List specific drug names in the following doctor's notes.")}\n1. `; - const s = await gen({ maxTokens: 30 }); - const drugs = []; - ("" + s).replace(/([^<]*)<\/drug>/g, (_, d) => { - drugs.push(d); - return ""; - }); - console.log("drugs", drugs); - await start.fixedAfter(inst("Make a list of each drug along with time to take it, based on the following doctor's notes.") + `Take `); - const pos = new Label(); - await gen({ options: drugs.map((d) => d + "") }); - for (let i = 0; i < 5; i++) { - const fragment = await gen({ maxTokens: 20, stopAt: "" }); - console.log(fragment); - if (fragment.includes("")) { - assert(fragment.endsWith("")); - await gen({ options: drugs.map((d) => d + "") }); - } - else { - break; - } - } - setVar("times", "" + pos.textSince()); - checkVars({ - times: "Tylenol in the evening.\n" + - "Take Aspirin in the morning.\n" + - "Exercise is highly recommended.\nGet lots of sleep.", - }); -} -test(test_fork); diff --git a/jsctrl/ts/tsconfig.json b/jsctrl/ts/tsconfig.json index e5f16da2..ecc9a95e 100644 --- a/jsctrl/ts/tsconfig.json +++ b/jsctrl/ts/tsconfig.json @@ -49,13 +49,13 @@ // "maxNodeModuleJsDepth": 1, /* Specify the maximum folder depth used for checking JavaScript files from 'node_modules'. Only applicable with 'allowJs'. */ /* Emit */ - // "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */ + "declaration": true, /* Generate .d.ts files from TypeScript and JavaScript files in your project. */ // "declarationMap": true, /* Create sourcemaps for d.ts files. */ // "emitDeclarationOnly": true, /* Only output d.ts files and not JavaScript files. */ // "sourceMap": true, /* Create source map files for emitted JavaScript files. */ // "inlineSourceMap": true, /* Include sourcemap files inside the emitted JavaScript. */ // "outFile": "./", /* Specify a file that bundles all outputs into one JavaScript file. If 'declaration' is true, also designates a file that bundles all .d.ts output. */ - // "outDir": "./", /* Specify an output folder for all emitted files. */ + "outDir": "./gen", /* Specify an output folder for all emitted files. */ // "removeComments": true, /* Disable emitting comments. */ // "noEmit": true, /* Disable emitting files from a compilation. */ // "importHelpers": true, /* Allow importing helper functions from tslib once per project, instead of including them per-file. */ @@ -88,13 +88,13 @@ // "strictFunctionTypes": true, /* When assigning functions, check to ensure parameters and the return values are subtype-compatible. */ // "strictBindCallApply": true, /* Check that the arguments for 'bind', 'call', and 'apply' methods match the original function. */ // "strictPropertyInitialization": true, /* Check for class properties that are declared but not set in the constructor. */ - // "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */ + "noImplicitThis": true, /* Enable error reporting when 'this' is given the type 'any'. */ // "useUnknownInCatchVariables": true, /* Default catch clause variables as 'unknown' instead of 'any'. */ // "alwaysStrict": true, /* Ensure 'use strict' is always emitted. */ // "noUnusedLocals": true, /* Enable error reporting when local variables aren't read. */ // "noUnusedParameters": true, /* Raise an error when a function parameter isn't read. */ // "exactOptionalPropertyTypes": true, /* Interpret optional property types as written, rather than adding 'undefined'. */ - // "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */ + "noImplicitReturns": true, /* Enable error reporting for codepaths that do not explicitly return in a function. */ // "noFallthroughCasesInSwitch": true, /* Enable error reporting for fallthrough cases in switch statements. */ // "noUncheckedIndexedAccess": true, /* Add 'undefined' to a type when accessed using an index. */ // "noImplicitOverride": true, /* Ensure overriding members in derived classes are marked with an override modifier. */ @@ -104,6 +104,6 @@ /* Completeness */ // "skipDefaultLibCheck": true, /* Skip type checking .d.ts files that are included with TypeScript. */ - "skipLibCheck": true /* Skip type checking all .d.ts files. */ + // "skipLibCheck": true /* Skip type checking all .d.ts files. */ } }