Skip to content

Commit

Permalink
feat(azure-cosmosdb): add session context for a user mongodb (#7436)
Browse files Browse the repository at this point in the history
Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
crisjy and jacoblee93 authored Jan 18, 2025
1 parent f1dbe28 commit 3a1131a
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 5 deletions.
81 changes: 78 additions & 3 deletions libs/langchain-azure-cosmosdb/src/chat_histories/mongodb.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ export interface AzureCosmosDBMongoChatHistoryDBConfig {
readonly collectionName?: string;
}

export type ChatSessionMongo = {
id: string;
context: Record<string, unknown>;
};

const ID_KEY = "sessionId";
const ID_USER = "userId";

export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace = ["langchain", "stores", "message", "azurecosmosdb"];
Expand All @@ -33,6 +39,8 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

private initPromise?: Promise<void>;

private context: Record<string, unknown> = {};

private readonly client: MongoClient | undefined;

private database: Db;
Expand All @@ -41,11 +49,14 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

private sessionId: string;

private userId: string;

initialize: () => Promise<void>;

constructor(
dbConfig: AzureCosmosDBMongoChatHistoryDBConfig,
sessionId: string
sessionId: string,
userId: string
) {
super();

Expand All @@ -70,6 +81,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
const collectionName = dbConfig.collectionName ?? "chatHistory";

this.sessionId = sessionId;
this.userId = userId ?? "anonymous";

// Deferring initialization to the first call to `initialize`
this.initialize = () => {
Expand Down Expand Up @@ -120,6 +132,7 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis

const document = await this.collection.findOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
const messages = document?.messages || [];
return mapStoredMessagesToChatMessages(messages);
Expand All @@ -134,10 +147,12 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
await this.initialize();

const messages = mapChatMessagesToStoredMessages([message]);
const context = await this.getContext();
await this.collection.updateOne(
{ [ID_KEY]: this.sessionId },
{ [ID_KEY]: this.sessionId, [ID_USER]: this.userId },
{
$push: { messages: { $each: messages } } as PushOperator<Document>,
$set: { context },
},
{ upsert: true }
);
Expand All @@ -150,6 +165,66 @@ export class AzureCosmosDBMongoChatMessageHistory extends BaseListChatMessageHis
async clear(): Promise<void> {
await this.initialize();

await this.collection.deleteOne({ [ID_KEY]: this.sessionId });
await this.collection.deleteOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
}

async getAllSessions(): Promise<ChatSessionMongo[]> {
await this.initialize();
const documents = await this.collection
.find({
[ID_USER]: this.userId,
})
.toArray();

const chatSessions: ChatSessionMongo[] = documents.map((doc) => ({
id: doc[ID_KEY],
user_id: doc[ID_USER],
context: doc.context || {},
}));

return chatSessions;
}

async clearAllSessions() {
await this.initialize();
try {
await this.collection.deleteMany({
[ID_USER]: this.userId,
});
} catch (error) {
console.error("Error clearing chat history sessions:", error);
throw error;
}
}

async getContext(): Promise<Record<string, unknown>> {
await this.initialize();

const document = await this.collection.findOne({
[ID_KEY]: this.sessionId,
[ID_USER]: this.userId,
});
this.context = document?.context || this.context;
return this.context;
}

async setContext(context: Record<string, unknown>): Promise<void> {
await this.initialize();

try {
await this.collection.updateOne(
{ [ID_KEY]: this.sessionId },
{
$set: { context },
},
{ upsert: true }
);
} catch (error) {
console.error("Error setting chat history context", error);
throw error;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ test("Test Azure Cosmos MongoDB history store", async () => {
};

const sessionId = new ObjectId().toString();
const userId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
sessionId,
userId
);

const blankResult = await chatHistory.getMessages();
Expand Down Expand Up @@ -70,9 +72,11 @@ test("Test clear Azure Cosmos MongoDB history store", async () => {
};

const sessionId = new ObjectId().toString();
const userId = new ObjectId().toString();
const chatHistory = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId
sessionId,
userId
);

await chatHistory.addUserMessage("Who is the best vocalist?");
Expand All @@ -93,3 +97,50 @@ test("Test clear Azure Cosmos MongoDB history store", async () => {

await mongoClient.close();
});

test("Test getAllSessions and clearAllSessions", async () => {
expect(process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING).toBeDefined();

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const mongoClient = new MongoClient(
process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING!
);
const dbcfg: AzureCosmosDBMongoChatHistoryDBConfig = {
client: mongoClient,
connectionString: process.env.AZURE_COSMOSDB_MONGODB_CONNECTION_STRING,
databaseName: "langchain",
collectionName: "chathistory",
};

const sessionId1 = new ObjectId().toString();
const userId1 = new ObjectId().toString();
const sessionId2 = new ObjectId().toString();
const userId2 = new ObjectId().toString();

const chatHistory1 = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId1,
userId1
);
const chatHistory2 = new AzureCosmosDBMongoChatMessageHistory(
dbcfg,
sessionId2,
userId2
);

await chatHistory1.addUserMessage("What is AI?");
await chatHistory1.addAIChatMessage("AI stands for Artificial Intelligence.");
await chatHistory2.addUserMessage("What is the best programming language?");
await chatHistory2.addAIChatMessage("It depends on the use case.");

const allSessions = await chatHistory1.getAllSessions();
expect(allSessions.length).toBe(2);
expect(allSessions[0].id).toBe(sessionId1);
expect(allSessions[1].id).toBe(sessionId2);

await chatHistory1.clearAllSessions();
const clearedSessions = await chatHistory1.getAllSessions();
expect(clearedSessions.length).toBe(0);

await mongoClient.close();
});

0 comments on commit 3a1131a

Please sign in to comment.