From 3a1131a156c11d699653b114cf6ccb4155ea54e3 Mon Sep 17 00:00:00 2001 From: crisjy Date: Sat, 18 Jan 2025 16:14:39 +0800 Subject: [PATCH] feat(azure-cosmosdb): add session context for a user mongodb (#7436) Co-authored-by: jacoblee93 --- .../src/chat_histories/mongodb.ts | 81 ++++++++++++++++++- .../tests/chat_histories/mongodb.int.test.ts | 55 ++++++++++++- 2 files changed, 131 insertions(+), 5 deletions(-) diff --git a/libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts b/libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts index 53104c198d71..8a0ba0264160 100644 --- a/libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts +++ b/libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts @@ -20,7 +20,13 @@ export interface AzureCosmosDBMongoChatHistoryDBConfig { readonly collectionName?: string; } +export type ChatSessionMongo = { + id: string; + context: Record; +}; + const ID_KEY = "sessionId"; +const ID_USER = "userId"; export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHistory { lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"]; @@ -33,6 +39,8 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis private initPromise?: Promise; + private context: Record = {}; + private readonly client: MongoClient | undefined; private database: Db; @@ -41,11 +49,14 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis private sessionId: string; + private userId: string; + initialize: () => Promise; constructor( dbConfig: AzureCosmosDBMongoChatHistoryDBConfig, - sessionId: string + sessionId: string, + userId: string ) { super(); @@ -70,6 +81,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis const collectionName = dbConfig.collectionName ?? "chatHistory"; this.sessionId = sessionId; + this.userId = userId ?? "anonymous"; // Deferring initialization to the first call to `initialize` this.initialize = () => { @@ -120,6 +132,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis const document = await this.collection.findOne({ [ID_KEY]: this.sessionId, + [ID_USER]: this.userId, }); const messages = document?.messages || []; return mapStoredMessagesToChatMessages(messages); @@ -134,10 +147,12 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis await this.initialize(); const messages = mapChatMessagesToStoredMessages([message]); + const context = await this.getContext(); await this.collection.updateOne( - { [ID_KEY]: this.sessionId }, + { [ID_KEY]: this.sessionId, [ID_USER]: this.userId }, { $push: { messages: { $each: messages } } as PushOperator, + $set: { context }, }, { upsert: true } ); @@ -150,6 +165,66 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis async clear(): Promise { await this.initialize(); - await this.collection.deleteOne({ [ID_KEY]: this.sessionId }); + await this.collection.deleteOne({ + [ID_KEY]: this.sessionId, + [ID_USER]: this.userId, + }); + } + + async getAllSessions(): Promise { + await this.initialize(); + const documents = await this.collection + .find({ + [ID_USER]: this.userId, + }) + .toArray(); + + const chatSessions: ChatSessionMongo[] = documents.map((doc) => ({ + id: doc[ID_KEY], + user_id: doc[ID_USER], + context: doc.context || {}, + })); + + return chatSessions; + } + + async clearAllSessions() { + await this.initialize(); + try { + await this.collection.deleteMany({ + [ID_USER]: this.userId, + }); + } catch (error) { + console.error("Error clearing chat history sessions:", error); + throw error; + } + } + + async getContext(): Promise> { + await this.initialize(); + + const document = await this.collection.findOne({ + [ID_KEY]: this.sessionId, + [ID_USER]: this.userId, + }); + this.context = document?.context || this.context; + return this.context; + } + + async setContext(context: Record): Promise { + await this.initialize(); + + try { + await this.collection.updateOne( + { [ID_KEY]: this.sessionId }, + { + $set: { context }, + }, + { upsert: true } + ); + } catch (error) { + console.error("Error setting chat history context", error); + throw error; + } } } diff --git a/libs/langchain-azure-cosmosdb/src/tests/chat_histories/mongodb.int.test.ts b/libs/langchain-azure-cosmosdb/src/tests/chat_histories/mongodb.int.test.ts index 35c4a2cf0311..2825b2cafab4 100644 --- a/libs/langchain-azure-cosmosdb/src/tests/chat_histories/mongodb.int.test.ts +++ b/libs/langchain-azure-cosmosdb/src/tests/chat_histories/mongodb.int.test.ts @@ -32,9 +32,11 @@ test("Test Azure Cosmos MongoDB history store", async () => { }; const sessionId = new ObjectId().toString(); + const userId = new ObjectId().toString(); const chatHistory = new AzureCosmosDBMongoChatMessageHistory( dbcfg, - sessionId + sessionId, + userId ); const blankResult = await chatHistory.getMessages(); @@ -70,9 +72,11 @@ test("Test clear Azure Cosmos MongoDB history store", async () => { }; const sessionId = new ObjectId().toString(); + const userId = new ObjectId().toString(); const chatHistory = new AzureCosmosDBMongoChatMessageHistory( dbcfg, - sessionId + sessionId, + userId ); await chatHistory.addUserMessage("Who is the best vocalist?"); @@ -93,3 +97,50 @@ test("Test clear Azure Cosmos MongoDB history store", async () => { await mongoClient.close(); }); + +test("Test getAllSessions and clearAllSessions", async () => { + expect(process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING).toBeDefined(); + + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const mongoClient = new MongoClient( + process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING! + ); + const dbcfg: AzureCosmosDBMongoChatHistoryDBConfig = { + client: mongoClient, + connectionString: process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING, + databaseName: "langchain", + collectionName: "chathistory", + }; + + const sessionId1 = new ObjectId().toString(); + const userId1 = new ObjectId().toString(); + const sessionId2 = new ObjectId().toString(); + const userId2 = new ObjectId().toString(); + + const chatHistory1 = new AzureCosmosDBMongoChatMessageHistory( + dbcfg, + sessionId1, + userId1 + ); + const chatHistory2 = new AzureCosmosDBMongoChatMessageHistory( + dbcfg, + sessionId2, + userId2 + ); + + await chatHistory1.addUserMessage("What is AI?"); + await chatHistory1.addAIChatMessage("AI stands for Artificial Intelligence."); + await chatHistory2.addUserMessage("What is the best programming language?"); + await chatHistory2.addAIChatMessage("It depends on the use case."); + + const allSessions = await chatHistory1.getAllSessions(); + expect(allSessions.length).toBe(2); + expect(allSessions[0].id).toBe(sessionId1); + expect(allSessions[1].id).toBe(sessionId2); + + await chatHistory1.clearAllSessions(); + const clearedSessions = await chatHistory1.getAllSessions(); + expect(clearedSessions.length).toBe(0); + + await mongoClient.close(); +});