From d83a8373af98a9e610a9b1fb12492ee82bf322e7 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser Date: Thu, 14 Sep 2023 12:24:02 +0700 Subject: [PATCH] feat: added support for hnswlib vec store for development --- src/features/chat/chat-data/chat-data-api.ts | 31 +++------ .../chat-services/chat-document-service.ts | 24 ++----- .../azure-cog-vector-store.ts | 7 +- .../azure-cog-search/doc-store.ts | 51 ++++++++++++++ .../vector-stores/hnswlib/doc-store.ts | 69 +++++++++++++++++++ src/features/langchain/vector-stores/store.ts | 29 ++++++++ src/next.config.js | 7 ++ src/package-lock.json | 29 ++++++++ src/package.json | 1 + 9 files changed, 202 insertions(+), 46 deletions(-) create mode 100644 src/features/langchain/vector-stores/azure-cog-search/doc-store.ts create mode 100644 src/features/langchain/vector-stores/hnswlib/doc-store.ts create mode 100644 src/features/langchain/vector-stores/store.ts diff --git a/src/features/chat/chat-data/chat-data-api.ts b/src/features/chat/chat-data/chat-data-api.ts index a685b3473..54a220cb4 100644 --- a/src/features/chat/chat-data/chat-data-api.ts +++ b/src/features/chat/chat-data/chat-data-api.ts @@ -1,19 +1,18 @@ import { userHashedId } from "@/features/auth/helpers"; import { CosmosDBChatMessageHistory } from "@/features/langchain/memory/cosmosdb/cosmosdb"; +import { initVectorStore } from "@/features/langchain/vector-stores/store"; import { LangChainStream, StreamingTextResponse } from "ai"; import { loadQAMapReduceChain } from "langchain/chains"; import { ChatOpenAI } from "langchain/chat_models/openai"; -import { OpenAIEmbeddings } from "langchain/embeddings/openai"; import { BufferWindowMemory } from "langchain/memory"; import { ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate, } from "langchain/prompts"; -import { AzureCogSearch } from "../../langchain/vector-stores/azure-cog-search/azure-cog-vector-store"; import { insertPromptAndResponse } from "../chat-services/chat-service"; import { initAndGuardChatSession } from "../chat-services/chat-thread-service"; -import { FaqDocumentIndex, PromptGPTProps } from "../chat-services/models"; +import { PromptGPTProps } from "../chat-services/models"; import { transformConversationStyleToTemperature } from "../chat-services/utils"; export const ChatData = async (props: PromptGPTProps) => { @@ -68,12 +67,15 @@ export const ChatData = async (props: PromptGPTProps) => { }; const findRelevantDocuments = async (query: string, chatThreadId: string) => { - const vectorStore = initVectorStore(); + const vectorStore = await initVectorStore(); + const userId = await userHashedId(); - const relevantDocuments = await vectorStore.similaritySearch(query, 10, { - vectorFields: vectorStore.config.vectorFieldName, - filter: `user eq '${await userHashedId()}' and chatThreadId eq '${chatThreadId}'`, - }); + const relevantDocuments = await vectorStore.similaritySearch( + query, + 10, + userId, + chatThreadId + ); return relevantDocuments; }; @@ -93,16 +95,3 @@ const defineSystemPrompt = () => { return CHAT_COMBINE_PROMPT; }; - -const initVectorStore = () => { - const embedding = new OpenAIEmbeddings(); - const azureSearch = new AzureCogSearch(embedding, { - name: process.env.AZURE_SEARCH_NAME, - indexName: process.env.AZURE_SEARCH_INDEX_NAME, - apiKey: process.env.AZURE_SEARCH_API_KEY, - apiVersion: process.env.AZURE_SEARCH_API_VERSION, - vectorFieldName: "embedding", - }); - - return azureSearch; -}; diff --git a/src/features/chat/chat-services/chat-document-service.ts b/src/features/chat/chat-services/chat-document-service.ts index 35bf145ec..38335f0a9 100644 --- a/src/features/chat/chat-services/chat-document-service.ts +++ b/src/features/chat/chat-services/chat-document-service.ts @@ -2,13 +2,12 @@ import { userHashedId } from "@/features/auth/helpers"; import { CosmosDBContainer } from "@/features/common/cosmos"; -import { AzureCogSearch } from "@/features/langchain/vector-stores/azure-cog-search/azure-cog-vector-store"; +import { initVectorStore } from "@/features/langchain/vector-stores/store"; import { AzureKeyCredential, DocumentAnalysisClient, } from "@azure/ai-form-recognizer"; import { Document } from "langchain/document"; -import { OpenAIEmbeddings } from "langchain/embeddings/openai"; import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; import { nanoid } from "nanoid"; import { @@ -105,10 +104,8 @@ const SplitDocuments = async (docs: Array) => { export const DeleteDocuments = async (chatThreadId: string) => { try { - - const vectorStore = initAzureSearchVectorStore(); + const vectorStore = await initVectorStore(); await vectorStore.deleteDocuments(chatThreadId); - } catch (e) { console.log("************"); return { @@ -125,7 +122,7 @@ export const IndexDocuments = async ( chatThreadId: string ): Promise> => { try { - const vectorStore = initAzureSearchVectorStore(); + const vectorStore = await initVectorStore(); const documentsToIndex: FaqDocumentIndex[] = []; let index = 0; for (const doc of docs) { @@ -159,19 +156,6 @@ export const IndexDocuments = async ( } }; -export const initAzureSearchVectorStore = () => { - const embedding = new OpenAIEmbeddings(); - const azureSearch = new AzureCogSearch(embedding, { - name: process.env.AZURE_SEARCH_NAME, - indexName: process.env.AZURE_SEARCH_INDEX_NAME, - apiKey: process.env.AZURE_SEARCH_API_KEY, - apiVersion: process.env.AZURE_SEARCH_API_VERSION, - vectorFieldName: "embedding", - }); - - return azureSearch; -}; - export const initDocumentIntelligence = () => { const client = new DocumentAnalysisClient( process.env.AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT, @@ -260,6 +244,6 @@ export const ensureSearchIsConfigured = async () => { throw new Error("Azure openai embedding variables are not configured."); } - const vectorStore = initAzureSearchVectorStore(); + const vectorStore = await initVectorStore(); await vectorStore.ensureIndexIsCreated(); }; diff --git a/src/features/langchain/vector-stores/azure-cog-search/azure-cog-vector-store.ts b/src/features/langchain/vector-stores/azure-cog-search/azure-cog-vector-store.ts index 767c396b0..60805c2ae 100644 --- a/src/features/langchain/vector-stores/azure-cog-search/azure-cog-vector-store.ts +++ b/src/features/langchain/vector-stores/azure-cog-search/azure-cog-vector-store.ts @@ -82,12 +82,9 @@ export class AzureCogSearch< return `https://${this._config.name}.search.windows.net/indexes/${this._config.indexName}/docs`; } - async addDocuments(documents: Document[]): Promise { + async addDocuments(documents: Document[]): Promise { const texts = documents.map(({ pageContent }) => pageContent); - return this.addVectors( - await this.embeddings.embedDocuments(texts), - documents - ); + this.addVectors(await this.embeddings.embedDocuments(texts), documents); } async deleteDocuments(chatThreadId: string): Promise { diff --git a/src/features/langchain/vector-stores/azure-cog-search/doc-store.ts b/src/features/langchain/vector-stores/azure-cog-search/doc-store.ts new file mode 100644 index 000000000..24223818d --- /dev/null +++ b/src/features/langchain/vector-stores/azure-cog-search/doc-store.ts @@ -0,0 +1,51 @@ +import { AzureCogSearch } from "./azure-cog-vector-store"; +import { Embeddings } from "langchain/embeddings/base"; +import { Document } from "langchain/document"; +import { DocumentVectorStore, DocumentType } from "../store"; + +export class AzureCogSearchStore + implements DocumentVectorStore +{ + private vectorStore: AzureCogSearch; + + constructor(vectorStore: AzureCogSearch) { + this.vectorStore = vectorStore; + } + + static async build( + embedding: Embeddings + ): Promise> { + const vectorStore = new AzureCogSearch(embedding, { + name: process.env.AZURE_SEARCH_NAME, + indexName: process.env.AZURE_SEARCH_INDEX_NAME, + apiKey: process.env.AZURE_SEARCH_API_KEY, + apiVersion: process.env.AZURE_SEARCH_API_VERSION, + vectorFieldName: "embedding", + }); + return new AzureCogSearchStore(vectorStore); + } + + async similaritySearch( + query: string, + k: number, + userId: string, + chatThreadId: string + ) { + return await this.vectorStore.similaritySearch(query, k, { + vectorFields: this.vectorStore.config.vectorFieldName, + filter: `user eq '${userId}' and chatThreadId eq '${chatThreadId}'`, + }); + } + + async addDocuments(documentsToIndex: Document[]) { + await this.vectorStore.addDocuments(documentsToIndex); + } + + async deleteDocuments(chatThreadId: string): Promise { + await this.vectorStore.deleteDocuments(chatThreadId); + } + + async ensureIndexIsCreated(): Promise { + await this.vectorStore.ensureIndexIsCreated(); + } +} diff --git a/src/features/langchain/vector-stores/hnswlib/doc-store.ts b/src/features/langchain/vector-stores/hnswlib/doc-store.ts new file mode 100644 index 000000000..75eaef2d9 --- /dev/null +++ b/src/features/langchain/vector-stores/hnswlib/doc-store.ts @@ -0,0 +1,69 @@ +import { Document } from "langchain/document"; +import { Embeddings } from "langchain/embeddings/base"; +import { HNSWLib } from "langchain/vectorstores/hnswlib"; +import { DocumentType, DocumentVectorStore } from "../store"; + +import * as fs from "node:fs"; +import os from "node:os"; +import * as path from "node:path"; + +const DB_FILE = path.join(os.tmpdir(), "hnswlib-chatgpt"); + +/** + * Represents a document vector store that uses the HNSWLib library for similarity search. Persists the vector store to disk to the DB_FILE. Doesn't support any file locking. Only use for testing / development. + * @template TModel The type of the documents stored in the vector store. + */ +export class HNSWLibStore + implements DocumentVectorStore +{ + private vectorStore: HNSWLib; + + constructor(vectorStore: HNSWLib) { + this.vectorStore = vectorStore; + } + + static async build( + embedding: Embeddings + ): Promise> { + let vectorStore: HNSWLib; + if (fs.existsSync(DB_FILE)) { + try { + vectorStore = await HNSWLib.load(DB_FILE, embedding); + } catch (err) { + console.error("Error loading vector DB state. Reset state.", err); + vectorStore = await HNSWLib.fromDocuments([], embedding); + } + } else { + vectorStore = await HNSWLib.fromDocuments([], embedding); + } + return new HNSWLibStore(vectorStore); + } + + async similaritySearch( + query: string, + k: number, + userId: string, + chatThreadId: string + ) { + // TODO: filter by userId and chatThreadId + return (await this.vectorStore.similaritySearch( + query, + k + )) as Document[]; + } + + async addDocuments(documentsToIndex: Document[]) { + await this.vectorStore.addDocuments(documentsToIndex); + await this.vectorStore.save(DB_FILE); + } + + async deleteDocuments(chatThreadId: string): Promise { + console.log("[HNSWLibStore] TODO: implement deleteDocuments"); + } + + async ensureIndexIsCreated(): Promise { + console.log( + "[HNSWLibStore] ensureIndexIsCreated called - noop - HNSWLib doesn't need index creation" + ); + } +} diff --git a/src/features/langchain/vector-stores/store.ts b/src/features/langchain/vector-stores/store.ts new file mode 100644 index 000000000..8f656b6cf --- /dev/null +++ b/src/features/langchain/vector-stores/store.ts @@ -0,0 +1,29 @@ +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; +import { FaqDocumentIndex } from "@/features/chat/chat-services/models"; +import { AzureCogSearchStore } from "./azure-cog-search/doc-store"; +import { HNSWLibStore } from "./hnswlib/doc-store"; +import { Document } from "langchain/document"; + +export interface DocumentType extends Record {} + +export interface DocumentVectorStore { + ensureIndexIsCreated(): Promise; + deleteDocuments(chatThreadId: string): Promise; + similaritySearch( + query: string, + k: number, + userId: string, + chatThreadId: string + ): Promise[]>; + addDocuments(documentsToIndex: Document[]): Promise; +} + +export const initVectorStore = async () => { + const embedding = new OpenAIEmbeddings(); + if (process.env.AZURE_SEARCH_NAME) { + return await AzureCogSearchStore.build(embedding); + } else { + // if azure cog search is not configured, use hnswlib as a fallback + return await HNSWLibStore.build(embedding); + } +}; diff --git a/src/next.config.js b/src/next.config.js index dc9285a0e..fd9a7d715 100644 --- a/src/next.config.js +++ b/src/next.config.js @@ -1,10 +1,17 @@ /** @type {import('next').NextConfig} */ +const webpack = require("webpack"); + const nextConfig = { output: "standalone", experimental: { serverActions: true, serverActionsBodySizeLimit: "22mb", }, + webpack: (config) => { + config.externals = [...config.externals, "hnswlib-node"]; + + return config; + }, }; module.exports = nextConfig; diff --git a/src/package-lock.json b/src/package-lock.json index 1871a5385..d7ae735c6 100644 --- a/src/package-lock.json +++ b/src/package-lock.json @@ -27,6 +27,7 @@ "clsx": "^2.0.0", "eslint": "^8.46.0", "eslint-config-next": "^13.4.12", + "hnswlib-node": "^1.4.2", "langchain": "^0.0.123", "lucide-react": "^0.264.0", "nanoid": "^4.0.2", @@ -2272,6 +2273,14 @@ "resolved": "https://registry.npmjs.org/binary-search/-/binary-search-1.3.6.tgz", "integrity": "sha512-nbE1WxOTTrUWIfsfZ4aHGYu5DOuNkbxGokjV6Z2kxfJK3uaAb8zNK1muzOeipoLHZjInT4Br88BHpzevc681xA==" }, + "node_modules/bindings": { + "version": "1.5.0", + "resolved": "https://registry.npmjs.org/bindings/-/bindings-1.5.0.tgz", + "integrity": "sha512-p2q/t/mhvuOj/UeLlV6566GD/guowlr0hHxClI0W9m7MWYkL1F0hLo+0Aexs9HSPCtR1SXQ0TD3MMKrXZajbiQ==", + "dependencies": { + "file-uri-to-path": "1.0.0" + } + }, "node_modules/bplist-parser": { "version": "0.2.0", "resolved": "https://registry.npmjs.org/bplist-parser/-/bplist-parser-0.2.0.tgz", @@ -3556,6 +3565,11 @@ "node": "^10.12.0 || >=12.0.0" } }, + "node_modules/file-uri-to-path": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/file-uri-to-path/-/file-uri-to-path-1.0.0.tgz", + "integrity": "sha512-0Zt+s3L7Vf1biwWZ29aARiVYLx7iMGnEUl9x33fbB/j3jR81u/O2LbqK+Bm1CDSNDKVtJ/YjwY7TUd5SkeLQLw==" + }, "node_modules/fill-range": { "version": "7.0.1", "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", @@ -4040,6 +4054,16 @@ "node": "*" } }, + "node_modules/hnswlib-node": { + "version": "1.4.2", + "resolved": "https://registry.npmjs.org/hnswlib-node/-/hnswlib-node-1.4.2.tgz", + "integrity": "sha512-76PIzOaNcX8kOpKwlFPl07uelpctqDMzbiC+Qsk2JWNVkzeU/6iXRk4tfE9z3DoK1RCBrOaFXmQ6RFb1BVF9LA==", + "hasInstallScript": true, + "dependencies": { + "bindings": "^1.5.0", + "node-addon-api": "^6.0.0" + } + }, "node_modules/http-proxy-agent": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/http-proxy-agent/-/http-proxy-agent-5.0.0.tgz", @@ -6260,6 +6284,11 @@ "resolved": "https://registry.npmjs.org/node-abort-controller/-/node-abort-controller-3.1.1.tgz", "integrity": "sha512-AGK2yQKIjRuqnc6VkX2Xj5d+QW8xZ87pa1UK6yA6ouUyuxfHuMP6umE5QK7UmTeOAymo+Zx1Fxiuw9rVx8taHQ==" }, + "node_modules/node-addon-api": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-6.1.0.tgz", + "integrity": "sha512-+eawOlIgy680F0kBzPUNFhMZGtJ1YmqM6l4+Crf4IkImjYrO/mqPwRMh352g23uIaQKFItcQ64I7KMaJxHgAVA==" + }, "node_modules/node-domexception": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/node-domexception/-/node-domexception-1.0.0.tgz", diff --git a/src/package.json b/src/package.json index 0711f58b1..a1855c1d7 100644 --- a/src/package.json +++ b/src/package.json @@ -28,6 +28,7 @@ "clsx": "^2.0.0", "eslint": "^8.46.0", "eslint-config-next": "^13.4.12", + "hnswlib-node": "^1.4.2", "langchain": "^0.0.123", "lucide-react": "^0.264.0", "nanoid": "^4.0.2",