diff --git a/src/experts/assistant.js b/src/experts/assistant.js index f13963d..d92bd8d 100644 --- a/src/experts/assistant.js +++ b/src/experts/assistant.js @@ -39,6 +39,7 @@ class Assistant { options.temperature !== undefined ? options.temperature : 1.0; this.top_p = options.top_p !== undefined ? options.top_p : 1.0; this.experts = []; + this.expertsFunctionNames = []; this.tools = options.tools || []; this.tool_resources = options.tool_resources || {}; this._metadata = options.metadata; @@ -107,7 +108,10 @@ class Assistant { this.experts.push(assistantTool); if (assistantTool.isParentsTools) { for (const tool of assistantTool.parentsTools) { - this.tools.push(tool); + if (tool.type === "function") { + this.tools.push(tool); + this.expertsFunctionNames.push(tool.function.name); + } } } } diff --git a/src/experts/run.js b/src/experts/run.js index 78d5cbd..1d8abe5 100644 --- a/src/experts/run.js +++ b/src/experts/run.js @@ -126,6 +126,7 @@ class Run { #findExpertByToolName(functionName) { let toolCaller; + // Always ask an expert if they can answer. this.assistant.experts.forEach((expert) => { if (expert.isParentsTools) { expert.parentsTools.forEach((parentTool) => { @@ -138,6 +139,15 @@ class Run { }); } }); + if (toolCaller) return toolCaller; + // Allow this assistant to use its own tool if that tool is not a linked expert. + if (!this.assistant.expertsFunctionNames.includes(functionName)) { + this.assistant.tools.forEach((tool) => { + if (tool.type === "function" && tool.function.name === functionName) { + toolCaller = this.assistant; + } + }); + } return toolCaller; } } diff --git a/test/fixtures/echoTool.js b/test/fixtures/echoTool.js index 407efe2..0cdcad7 100644 --- a/test/fixtures/echoTool.js +++ b/test/fixtures/echoTool.js @@ -12,11 +12,11 @@ class EchoTool extends Tool { type: "function", function: { name: "marco", - description: "Use this tool if you get the /marco command.", + description: "Use this tool if you get the '/marco' message.", parameters: { type: "object", - properties: { invoke: { type: "boolean" } }, - required: ["invoke"], + properties: { invoke_marco: { type: "boolean" } }, + required: ["invoke_marco"], }, }, }, @@ -36,12 +36,17 @@ class EchoTool extends Tool { }, ], }); + this.marcoToolCallCount = 0; } async ask(message, threadID, options = {}) { - const json = JSON.parse(message); - if (json.message === "marco") { - return "polo"; + let json; + try { + json = JSON.parse(message); + } catch (error) {} + if (json?.invoke_marco === true) { + this.marcoToolCallCount++; + return "poolo"; } return await super.ask(message, threadID, options); } diff --git a/test/fixtures/routerAssistant.js b/test/fixtures/routerAssistant.js index 47a1142..aae3023 100644 --- a/test/fixtures/routerAssistant.js +++ b/test/fixtures/routerAssistant.js @@ -7,8 +7,9 @@ class RouterAssistant extends Assistant { super({ name: helperName("Router"), description: "Conversational Router", + temperature: 0.1, instructions: - "Routes messages to the right tool. Send any message starting with the /echo or /marco command to the echo tool. If no tool can be found for the message, reply with one word 'unrouteable' as the error.", + "Routes messages to the right tool. Send any message starting with the '/echo' or '/marco' command to that tool using the exact message you received. If no tool can be found for the message, reply with one word 'unrouteable' as the error.", }); this.addAssistantTool(EchoTool); } diff --git a/test/uat/router.test.js b/test/uat/router.test.js index 1c8a5c9..cca1521 100644 --- a/test/uat/router.test.js +++ b/test/uat/router.test.js @@ -24,8 +24,10 @@ test("each has own thread using metadata links", async () => { expect(thread2.metadata.tool).toMatch(/Experts\.js \(EchoTool\)/); }); -test("commands can pass from tool to tool", async () => { +test("commands can pass from assistant to assistant's tool", async () => { const threadID = await helperThreadID(); const output = await assistant.ask("/marco", threadID); - expect(output).toMatch(/polo/); + const tool = assistant.experts[0]; + expect(tool.marcoToolCallCount).toBe(1); + expect(output).toMatch(/poolo/); });