Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added support for hnswlib vec store for development / testing #163

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions src/features/chat/chat-data/chat-data-api.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import { userHashedId } from "@/features/auth/helpers";
import { CosmosDBChatMessageHistory } from "@/features/langchain/memory/cosmosdb/cosmosdb";
import { initVectorStore } from "@/features/langchain/vector-stores/store";
import { LangChainStream, StreamingTextResponse } from "ai";
import { loadQAMapReduceChain } from "langchain/chains";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { BufferWindowMemory } from "langchain/memory";
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
} from "langchain/prompts";
import { AzureCogSearch } from "../../langchain/vector-stores/azure-cog-search/azure-cog-vector-store";
import { insertPromptAndResponse } from "../chat-services/chat-service";
import { initAndGuardChatSession } from "../chat-services/chat-thread-service";
import { FaqDocumentIndex, PromptGPTProps } from "../chat-services/models";
import { PromptGPTProps } from "../chat-services/models";
import { transformConversationStyleToTemperature } from "../chat-services/utils";

export const ChatData = async (props: PromptGPTProps) => {
Expand Down Expand Up @@ -68,12 +67,15 @@ export const ChatData = async (props: PromptGPTProps) => {
};

const findRelevantDocuments = async (query: string, chatThreadId: string) => {
const vectorStore = initVectorStore();
const vectorStore = await initVectorStore();
const userId = await userHashedId();

const relevantDocuments = await vectorStore.similaritySearch(query, 10, {
vectorFields: vectorStore.config.vectorFieldName,
filter: `user eq '${await userHashedId()}' and chatThreadId eq '${chatThreadId}'`,
});
const relevantDocuments = await vectorStore.similaritySearch(
query,
10,
userId,
chatThreadId
);

return relevantDocuments;
};
Expand All @@ -93,16 +95,3 @@ const defineSystemPrompt = () => {

return CHAT_COMBINE_PROMPT;
};

const initVectorStore = () => {
const embedding = new OpenAIEmbeddings();
const azureSearch = new AzureCogSearch<FaqDocumentIndex>(embedding, {
name: process.env.AZURE_SEARCH_NAME,
indexName: process.env.AZURE_SEARCH_INDEX_NAME,
apiKey: process.env.AZURE_SEARCH_API_KEY,
apiVersion: process.env.AZURE_SEARCH_API_VERSION,
vectorFieldName: "embedding",
});

return azureSearch;
};
24 changes: 4 additions & 20 deletions src/features/chat/chat-services/chat-document-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import { userHashedId } from "@/features/auth/helpers";
import { CosmosDBContainer } from "@/features/common/cosmos";
import { AzureCogSearch } from "@/features/langchain/vector-stores/azure-cog-search/azure-cog-vector-store";
import { initVectorStore } from "@/features/langchain/vector-stores/store";
import {
AzureKeyCredential,
DocumentAnalysisClient,
} from "@azure/ai-form-recognizer";
import { Document } from "langchain/document";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { nanoid } from "nanoid";
import {
Expand Down Expand Up @@ -105,10 +104,8 @@ const SplitDocuments = async (docs: Array<Document>) => {

export const DeleteDocuments = async (chatThreadId: string) => {
try {

const vectorStore = initAzureSearchVectorStore();
const vectorStore = await initVectorStore();
await vectorStore.deleteDocuments(chatThreadId);

} catch (e) {
console.log("************");
return {
Expand All @@ -125,7 +122,7 @@ export const IndexDocuments = async (
chatThreadId: string
): Promise<ServerActionResponse<FaqDocumentIndex[]>> => {
try {
const vectorStore = initAzureSearchVectorStore();
const vectorStore = await initVectorStore();
const documentsToIndex: FaqDocumentIndex[] = [];
let index = 0;
for (const doc of docs) {
Expand Down Expand Up @@ -159,19 +156,6 @@ export const IndexDocuments = async (
}
};

export const initAzureSearchVectorStore = () => {
const embedding = new OpenAIEmbeddings();
const azureSearch = new AzureCogSearch<FaqDocumentIndex>(embedding, {
name: process.env.AZURE_SEARCH_NAME,
indexName: process.env.AZURE_SEARCH_INDEX_NAME,
apiKey: process.env.AZURE_SEARCH_API_KEY,
apiVersion: process.env.AZURE_SEARCH_API_VERSION,
vectorFieldName: "embedding",
});

return azureSearch;
};

export const initDocumentIntelligence = () => {
const client = new DocumentAnalysisClient(
process.env.AZURE_DOCUMENT_INTELLIGENCE_ENDPOINT,
Expand Down Expand Up @@ -260,6 +244,6 @@ export const ensureSearchIsConfigured = async () => {
throw new Error("Azure openai embedding variables are not configured.");
}

const vectorStore = initAzureSearchVectorStore();
const vectorStore = await initVectorStore();
await vectorStore.ensureIndexIsCreated();
};
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,9 @@ export class AzureCogSearch<
return `https://${this._config.name}.search.windows.net/indexes/${this._config.indexName}/docs`;
}

async addDocuments(documents: Document<TModel>[]): Promise<string[]> {
async addDocuments(documents: Document<TModel>[]): Promise<void> {
const texts = documents.map(({ pageContent }) => pageContent);
return this.addVectors(
await this.embeddings.embedDocuments(texts),
documents
);
this.addVectors(await this.embeddings.embedDocuments(texts), documents);
}

async deleteDocuments(chatThreadId: string): Promise<void> {
Expand Down
51 changes: 51 additions & 0 deletions src/features/langchain/vector-stores/azure-cog-search/doc-store.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { AzureCogSearch } from "./azure-cog-vector-store";
import { Embeddings } from "langchain/embeddings/base";
import { Document } from "langchain/document";
import { DocumentVectorStore, DocumentType } from "../store";

export class AzureCogSearchStore<TModel extends DocumentType>
implements DocumentVectorStore<TModel>
{
private vectorStore: AzureCogSearch<TModel>;

constructor(vectorStore: AzureCogSearch<TModel>) {
this.vectorStore = vectorStore;
}

static async build<TModel extends DocumentType>(
embedding: Embeddings
): Promise<DocumentVectorStore<TModel>> {
const vectorStore = new AzureCogSearch<TModel>(embedding, {
name: process.env.AZURE_SEARCH_NAME,
indexName: process.env.AZURE_SEARCH_INDEX_NAME,
apiKey: process.env.AZURE_SEARCH_API_KEY,
apiVersion: process.env.AZURE_SEARCH_API_VERSION,
vectorFieldName: "embedding",
});
return new AzureCogSearchStore<TModel>(vectorStore);
}

async similaritySearch(
query: string,
k: number,
userId: string,
chatThreadId: string
) {
return await this.vectorStore.similaritySearch(query, k, {
vectorFields: this.vectorStore.config.vectorFieldName,
filter: `user eq '${userId}' and chatThreadId eq '${chatThreadId}'`,
});
}

async addDocuments(documentsToIndex: Document<TModel>[]) {
await this.vectorStore.addDocuments(documentsToIndex);
}

async deleteDocuments(chatThreadId: string): Promise<void> {
await this.vectorStore.deleteDocuments(chatThreadId);
}

async ensureIndexIsCreated(): Promise<void> {
await this.vectorStore.ensureIndexIsCreated();
}
}
69 changes: 69 additions & 0 deletions src/features/langchain/vector-stores/hnswlib/doc-store.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import { Document } from "langchain/document";
import { Embeddings } from "langchain/embeddings/base";
import { HNSWLib } from "langchain/vectorstores/hnswlib";
import { DocumentType, DocumentVectorStore } from "../store";

import * as fs from "node:fs";
import os from "node:os";
import * as path from "node:path";

const DB_FILE = path.join(os.tmpdir(), "hnswlib-chatgpt");

/**
* Represents a document vector store that uses the HNSWLib library for similarity search. Persists the vector store to disk to the DB_FILE. Doesn't support any file locking. Only use for testing / development.
* @template TModel The type of the documents stored in the vector store.
*/
export class HNSWLibStore<TModel extends DocumentType>
implements DocumentVectorStore<TModel>
{
private vectorStore: HNSWLib;

constructor(vectorStore: HNSWLib) {
this.vectorStore = vectorStore;
}

static async build<TModel extends DocumentType>(
embedding: Embeddings
): Promise<DocumentVectorStore<TModel>> {
let vectorStore: HNSWLib;
if (fs.existsSync(DB_FILE)) {
try {
vectorStore = await HNSWLib.load(DB_FILE, embedding);
} catch (err) {
console.error("Error loading vector DB state. Reset state.", err);
vectorStore = await HNSWLib.fromDocuments([], embedding);
}
} else {
vectorStore = await HNSWLib.fromDocuments([], embedding);
}
return new HNSWLibStore(vectorStore);
}

async similaritySearch(
query: string,
k: number,
userId: string,
chatThreadId: string
) {
// TODO: filter by userId and chatThreadId
return (await this.vectorStore.similaritySearch(
query,
k
)) as Document<TModel>[];
}

async addDocuments(documentsToIndex: Document<TModel>[]) {
await this.vectorStore.addDocuments(documentsToIndex);
await this.vectorStore.save(DB_FILE);
}

async deleteDocuments(chatThreadId: string): Promise<void> {
console.log("[HNSWLibStore] TODO: implement deleteDocuments");
}

async ensureIndexIsCreated(): Promise<void> {
console.log(
"[HNSWLibStore] ensureIndexIsCreated called - noop - HNSWLib doesn't need index creation"
);
}
}
29 changes: 29 additions & 0 deletions src/features/langchain/vector-stores/store.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import { FaqDocumentIndex } from "@/features/chat/chat-services/models";
import { AzureCogSearchStore } from "./azure-cog-search/doc-store";
import { HNSWLibStore } from "./hnswlib/doc-store";
import { Document } from "langchain/document";

export interface DocumentType extends Record<string, unknown> {}

export interface DocumentVectorStore<TModel extends DocumentType> {
ensureIndexIsCreated(): Promise<void>;
deleteDocuments(chatThreadId: string): Promise<void>;
similaritySearch(
query: string,
k: number,
userId: string,
chatThreadId: string
): Promise<Document<TModel>[]>;
addDocuments(documentsToIndex: Document<TModel>[]): Promise<void>;
}

export const initVectorStore = async () => {
const embedding = new OpenAIEmbeddings();
if (process.env.AZURE_SEARCH_NAME) {
return await AzureCogSearchStore.build<FaqDocumentIndex>(embedding);
} else {
// if azure cog search is not configured, use hnswlib as a fallback
return await HNSWLibStore.build<FaqDocumentIndex>(embedding);
}
};
7 changes: 7 additions & 0 deletions src/next.config.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
/** @type {import('next').NextConfig} */
const webpack = require("webpack");

const nextConfig = {
output: "standalone",
experimental: {
serverActions: true,
serverActionsBodySizeLimit: "22mb",
},
webpack: (config) => {
config.externals = [...config.externals, "hnswlib-node"];

return config;
},
};

module.exports = nextConfig;
29 changes: 29 additions & 0 deletions src/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"clsx": "^2.0.0",
"eslint": "^8.46.0",
"eslint-config-next": "^13.4.12",
"hnswlib-node": "^1.4.2",
"langchain": "^0.0.123",
"lucide-react": "^0.264.0",
"nanoid": "^4.0.2",
Expand Down