Skip to content

Commit

Permalink
feat(core): Add support for boolean metadata attributes in `Functio…
Browse files Browse the repository at this point in the history
…nalTranslator` (#7407)

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
nick-w-nick and jacoblee93 authored Dec 30, 2024
1 parent 21f3b2d commit 53feade
Show file tree
Hide file tree
Showing 9 changed files with 346 additions and 13 deletions.
2 changes: 1 addition & 1 deletion langchain-core/src/structured_query/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ export class BasicTranslator<
this.allowedComparators.indexOf(func as Comparator) === -1
) {
throw new Error(
`Comparator ${func} not allowed. Allowed operators: ${this.allowedComparators.join(
`Comparator ${func} not allowed. Allowed comparators: ${this.allowedComparators.join(
", "
)}`
);
Expand Down
51 changes: 48 additions & 3 deletions langchain-core/src/structured_query/functional.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import { castValue, isFilterEmpty } from "./utils.js";
* the result of a comparison operation.
*/
type ValueType = {
eq: string | number;
ne: string | number;
eq: string | number | boolean;
ne: string | number | boolean;
lt: string | number;
lte: string | number;
gt: string | number;
Expand Down Expand Up @@ -66,6 +66,42 @@ export class FunctionalTranslator extends BaseTranslator {
throw new Error("Not implemented");
}

/**
* Returns the allowed comparators for a given data type.
* @param input The input value to get the allowed comparators for.
* @returns An array of allowed comparators for the input data type.
*/
getAllowedComparatorsForType(inputType: string): Comparator[] {
switch (inputType) {
case "string": {
return [
Comparators.eq,
Comparators.ne,
Comparators.gt,
Comparators.gte,
Comparators.lt,
Comparators.lte,
];
}
case "number": {
return [
Comparators.eq,
Comparators.ne,
Comparators.gt,
Comparators.gte,
Comparators.lt,
Comparators.lte,
];
}
case "boolean": {
return [Comparators.eq, Comparators.ne];
}
default: {
throw new Error(`Unsupported data type: ${inputType}`);
}
}
}

/**
* Returns a function that performs a comparison based on the provided
* comparator.
Expand Down Expand Up @@ -155,10 +191,19 @@ export class FunctionalTranslator extends BaseTranslator {
* @param comparison The comparison part of a structured query.
* @returns A function that takes a `Document` as an argument and returns a boolean based on the comparison.
*/
visitComparison(comparison: Comparison): this["VisitComparisonOutput"] {
visitComparison(
comparison: Comparison<string | number | boolean>
): this["VisitComparisonOutput"] {
const { comparator, attribute, value } = comparison;
const undefinedTrue = [Comparators.ne];
if (this.allowedComparators.includes(comparator)) {
if (
!this.getAllowedComparatorsForType(typeof value).includes(comparator)
) {
throw new Error(
`'${comparator}' comparator not allowed to be used with ${typeof value}`
);
}
const comparatorFunction = this.getComparatorFunction(comparator);
return (document: Document) => {
const documentValue = document.metadata[attribute];
Expand Down
6 changes: 3 additions & 3 deletions langchain-core/src/structured_query/ir.ts
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ export type VisitorOperationResult = {
*/
export type VisitorComparisonResult = {
[attr: string]: {
[comparator: string]: string | number;
[comparator: string]: string | number | boolean;
};
};

Expand Down Expand Up @@ -149,13 +149,13 @@ export abstract class FilterDirective extends Expression {}
* Class representing a comparison filter directive. It extends the
* FilterDirective class.
*/
export class Comparison extends FilterDirective {
export class Comparison<ValueTypes = string | number> extends FilterDirective {
exprName = "Comparison" as const;

constructor(
public comparator: Comparator,
public attribute: string,
public value: string | number
public value: ValueTypes
) {
super();
}
Expand Down
254 changes: 254 additions & 0 deletions langchain-core/src/structured_query/tests/functional.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
import { test, expect, describe } from "@jest/globals";
import { Document } from "../../documents/document.js";
import { FunctionalTranslator } from "../functional.js";
import { Comparators, Visitor } from "../ir.js";

describe("FunctionalTranslator", () => {
const translator = new FunctionalTranslator();

describe("getAllowedComparatorsForType", () => {
test("string", () => {
expect(translator.getAllowedComparatorsForType("string")).toEqual([
Comparators.eq,
Comparators.ne,
Comparators.gt,
Comparators.gte,
Comparators.lt,
Comparators.lte,
]);
});
test("number", () => {
expect(translator.getAllowedComparatorsForType("number")).toEqual([
Comparators.eq,
Comparators.ne,
Comparators.gt,
Comparators.gte,
Comparators.lt,
Comparators.lte,
]);
});
test("boolean", () => {
expect(translator.getAllowedComparatorsForType("boolean")).toEqual([
Comparators.eq,
Comparators.ne,
]);
});
test("unsupported", () => {
expect(() =>
translator.getAllowedComparatorsForType("unsupported")
).toThrow("Unsupported data type: unsupported");
});
});

describe("visitComparison", () => {
describe("returns true or false for valid comparisons", () => {
const attributesByType = {
string: "stringValue",
number: "numberValue",
boolean: "booleanValue",
};

const inputValuesByAttribute: {
[key in string]: string | number | boolean;
} = {
stringValue: "value",
numberValue: 1,
booleanValue: true,
};

// documents that will match against the comparison
const validDocumentsByComparator: {
[key in string]: Document<Record<string, unknown>>[];
} = {
[Comparators.eq]: [
new Document({
pageContent: "",
metadata: {
stringValue: "value",
numberValue: 1,
booleanValue: true,
},
}),
],
[Comparators.ne]: [
new Document({
pageContent: "",
metadata: {
stringValue: "not-value",
numberValue: 0,
booleanValue: false,
},
}),
],
[Comparators.gt]: [
new Document({
pageContent: "",
metadata: {
stringValue: "valueee",
numberValue: 2,
booleanValue: true,
},
}),
],
[Comparators.gte]: [
// test for greater than
new Document({
pageContent: "",
metadata: {
stringValue: "valueee",
numberValue: 2,
booleanValue: true,
},
}),
// test for equal to
new Document({
pageContent: "",
metadata: {
stringValue: "value",
numberValue: 1,
booleanValue: true,
},
}),
],
[Comparators.lt]: [
new Document({
pageContent: "",
metadata: {
stringValue: "val",
numberValue: 0,
booleanValue: true,
},
}),
],
[Comparators.lte]: [
// test for less than
new Document({
pageContent: "",
metadata: {
stringValue: "val",
numberValue: 0,
booleanValue: true,
},
}),
// test for equal to
new Document({
pageContent: "",
metadata: {
stringValue: "value",
numberValue: 1,
booleanValue: true,
},
}),
],
};

// documents that will not match against the comparison
const invalidDocumentsByComparator: {
[key in string]: Document<Record<string, unknown>>[];
} = {
[Comparators.eq]: [
new Document({
pageContent: "",
metadata: {
stringValue: "not-value",
numberValue: 0,
booleanValue: false,
},
}),
],
[Comparators.ne]: [
new Document({
pageContent: "",
metadata: {
stringValue: "value",
numberValue: 1,
booleanValue: true,
},
}),
],
[Comparators.gt]: [
new Document({
pageContent: "",
metadata: {
stringValue: "value",
numberValue: 1,
booleanValue: true,
},
}),
],
[Comparators.gte]: [
new Document({
pageContent: "",
metadata: {
stringValue: "val",
numberValue: 0,
booleanValue: true,
},
}),
],
[Comparators.lt]: [
new Document({
pageContent: "",
metadata: {
stringValue: "valueee",
numberValue: 2,
booleanValue: true,
},
}),
],
[Comparators.lte]: [
new Document({
pageContent: "",
metadata: {
stringValue: "valueee",
numberValue: 2,
booleanValue: true,
},
}),
],
};

function generateComparatorTestsForType(
type: "string" | "number" | "boolean"
) {
const comparators = translator.getAllowedComparatorsForType(type);
for (const comparator of comparators) {
const attribute = attributesByType[type];
const value = inputValuesByAttribute[attribute];
const validDocuments = validDocumentsByComparator[comparator];
for (const validDocument of validDocuments) {
test(`${value} -> ${comparator} -> ${validDocument.metadata[attribute]}`, () => {
const comparison = translator.visitComparison({
attribute,
comparator,
value,
exprName: "Comparison",
accept: (visitor: Visitor) => visitor,
});
const result = comparison(validDocument);
expect(result).toBeTruthy();
});
}
const invalidDocuments = invalidDocumentsByComparator[comparator];
for (const invalidDocument of invalidDocuments) {
test(`${value} -> ${comparator} -> ${invalidDocument.metadata[attribute]}`, () => {
const comparison = translator.visitComparison({
attribute,
comparator,
value,
exprName: "Comparison",
accept: (visitor: Visitor) => visitor,
});
const result = comparison(invalidDocument);
expect(result).toBeFalsy();
});
}
}
}

generateComparatorTestsForType("string");
generateComparatorTestsForType("number");
generateComparatorTestsForType("boolean");
});
});
});
9 changes: 8 additions & 1 deletion langchain-core/src/structured_query/tests/utils.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* eslint-disable no-process-env */
import { test, expect } from "@jest/globals";
import { castValue, isFloat, isInt, isString } from "../utils.js";
import { castValue, isFloat, isInt, isString, isBoolean } from "../utils.js";

test("Casting values correctly", () => {
const stringString = [
Expand Down Expand Up @@ -28,6 +28,8 @@ test("Casting values correctly", () => {

const floatFloat = ["1.1", 2.2, 3.3];

const booleanBoolean = [true, false];

stringString.map(castValue).forEach((value) => {
expect(typeof value).toBe("string");
expect(isString(value)).toBe(true);
Expand All @@ -54,4 +56,9 @@ test("Casting values correctly", () => {
expect(typeof value).toBe("number");
expect(isFloat(value)).toBe(true);
});

booleanBoolean.map(castValue).forEach((value) => {
expect(typeof value).toBe("boolean");
expect(isBoolean(value)).toBe(true);
});
});
Loading

0 comments on commit 53feade

Please sign in to comment.