-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/Extract Metadata Retriever (#3579)
add extract metadata retriever
- Loading branch information
1 parent
76ae921
commit 4c1951d
Showing
2 changed files
with
217 additions
and
0 deletions.
There are no files selected for viewing
216 changes: 216 additions & 0 deletions
216
packages/components/nodes/retrievers/ExtractMetadataRetriever/ExtractMetadataRetriever.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
import { Document } from '@langchain/core/documents' | ||
import { VectorStore, VectorStoreRetriever, VectorStoreRetrieverInput } from '@langchain/core/vectorstores' | ||
import { INode, INodeData, INodeParams, INodeOutputsValue } from '../../../src/Interface' | ||
import { handleEscapeCharacters } from '../../../src' | ||
import { z } from 'zod' | ||
import { convertStructuredSchemaToZod, ExtractTool } from '../../sequentialagents/commonUtils' | ||
import { ChatGoogleGenerativeAI } from '@langchain/google-genai' | ||
|
||
const queryPrefix = 'query' | ||
const defaultPrompt = `Extract keywords from the query: {{${queryPrefix}}}` | ||
|
||
class ExtractMetadataRetriever_Retrievers implements INode { | ||
label: string | ||
name: string | ||
version: number | ||
description: string | ||
type: string | ||
icon: string | ||
category: string | ||
badge?: string | ||
baseClasses: string[] | ||
inputs: INodeParams[] | ||
outputs: INodeOutputsValue[] | ||
|
||
constructor() { | ||
this.label = 'Extract Metadata Retriever' | ||
this.name = 'extractMetadataRetriever' | ||
this.version = 1.0 | ||
this.type = 'ExtractMetadataRetriever' | ||
this.icon = 'dynamicMetadataRetriever.svg' | ||
this.category = 'Retrievers' | ||
this.description = 'Extract keywords/metadata from the query and use it to filter documents' | ||
this.baseClasses = [this.type, 'BaseRetriever'] | ||
this.badge = 'BETA' | ||
this.inputs = [ | ||
{ | ||
label: 'Vector Store', | ||
name: 'vectorStore', | ||
type: 'VectorStore' | ||
}, | ||
{ | ||
label: 'Chat Model', | ||
name: 'model', | ||
type: 'BaseChatModel' | ||
}, | ||
{ | ||
label: 'Query', | ||
name: 'query', | ||
type: 'string', | ||
description: 'Query to retrieve documents from retriever. If not specified, user question will be used', | ||
optional: true, | ||
acceptVariable: true | ||
}, | ||
{ | ||
label: 'Prompt', | ||
name: 'dynamicMetadataFilterRetrieverPrompt', | ||
type: 'string', | ||
description: 'Prompt to extract metadata from query', | ||
rows: 4, | ||
additionalParams: true, | ||
default: defaultPrompt | ||
}, | ||
{ | ||
label: 'JSON Structured Output', | ||
name: 'dynamicMetadataFilterRetrieverStructuredOutput', | ||
type: 'datagrid', | ||
description: | ||
'Instruct the model to give output in a JSON structured schema. This output will be used as the metadata filter for connected vector store', | ||
datagrid: [ | ||
{ field: 'key', headerName: 'Key', editable: true }, | ||
{ | ||
field: 'type', | ||
headerName: 'Type', | ||
type: 'singleSelect', | ||
valueOptions: ['String', 'String Array', 'Number', 'Boolean', 'Enum'], | ||
editable: true | ||
}, | ||
{ field: 'enumValues', headerName: 'Enum Values', editable: true }, | ||
{ field: 'description', headerName: 'Description', flex: 1, editable: true } | ||
], | ||
optional: true, | ||
additionalParams: true | ||
}, | ||
{ | ||
label: 'Top K', | ||
name: 'topK', | ||
description: 'Number of top results to fetch. Default to vector store topK', | ||
placeholder: '4', | ||
type: 'number', | ||
additionalParams: true, | ||
optional: true | ||
} | ||
] | ||
this.outputs = [ | ||
{ | ||
label: 'Extract Metadata Retriever', | ||
name: 'retriever', | ||
baseClasses: this.baseClasses | ||
}, | ||
{ | ||
label: 'Document', | ||
name: 'document', | ||
description: 'Array of document objects containing metadata and pageContent', | ||
baseClasses: ['Document', 'json'] | ||
}, | ||
{ | ||
label: 'Text', | ||
name: 'text', | ||
description: 'Concatenated string from pageContent of documents', | ||
baseClasses: ['string', 'json'] | ||
} | ||
] | ||
} | ||
|
||
async init(nodeData: INodeData, input: string): Promise<any> { | ||
const vectorStore = nodeData.inputs?.vectorStore as VectorStore | ||
let llm = nodeData.inputs?.model | ||
const llmStructuredOutput = nodeData.inputs?.dynamicMetadataFilterRetrieverStructuredOutput | ||
const topK = nodeData.inputs?.topK as string | ||
const dynamicMetadataFilterRetrieverPrompt = nodeData.inputs?.dynamicMetadataFilterRetrieverPrompt as string | ||
const query = nodeData.inputs?.query as string | ||
const finalInputQuery = query ? query : input | ||
|
||
const output = nodeData.outputs?.output as string | ||
|
||
if (llmStructuredOutput && llmStructuredOutput !== '[]') { | ||
try { | ||
const structuredOutput = z.object(convertStructuredSchemaToZod(llmStructuredOutput)) | ||
|
||
if (llm instanceof ChatGoogleGenerativeAI) { | ||
const tool = new ExtractTool({ | ||
schema: structuredOutput | ||
}) | ||
// @ts-ignore | ||
const modelWithTool = llm.bind({ | ||
tools: [tool] | ||
}) as any | ||
llm = modelWithTool | ||
} else { | ||
// @ts-ignore | ||
llm = llm.withStructuredOutput(structuredOutput) | ||
} | ||
} catch (exception) { | ||
console.error(exception) | ||
} | ||
} | ||
|
||
const retriever = DynamicMetadataRetriever.fromVectorStore(vectorStore, { | ||
structuredLLM: llm, | ||
prompt: dynamicMetadataFilterRetrieverPrompt, | ||
topK: topK ? parseInt(topK, 10) : (vectorStore as any)?.k ?? 4 | ||
}) | ||
|
||
if (output === 'retriever') return retriever | ||
else if (output === 'document') return await retriever.getRelevantDocuments(finalInputQuery) | ||
else if (output === 'text') { | ||
let finaltext = '' | ||
|
||
const docs = await retriever.getRelevantDocuments(finalInputQuery) | ||
|
||
for (const doc of docs) finaltext += `${doc.pageContent}\n` | ||
|
||
return handleEscapeCharacters(finaltext, false) | ||
} | ||
|
||
return retriever | ||
} | ||
} | ||
|
||
type RetrieverInput<V extends VectorStore> = Omit<VectorStoreRetrieverInput<V>, 'k'> & { | ||
topK?: number | ||
structuredLLM: any | ||
prompt: string | ||
} | ||
|
||
class DynamicMetadataRetriever<V extends VectorStore> extends VectorStoreRetriever<V> { | ||
topK = 4 | ||
structuredLLM: any | ||
prompt = '' | ||
|
||
constructor(input: RetrieverInput<V>) { | ||
super(input) | ||
this.topK = input.topK ?? this.topK | ||
this.structuredLLM = input.structuredLLM ?? this.structuredLLM | ||
this.prompt = input.prompt ?? this.prompt | ||
} | ||
|
||
async getFilter(query: string): Promise<any> { | ||
const structuredResponse = await this.structuredLLM.invoke(this.prompt.replace(`{{${queryPrefix}}}`, query)) | ||
return structuredResponse | ||
} | ||
|
||
async getRelevantDocuments(query: string): Promise<Document[]> { | ||
const newFilter = await this.getFilter(query) | ||
// @ts-ignore | ||
this.filter = { ...this.filter, ...newFilter } | ||
const results = await this.vectorStore.similaritySearchWithScore(query, this.topK, this.filter) | ||
|
||
const finalDocs: Document[] = [] | ||
for (const result of results) { | ||
finalDocs.push( | ||
new Document({ | ||
pageContent: result[0].pageContent, | ||
metadata: result[0].metadata | ||
}) | ||
) | ||
} | ||
return finalDocs | ||
} | ||
|
||
static fromVectorStore<V extends VectorStore>(vectorStore: V, options: Omit<RetrieverInput<V>, 'vectorStore'>) { | ||
return new this<V>({ ...options, vectorStore }) | ||
} | ||
} | ||
|
||
module.exports = { nodeClass: ExtractMetadataRetriever_Retrievers } |
1 change: 1 addition & 0 deletions
1
...mponents/nodes/retrievers/ExtractMetadataRetriever/dynamicMetadataRetriever.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.