Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/Return Source Documens to retriever tool #1644

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,28 @@ class OpenAIFunctionAgent_Agents implements INode {
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string> {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const memory = nodeData.inputs?.memory as FlowiseMemory
const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)

let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = []

if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
if (res.sourceDocuments) {
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
sourceDocuments = res.sourceDocuments
}
} else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments
}
}

await memory.addChatMessages(
Expand All @@ -94,7 +102,7 @@ class OpenAIFunctionAgent_Agents implements INode {
this.sessionId
)

return res?.output
return sourceDocuments.length ? { text: res?.output, sourceDocuments: flatten(sourceDocuments) } : res?.output
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,17 @@ class CustomListOutputParser implements INode {
label: 'Length',
name: 'length',
type: 'number',
default: 5,
step: 1,
description: 'Number of values to return'
description: 'Number of values to return',
optional: true
},
{
label: 'Separator',
name: 'separator',
type: 'string',
description: 'Separator between values',
default: ','
default: ',',
optional: true
},
{
label: 'Autofix',
Expand All @@ -54,10 +55,11 @@ class CustomListOutputParser implements INode {
const separator = nodeData.inputs?.separator as string
const lengthStr = nodeData.inputs?.length as string
const autoFix = nodeData.inputs?.autofixParser as boolean
let length = 5
if (lengthStr) length = parseInt(lengthStr, 10)

const parser = new LangchainCustomListOutputParser({ length: length, separator: separator })
const parser = new LangchainCustomListOutputParser({
length: lengthStr ? parseInt(lengthStr, 10) : undefined,
separator: separator
})
Object.defineProperty(parser, 'autoFix', {
enumerable: true,
configurable: true,
Expand Down
28 changes: 25 additions & 3 deletions packages/components/nodes/tools/RetrieverTool/RetrieverTool.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import { INode, INodeData, INodeParams } from '../../../src/Interface'
import { getBaseClasses } from '../../../src/utils'
import { DynamicTool } from 'langchain/tools'
import { createRetrieverTool } from 'langchain/agents/toolkits'
import { DynamicStructuredTool } from '@langchain/core/tools'
import { CallbackManagerForToolRun } from '@langchain/core/callbacks/manager'
import { BaseRetriever } from 'langchain/schema/retriever'
import { z } from 'zod'
import { SOURCE_DOCUMENTS_PREFIX } from '../../../src/agents'

class Retriever_Tools implements INode {
label: string
Expand All @@ -19,7 +22,7 @@ class Retriever_Tools implements INode {
constructor() {
this.label = 'Retriever Tool'
this.name = 'retrieverTool'
this.version = 1.0
this.version = 2.0
this.type = 'RetrieverTool'
this.icon = 'retrievertool.svg'
this.category = 'Tools'
Expand All @@ -44,6 +47,12 @@ class Retriever_Tools implements INode {
label: 'Retriever',
name: 'retriever',
type: 'BaseRetriever'
},
{
label: 'Return Source Documents',
name: 'returnSourceDocuments',
type: 'boolean',
optional: true
}
]
}
Expand All @@ -52,12 +61,25 @@ class Retriever_Tools implements INode {
const name = nodeData.inputs?.name as string
const description = nodeData.inputs?.description as string
const retriever = nodeData.inputs?.retriever as BaseRetriever
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean

const tool = createRetrieverTool(retriever, {
const input = {
name,
description
}

const func = async ({ input }: { input: string }, runManager?: CallbackManagerForToolRun) => {
const docs = await retriever.getRelevantDocuments(input, runManager?.getChild('retriever'))
const content = docs.map((doc) => doc.pageContent).join('\n\n')
const sourceDocuments = JSON.stringify(docs)
return returnSourceDocuments ? content + SOURCE_DOCUMENTS_PREFIX + sourceDocuments : content
}

const schema = z.object({
input: z.string().describe('query to look up in retriever')
})

const tool = new DynamicStructuredTool({ ...input, func, schema })
return tool
}
}
Expand Down
25 changes: 24 additions & 1 deletion packages/components/src/agents.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import { flatten } from 'lodash'
import { AgentExecutorInput, BaseSingleActionAgent, BaseMultiActionAgent, RunnableAgent, StoppingMethod } from 'langchain/agents'
import { ChainValues, AgentStep, AgentFinish, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema'
import { ChainValues, AgentStep, AgentAction, BaseMessage, FunctionMessage, AIMessage } from 'langchain/schema'
import { OutputParserException } from 'langchain/schema/output_parser'
import { CallbackManager, CallbackManagerForChainRun, Callbacks } from 'langchain/callbacks'
import { ToolInputParsingException, Tool } from '@langchain/core/tools'
import { Runnable } from 'langchain/schema/runnable'
import { BaseChain, SerializedLLMChain } from 'langchain/chains'
import { Serializable } from '@langchain/core/load/serializable'

export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n'
type AgentFinish = {
returnValues: Record<string, any>
log: string
}
type AgentExecutorOutput = ChainValues

interface AgentExecutorIteratorInput {
Expand Down Expand Up @@ -315,10 +321,12 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {

const steps: AgentStep[] = []
let iterations = 0
let sourceDocuments: Array<Document> = []

const getOutput = async (finishStep: AgentFinish): Promise<AgentExecutorOutput> => {
const { returnValues } = finishStep
const additional = await this.agent.prepareForOutput(returnValues, steps)
if (sourceDocuments.length) additional.sourceDocuments = flatten(sourceDocuments)

if (this.returnIntermediateSteps) {
return { ...returnValues, intermediateSteps: steps, ...additional }
Expand Down Expand Up @@ -406,6 +414,17 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
return { action, observation: observation ?? '' }
}
}
if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) {
const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX)
observation = observationArray[0]
const docs = observationArray[1]
try {
const parsedDocs = JSON.parse(docs)
sourceDocuments.push(parsedDocs)
} catch (e) {
console.error('Error parsing source documents from tool')
}
}
return { action, observation: observation ?? '' }
})
)
Expand Down Expand Up @@ -500,6 +519,10 @@ export class AgentExecutor extends BaseChain<ChainValues, AgentExecutorOutput> {
chatId: this.chatId,
input: this.input
})
if (observation?.includes(SOURCE_DOCUMENTS_PREFIX)) {
const observationArray = observation.split(SOURCE_DOCUMENTS_PREFIX)
observation = observationArray[0]
}
} catch (e) {
if (e instanceof ToolInputParsingException) {
if (this.handleParsingErrors === true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@
"rows": 3,
"placeholder": "Searches and returns documents regarding the state-of-the-union.",
"id": "retrieverTool_0-input-description-string"
},
{
"label": "Return Source Documents",
"name": "returnSourceDocuments",
"type": "boolean",
"optional": true,
"id": "retrieverTool_0-input-returnSourceDocuments-boolean"
}
],
"inputAnchors": [
Expand All @@ -230,7 +237,8 @@
"inputs": {
"name": "search_website",
"description": "Searches and return documents regarding Jane - a culinary institution that offers top quality coffee, pastries, breakfast, lunch, and a variety of baked goods. They have multiple locations, including Jane on Fillmore, Jane on Larkin, Jane the Bakery, Toy Boat By Jane, and Little Jane on Grant. They emphasize healthy eating with a focus on flavor and quality ingredients. They bake everything in-house and work with local suppliers to source ingredients directly from farmers. They also offer catering services and delivery options.",
"retriever": "{{pinecone_0.data.instance}}"
"retriever": "{{pinecone_0.data.instance}}",
"returnSourceDocuments": true
},
"outputAnchors": [
{
Expand Down
11 changes: 10 additions & 1 deletion packages/server/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ export class App {
const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id))

let isStreaming = false
let isEndingNodeExists = endingNodes.find((node) => node.data?.outputs?.output === 'EndingNode')

for (const endingNode of endingNodes) {
const endingNodeData = endingNode.data
if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`)
Expand All @@ -488,7 +490,8 @@ export class App {
isStreaming = isEndingNode ? false : isFlowValidForStream(nodes, endingNodeData)
}

const obj = { isStreaming }
// Once custom function ending node exists, flow is always unavailable to stream
const obj = { isStreaming: isEndingNodeExists ? false : isStreaming }
return res.json(obj)
})

Expand Down Expand Up @@ -1677,6 +1680,9 @@ export class App {
if (!endingNodeIds.length) return res.status(500).send(`Ending nodes not found`)

const endingNodes = nodes.filter((nd) => endingNodeIds.includes(nd.id))

let isEndingNodeExists = endingNodes.find((node) => node.data?.outputs?.output === 'EndingNode')

for (const endingNode of endingNodes) {
const endingNodeData = endingNode.data
if (!endingNodeData) return res.status(500).send(`Ending node ${endingNode.id} data not found`)
Expand Down Expand Up @@ -1704,6 +1710,9 @@ export class App {
isStreamValid = isFlowValidForStream(nodes, endingNodeData)
}

// Once custom function ending node exists, flow is always unavailable to stream
isStreamValid = isEndingNodeExists ? false : isStreamValid

let chatHistory: IMessage[] = incomingInput.history ?? []

// When {{chat_history}} is used in Prompt Template, fetch the chat conversations from memory node
Expand Down
Loading