Skip to content

Commit

Permalink
chore(langchain): add test to ensure sort order on MemoryVectorStore
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamincburns committed Mar 3, 2025
1 parent 0d6b66c commit b3304d9
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions langchain/src/vectorstores/tests/memory.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { test, expect } from "@jest/globals";

import { Document, DocumentInterface } from "@langchain/core/documents";
import { SyntheticEmbeddings } from "@langchain/core/utils/testing";
import { Embeddings } from "@langchain/core/embeddings";
import { MemoryVectorStore } from "../memory.js";
import { cosine } from "../../util/ml-distance/similarities.js";

Expand Down Expand Up @@ -165,3 +166,91 @@ test("MemoryVectorStore with max marginal relevance", async () => {
expect(similarityCalledCount).toBe(4);
expect(results).toHaveLength(3);
});

test("MemoryVectorStore sorts results in descending order of similarity", async () => {
const embeddings = new Map<string, number[]>([
["Document A", [0]],
["Document B", [1]],
["Document C", [2]],
["Document D", [3]],
]);

const reverseEmbeddings = new Map<number[], string>(
Array.from(embeddings.entries()).map(([key, value]) => [value, key])
);
class ContrivedEmbeddings extends Embeddings {

async embedDocuments(documents: string[]): Promise<number[][]> {
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return documents.map((text) => embeddings.get(text)!);
}

async embedQuery(text: string): Promise<number[]> {
if (!embeddings.has(text)) {
throw new Error(`Document ${text} not found`);
}
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
return embeddings.get(text)!;
}
}

function* permutations<T>(items: T[]): Generator<T[]> {
if (items.length <= 1) {
yield [...items];
return;
}

for (let i = 0; i < items.length; i++) {
const rest = [...items.slice(0, i), ...items.slice(i + 1)];
for (const perm of permutations(rest)) {
yield [items[i], ...perm];
}
}
}

function similarity(query: number[], vector: number[]): number {
const queryText = reverseEmbeddings.get(query);
if (queryText !== "Document D") {
throw new Error(`Similarity metric only valid for Document D`);
}
const docText = reverseEmbeddings.get(vector);

switch (docText) {
case "Document A": return 0.23351;
case "Document B": return 0.062168;
case "Document C": return 0.169842;
default: return 0;
}
}


for (const documentOrdering of permutations(["Document A", "Document B", "Document C"])) {
const store = new MemoryVectorStore(new ContrivedEmbeddings({}), { similarity });

// Add documents with dummy embeddings
for (const document of documentOrdering) {
await store.addDocuments([
{ pageContent: document, metadata: { a: 1 } },
]);
}

const results = await store.similaritySearchWithScore("Document D", 3);

// Get the IDs in the order they were returned
const resultOrder = results.map(([{pageContent}]) => pageContent);

// Get the similarity scores
const similarityScores = results.map(([, score]) => score);

// With the correct sorting logic, we would expect:
const expectedOrder = ["Document A", "Document C", "Document B"];

// This expectation might fail with the current implementation
// because .sort((a, b) => (a.similarity > b.similarity ? -1 : 0)) is broken
expect(resultOrder).toEqual(expectedOrder);

// The similarity scores should be in descending order
expect(similarityScores[0]).toBeGreaterThan(similarityScores[1]);
expect(similarityScores[1]).toBeGreaterThan(similarityScores[2]);
}
});

0 comments on commit b3304d9

Please sign in to comment.