From 8cc139cadaccb4779c38b547b2d59d9cb434708b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Sun, 12 Jan 2025 14:37:19 +0800 Subject: [PATCH 1/6] refactor(generate-text)!: trampoline function --- packages/generate-text/src/index.ts | 146 ++++++++++++++-------------- 1 file changed, 71 insertions(+), 75 deletions(-) diff --git a/packages/generate-text/src/index.ts b/packages/generate-text/src/index.ts index ad0679b..01c08ad 100644 --- a/packages/generate-text/src/index.ts +++ b/packages/generate-text/src/index.ts @@ -10,6 +10,8 @@ import { export interface GenerateTextOptions extends ChatOptions { /** @default 1 */ maxSteps?: number + /** @internal */ + steps?: StepResult[] /** if you want to enable stream, use `@xsai/stream-{text,object}` */ stream?: never } @@ -66,53 +68,51 @@ export interface StepResult { usage: GenerateTextResponseUsage } -export const generateText = async (options: GenerateTextOptions): Promise => { - let currentStep = 0 - - let finishReason: FinishReason = 'error' - let text - let usage: GenerateTextResponseUsage = { - completion_tokens: 0, - prompt_tokens: 0, - total_tokens: 0, - } - - const steps: StepResult[] = [] - const messages: Message[] = options.messages - const toolCalls: ToolCall[] = [] - const toolResults: ToolResult[] = [] - while (currentStep < (options.maxSteps ?? 1)) { - currentStep += 1 - - const data: GenerateTextResponse = await chat({ - ...options, - maxSteps: undefined, - messages, - stream: false, - }).then(res => res.json()) - - const { finish_reason, message } = data.choices[0] - - finishReason = finish_reason - text = message.content - usage = data.usage - - const stepResult: StepResult = { - text: message.content, - toolCalls: [], - toolResults: [], - // type: 'initial', - usage, - } - - // TODO: fix types - messages.push({ ...message, content: message.content! }) - - if (message.tool_calls) { - // execute tools - for (const toolCall of message.tool_calls ?? []) { +/** @internal */ +type RawGenerateTextTrampoline = Promise<(() => RawGenerateTextTrampoline) | T> + +/** @internal */ +type RawGenerateText = (options: GenerateTextOptions) => RawGenerateTextTrampoline + +/** @internal */ +const rawGenerateText: RawGenerateText = async (options: GenerateTextOptions) => + await chat({ + ...options, + maxSteps: undefined, + messages: options.messages, + stream: false, + }) + .then(res => res.json() as Promise) + .then(async ({ choices, usage }) => { + const messages: Message[] = options.messages + const steps: StepResult[] = options.steps ?? [] + const toolCalls: ToolCall[] = [] + const toolResults: ToolResult[] = [] + + const { finish_reason: finishReason, message } = choices[0] + + if (message.content || !message.tool_calls || steps.length >= (options.maxSteps ?? 1)) { + const step: StepResult = { + text: message.content, + toolCalls, + toolResults, + usage, + } + + steps.push(step) + + return { + finishReason, + steps, + ...step, + } + } + + messages.push({ ...message, content: message.content! }) + + for (const toolCall of message.tool_calls) { const tool = (options.tools as Tool[]).find(tool => tool.function.name === toolCall.function.name)! - const parsedArgs: Record = JSON.parse(toolCall.function.arguments) + const parsedArgs: Record = JSON.parse(toolCall.function.arguments) const toolResult = await tool.execute(parsedArgs) const toolMessage = { content: toolResult, @@ -120,48 +120,44 @@ export const generateText = async (options: GenerateTextOptions): Promise await rawGenerateText({ + ...options, + messages, + steps, + }) + }) + +export const generateText = async (options: GenerateTextOptions): Promise => { + let result = await rawGenerateText(options) + + while (result instanceof Function) + result = await result() + + return result } export default generateText From f7a1ca0958af2cdd13423d6967cf960abf1197e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Sun, 12 Jan 2025 14:45:01 +0800 Subject: [PATCH 2/6] fix(generate-text): clean steps --- packages/generate-text/src/index.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/generate-text/src/index.ts b/packages/generate-text/src/index.ts index 01c08ad..fceadf5 100644 --- a/packages/generate-text/src/index.ts +++ b/packages/generate-text/src/index.ts @@ -80,6 +80,7 @@ const rawGenerateText: RawGenerateText = async (options: GenerateTextOptions) => ...options, maxSteps: undefined, messages: options.messages, + steps: undefined, stream: false, }) .then(res => res.json() as Promise) From d076088fe8bc4663d66d81a1f83fabd8eea90fcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Sun, 12 Jan 2025 14:56:42 +0800 Subject: [PATCH 3/6] refactor(generate-text): simplify naming --- packages/generate-text/src/index.ts | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/packages/generate-text/src/index.ts b/packages/generate-text/src/index.ts index fceadf5..04de46d 100644 --- a/packages/generate-text/src/index.ts +++ b/packages/generate-text/src/index.ts @@ -113,13 +113,8 @@ const rawGenerateText: RawGenerateText = async (options: GenerateTextOptions) => for (const toolCall of message.tool_calls) { const tool = (options.tools as Tool[]).find(tool => tool.function.name === toolCall.function.name)! - const parsedArgs: Record = JSON.parse(toolCall.function.arguments) - const toolResult = await tool.execute(parsedArgs) - const toolMessage = { - content: toolResult, - role: 'tool', - tool_call_id: toolCall.id, - } satisfies Message + const args: Record = JSON.parse(toolCall.function.arguments) + const result = await tool.execute(args) toolCalls.push({ args: toolCall.function.arguments, @@ -129,13 +124,17 @@ const rawGenerateText: RawGenerateText = async (options: GenerateTextOptions) => }) toolResults.push({ - args: parsedArgs, - result: toolResult, + args, + result, toolCallId: toolCall.id, toolName: toolCall.function.name, }) - messages.push(toolMessage) + messages.push({ + content: result, + role: 'tool', + tool_call_id: toolCall.id, + }) } steps.push({ From 7ca9f0b25c35df5eccef79a491d803b5dac8d380 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Sun, 12 Jan 2025 15:03:11 +0800 Subject: [PATCH 4/6] refactor(generate-text): simplify naming --- packages/generate-text/src/index.ts | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/packages/generate-text/src/index.ts b/packages/generate-text/src/index.ts index 04de46d..671cb31 100644 --- a/packages/generate-text/src/index.ts +++ b/packages/generate-text/src/index.ts @@ -111,29 +111,33 @@ const rawGenerateText: RawGenerateText = async (options: GenerateTextOptions) => messages.push({ ...message, content: message.content! }) - for (const toolCall of message.tool_calls) { - const tool = (options.tools as Tool[]).find(tool => tool.function.name === toolCall.function.name)! - const args: Record = JSON.parse(toolCall.function.arguments) - const result = await tool.execute(args) + for (const { + function: { arguments: toolArgs, name: toolName }, + id: toolCallId, + type: toolCallType, + } of message.tool_calls) { + const tool = (options.tools as Tool[]).find(tool => tool.function.name === toolName)! + const parsedArgs: Record = JSON.parse(toolArgs) + const result = await tool.execute(parsedArgs) toolCalls.push({ - args: toolCall.function.arguments, - toolCallId: toolCall.id, - toolCallType: toolCall.type, - toolName: toolCall.function.name, + args: toolArgs, + toolCallId, + toolCallType, + toolName, }) toolResults.push({ - args, + args: parsedArgs, result, - toolCallId: toolCall.id, - toolName: toolCall.function.name, + toolCallId, + toolName, }) messages.push({ content: result, role: 'tool', - tool_call_id: toolCall.id, + tool_call_id: toolCallId, }) } From 7ba6a7392ccca426c3ad9604406873c889085a34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Sun, 12 Jan 2025 16:33:28 +0800 Subject: [PATCH 5/6] chore(tool): update test --- packages/tool/test/index.test.ts | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/packages/tool/test/index.test.ts b/packages/tool/test/index.test.ts index 775d3f4..e28318b 100644 --- a/packages/tool/test/index.test.ts +++ b/packages/tool/test/index.test.ts @@ -50,7 +50,7 @@ describe('@xsai/tool', () => { }), }) - const { text } = await generateText({ + const { steps, text } = await generateText({ ...ollama.chat('mistral-nemo'), maxSteps: 2, messages: [ @@ -69,5 +69,14 @@ describe('@xsai/tool', () => { }) expect(text).toMatchSnapshot() + + const { toolCalls, toolResults } = steps[0] + + expect(toolCalls[0].toolName).toBe('weather') + expect(toolCalls[0].args).toBe('{"location":"San Francisco"}') + + expect(toolCalls[0].toolName).toBe('weather') + expect(toolResults[0].args).toStrictEqual({ location: 'San Francisco' }) + expect(toolResults[0].result).toBe('{"location":"San Francisco","temperature":42}') }, 20000) }) From 0af7ee63602ebdf03bdd5031a1937f97d29b27b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=97=8D+85CD?= <50108258+kwaa@users.noreply.github.com> Date: Sun, 12 Jan 2025 16:40:14 +0800 Subject: [PATCH 6/6] chore(generate-text): update test --- .../test/__snapshots__/index.test.ts.snap | 16 ++++++++++++++++ packages/generate-text/test/index.test.ts | 6 +++++- 2 files changed, 21 insertions(+), 1 deletion(-) create mode 100644 packages/generate-text/test/__snapshots__/index.test.ts.snap diff --git a/packages/generate-text/test/__snapshots__/index.test.ts.snap b/packages/generate-text/test/__snapshots__/index.test.ts.snap new file mode 100644 index 0000000..bffc342 --- /dev/null +++ b/packages/generate-text/test/__snapshots__/index.test.ts.snap @@ -0,0 +1,16 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`@xsai/generate-text > basic 1`] = ` +[ + { + "text": "YES", + "toolCalls": [], + "toolResults": [], + "usage": { + "completion_tokens": 2, + "prompt_tokens": 46, + "total_tokens": 48, + }, + }, +] +`; diff --git a/packages/generate-text/test/index.test.ts b/packages/generate-text/test/index.test.ts index 7b4bcc6..8d12120 100644 --- a/packages/generate-text/test/index.test.ts +++ b/packages/generate-text/test/index.test.ts @@ -5,7 +5,7 @@ import { generateText } from '../src' describe('@xsai/generate-text', () => { it('basic', async () => { - const { text } = await generateText({ + const { finishReason, steps, text, toolCalls, toolResults } = await generateText({ ...ollama.chat('llama3.2'), messages: [ { @@ -20,6 +20,10 @@ describe('@xsai/generate-text', () => { }) expect(text).toStrictEqual('YES') + expect(finishReason).toBe('stop') + expect(toolCalls.length).toBe(0) + expect(toolResults.length).toBe(0) + expect(steps).toMatchSnapshot() }) // TODO: error handling