From 4d6881b506cb19695769550383e88818fbbaee73 Mon Sep 17 00:00:00 2001 From: Henry Date: Tue, 30 Jan 2024 14:12:55 +0000 Subject: [PATCH] add return source documens to retriever tool --- .../OpenAIFunctionAgent.ts | 12 ++++++-- .../CustomListOutputParser.ts | 14 ++++++---- .../tools/RetrieverTool/RetrieverTool.ts | 28 +++++++++++++++++-- packages/components/src/agents.ts | 25 ++++++++++++++++- .../Conversational Retrieval Agent.json | 10 ++++++- packages/server/src/index.ts | 11 +++++++- 6 files changed, 86 insertions(+), 14 deletions(-) diff --git a/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts b/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts index c21c887aaf2..9c25b2a9a5e 100644 --- a/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts +++ b/packages/components/nodes/agents/OpenAIFunctionAgent/OpenAIFunctionAgent.ts @@ -64,7 +64,7 @@ class OpenAIFunctionAgent_Agents implements INode { return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory) } - async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { + async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { const memory = nodeData.inputs?.memory as FlowiseMemory const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory) @@ -72,12 +72,20 @@ class OpenAIFunctionAgent_Agents implements INode { const callbacks = await additionalCallbacks(nodeData, options) let res: ChainValues = {} + let sourceDocuments: ICommonObject[] = [] if (options.socketIO && options.socketIOClientId) { const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) + if (res.sourceDocuments) { + options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) + sourceDocuments = res.sourceDocuments + } } else { res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) + if (res.sourceDocuments) { + sourceDocuments = res.sourceDocuments + } } await memory.addChatMessages( @@ -94,7 +102,7 @@ class OpenAIFunctionAgent_Agents implements INode { this.sessionId ) - return res?.output + return sourceDocuments.length ? { text: res?.output, sourceDocuments: flatten(sourceDocuments) } : res?.output } } diff --git a/packages/components/nodes/outputparsers/CustomListOutputParser/CustomListOutputParser.ts b/packages/components/nodes/outputparsers/CustomListOutputParser/CustomListOutputParser.ts index d420a88d99f..1e44acdb96b 100644 --- a/packages/components/nodes/outputparsers/CustomListOutputParser/CustomListOutputParser.ts +++ b/packages/components/nodes/outputparsers/CustomListOutputParser/CustomListOutputParser.ts @@ -29,16 +29,17 @@ class CustomListOutputParser implements INode { label: 'Length', name: 'length', type: 'number', - default: 5, step: 1, - description: 'Number of values to return' + description: 'Number of values to return', + optional: true }, { label: 'Separator', name: 'separator', type: 'string', description: 'Separator between values', - default: ',' + default: ',', + optional: true }, { label: 'Autofix', @@ -54,10 +55,11 @@ class CustomListOutputParser implements INode { const separator = nodeData.inputs?.separator as string const lengthStr = nodeData.inputs?.length as string const autoFix = nodeData.inputs?.autofixParser as boolean - let length = 5 - if (lengthStr) length = parseInt(lengthStr, 10) - const parser = new LangchainCustomListOutputParser({ length: length, separator: separator }) + const parser = new LangchainCustomListOutputParser({ + length: lengthStr ? parseInt(lengthStr, 10) : undefined, + separator: separator + }) Object.defineProperty(parser, 'autoFix', { enumerable: true, configurable: true, diff --git a/packages/components/nodes/tools/RetrieverTool/RetrieverTool.ts b/packages/components/nodes/tools/RetrieverTool/RetrieverTool.ts index cc74a015cd7..4e4a4af8a83 100644 --- a/packages/components/nodes/tools/RetrieverTool/RetrieverTool.ts +++ b/packages/components/nodes/tools/RetrieverTool/RetrieverTool.ts @@ -1,8 +1,11 @@ import { INode, INodeData, INodeParams } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' import { DynamicTool } from 'langchain/tools' -import { createRetrieverTool } from 'langchain/agents/toolkits' +import { DynamicStructuredTool } from '@langchain/core/tools' +import { CallbackManagerForToolRun } from '@langchain/core/callbacks/manager' import { BaseRetriever } from 'langchain/schema/retriever' +import { z } from 'zod' +import { SOURCE_DOCUMENTS_PREFIX } from '../../../src/agents' class Retriever_Tools implements INode { label: string @@ -19,7 +22,7 @@ class Retriever_Tools implements INode { constructor() { this.label = 'Retriever Tool' this.name = 'retrieverTool' - this.version = 1.0 + this.version = 2.0 this.type = 'RetrieverTool' this.icon = 'retrievertool.svg' this.category = 'Tools' @@ -44,6 +47,12 @@ class Retriever_Tools implements INode { label: 'Retriever', name: 'retriever', type: 'BaseRetriever' + }, + { + label: 'Return Source Documents', + name: 'returnSourceDocuments', + type: 'boolean', + optional: true } ] } @@ -52,12 +61,25 @@ class Retriever_Tools implements INode { const name = nodeData.inputs?.name as string const description = nodeData.inputs?.description as string const retriever = nodeData.inputs?.retriever as BaseRetriever + const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean - const tool = createRetrieverTool(retriever, { + const input = { name, description + } + + const func = async ({ input }: { input: string }, runManager?: CallbackManagerForToolRun) => { + const docs = await retriever.getRelevantDocuments(input, runManager?.getChild('retriever')) + const content = docs.map((doc) => doc.pageContent).join('\n\n') + const sourceDocuments = JSON.stringify(docs) + return returnSourceDocuments ? content + SOURCE_DOCUMENTS_PREFIX + sourceDocuments : content + } + + const schema = z.object({ + input: z.string().describe('query to look up in retriever') }) + const tool = new DynamicStructuredTool({ ...input, func, schema }) return tool } } diff --git a/packages/components/src/agents.ts b/packages/components/src/agents.ts index 5e241d505a2..ab08097b0bb 100644 --- a/packages/components/src/agents.ts +++ b/packages/components/src/agents.ts @@ -1,5 +1,6 @@ +import { flatten } from 'lodash' import { AgentExecutorInput, BaseSingleActionAgent, BaseMultiActionAgent, RunnableAgent, StoppingMethod } from 'langchain/agents' -import { ChainValues, AgentStep, AgentFinish, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema' +import { ChainValues, AgentStep, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema' import { OutputParserException } from 'langchain/schema/output_parser' import { CallbackManager, CallbackManagerForChainRun, Callbacks } from 'langchain/callbacks' import { ToolInputParsingException, Tool } from '@langchain/core/tools' @@ -7,6 +8,11 @@ import { Runnable } from 'langchain/schema/runnable' import { BaseChain, SerializedLLMChain } from 'langchain/chains' import { Serializable } from '@langchain/core/load/serializable' +export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n' +type AgentFinish = { + returnValues: Record + log: string +} type AgentExecutorOutput = ChainValues interface AgentExecutorIteratorInput { @@ -315,10 +321,12 @@ export class AgentExecutor extends BaseChain { const steps: AgentStep[] = [] let iterations = 0 + let sourceDocuments: Array = [] const getOutput = async (finishStep: AgentFinish): Promise => { const { returnValues } = finishStep const additional = await this.agent.prepareForOutput(returnValues, steps) + if (sourceDocuments.length) additional.sourceDocuments = flatten(sourceDocuments) if (this.returnIntermediateSteps) { return { ...returnValues, intermediateSteps: steps, ...additional } @@ -406,6 +414,17 @@ export class AgentExecutor extends BaseChain { return { action, observation: observation ?? '' } } } + if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) { + const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX) + observation = observationArray[0] + const docs = observationArray[1] + try { + const parsedDocs = JSON.parse(docs) + sourceDocuments.push(parsedDocs) + } catch (e) { + console.error('Error parsing source documents from tool') + } + } return { action, observation: observation ?? '' } }) ) @@ -500,6 +519,10 @@ export class AgentExecutor extends BaseChain { chatId: this.chatId, input: this.input }) + if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) { + const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX) + observation = observationArray[0] + } } catch (e) { if (e instanceof ToolInputParsingException) { if (this.handleParsingErrors === true) { diff --git a/packages/server/marketplaces/chatflows/Conversational Retrieval Agent.json b/packages/server/marketplaces/chatflows/Conversational Retrieval Agent.json index 810c2b354f7..40c689f50e8 100644 --- a/packages/server/marketplaces/chatflows/Conversational Retrieval Agent.json +++ b/packages/server/marketplaces/chatflows/Conversational Retrieval Agent.json @@ -217,6 +217,13 @@ "rows": 3, "placeholder": "Searches and returns documents regarding the state-of-the-union.", "id": "retrieverTool_0-input-description-string" + }, + { + "label": "Return Source Documents", + "name": "returnSourceDocuments", + "type": "boolean", + "optional": true, + "id": "retrieverTool_0-input-returnSourceDocuments-boolean" } ], "inputAnchors": [ @@ -230,7 +237,8 @@ "inputs": { "name": "search_website", "description": "Searches and return documents regarding Jane - a culinary institution that offers top quality coffee, pastries, breakfast, lunch, and a variety of baked goods. They have multiple locations, including Jane on Fillmore, Jane on Larkin, Jane the Bakery, Toy Boat By Jane, and Little Jane on Grant. They emphasize healthy eating with a focus on flavor and quality ingredients. They bake everything in-house and work with local suppliers to source ingredients directly from farmers. They also offer catering services and delivery options.", - "retriever": "{{pinecone_0.data.instance}}" + "retriever": "{{pinecone_0.data.instance}}", + "returnSourceDocuments": true }, "outputAnchors": [ { diff --git a/packages/server/src/index.ts b/packages/server/src/index.ts index b0bb06f55e9..045e40dd722 100644 --- a/packages/server/src/index.ts +++ b/packages/server/src/index.ts @@ -473,6 +473,8 @@ export class App { const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id)) let isStreaming = false + let isEndingNodeExists = endingNodes.find((node) => node.data?.outputs?.output === 'EndingNode') + for (const endingNode of endingNodes) { const endingNodeData = endingNode.data if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`) @@ -488,7 +490,8 @@ export class App { isStreaming = isEndingNode ? false : isFlowValidForStream(nodes, endingNodeData) } - const obj = { isStreaming } + // Once custom function ending node exists, flow is always unavailable to stream + const obj = { isStreaming: isEndingNodeExists ? false : isStreaming } return res.json(obj) }) @@ -1677,6 +1680,9 @@ export class App { if (!endingNodeIds.length) return res.status(500).send(`Ending nodes not found`) const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id)) + + let isEndingNodeExists = endingNodes.find((node) => node.data?.outputs?.output === 'EndingNode') + for (const endingNode of endingNodes) { const endingNodeData = endingNode.data if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`) @@ -1704,6 +1710,9 @@ export class App { isStreamValid = isFlowValidForStream(nodes, endingNodeData) } + // Once custom function ending node exists, flow is always unavailable to stream + isStreamValid = isEndingNodeExists ? false : isStreamValid + let chatHistory: IMessage[] = incomingInput.history ?? [] // When {{chat_history}} is used in Prompt Template, fetch the chat conversations from memory node