From ba5c7b90c5180125ff083da731ba59e78438cfff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Tamarelle?= Date: Tue, 26 Nov 2024 19:13:15 +0100 Subject: [PATCH] Implement $vectorSearch stage builder --- generator/config/search/queryString.yaml | 2 +- generator/config/search/wildcard.yaml | 2 +- generator/config/stage/vectorSearch.yaml | 115 +++++++++++++ src/Builder/Stage/FactoryTrait.php | 24 +++ src/Builder/Stage/FluentFactoryTrait.php | 26 +++ src/Builder/Stage/VectorSearchStage.php | 93 ++++++++++ tests/Builder/Stage/Pipelines.php | 159 ++++++++++++++++++ tests/Builder/Stage/VectorSearchStageTest.php | 85 ++++++++++ 8 files changed, 504 insertions(+), 2 deletions(-) create mode 100644 generator/config/stage/vectorSearch.yaml create mode 100644 src/Builder/Stage/VectorSearchStage.php create mode 100644 tests/Builder/Stage/VectorSearchStageTest.php diff --git a/generator/config/search/queryString.yaml b/generator/config/search/queryString.yaml index 0bebdf45a..8202771c9 100644 --- a/generator/config/search/queryString.yaml +++ b/generator/config/search/queryString.yaml @@ -32,4 +32,4 @@ tests: - $project: _id: 0 - title: 1 \ No newline at end of file + title: 1 diff --git a/generator/config/search/wildcard.yaml b/generator/config/search/wildcard.yaml index 1dd256e1a..d17fb4803 100644 --- a/generator/config/search/wildcard.yaml +++ b/generator/config/search/wildcard.yaml @@ -57,4 +57,4 @@ tests: - $project: _id: 0 - title: 1 \ No newline at end of file + title: 1 diff --git a/generator/config/stage/vectorSearch.yaml b/generator/config/stage/vectorSearch.yaml new file mode 100644 index 000000000..5e4fa4605 --- /dev/null +++ b/generator/config/stage/vectorSearch.yaml @@ -0,0 +1,115 @@ +# $schema: ../schema.json +name: $vectorSearch +link: 'https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/' +type: + - stage +encode: object +description: | + The $vectorSearch stage performs an ANN or ENN search on a vector in the specified field. +arguments: + - + name: index + type: + - string + - + name: limit + type: + - int + - + name: path + type: + - searchPath + - + name: queryVector + type: + - array # of numbers + - + name: exact + optional: true + type: + - bool + - + name: filter + optional: true + type: + - query + - + name: numCandidates + optional: true + type: + - int +tests: + - + name: 'ANN Basic' + link: 'https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#ann-examples' + pipeline: + - + $vectorSearch: + index: 'vector_index' + path: 'plot_embedding' + queryVector: + - -0.0016261312 + - -0.028070757 + - -0.011342932 + # skip other numbers, not relevant to the test + numCandidates: 150 + limit: 10 + - + $project: + _id: 0 + plot: 1 + title: 1 + score: + $meta: 'vectorSearchScore' + + - + name: 'ANN Filter' + link: 'https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#ann-examples' + pipeline: + - + $vectorSearch: + index: 'vector_index' + path: 'plot_embedding' + filter: + $and: + - + year: + $lt: 1975 + queryVector: + - 0.02421053 + - -0.022372592 + - -0.006231137 + # skip other numbers, not relevant to the test + numCandidates: 150 + limit: 10 + - + $project: + _id: 0 + title: 1 + plot: 1 + year: 1 + score: + $meta: 'vectorSearchScore' + + - + name: 'ENN' + link: 'https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#enn-examples' + pipeline: + - + $vectorSearch: + index: 'vector_index' + path: 'plot_embedding' + queryVector: + - -0.006954097 + - -0.009932499 + - -0.001311474 + # skip other numbers, not relevant to the test + exact: true + limit: 10 + - + $project: + _id: 0 + plot: 1 + title: 1 + score: + $meta: 'vectorSearchScore' diff --git a/src/Builder/Stage/FactoryTrait.php b/src/Builder/Stage/FactoryTrait.php index abf8f10a0..94430da02 100644 --- a/src/Builder/Stage/FactoryTrait.php +++ b/src/Builder/Stage/FactoryTrait.php @@ -705,4 +705,28 @@ public static function unwind( ): UnwindStage { return new UnwindStage($path, $includeArrayIndex, $preserveNullAndEmptyArrays); } + + /** + * The $vectorSearch stage performs an ANN or ENN search on a vector in the specified field. + * + * @see https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ + * @param string $index + * @param int $limit + * @param array|string $path + * @param BSONArray|PackedArray|array $queryVector + * @param Optional|bool $exact + * @param Optional|QueryInterface|array $filter + * @param Optional|int $numCandidates + */ + public static function vectorSearch( + string $index, + int $limit, + array|string $path, + PackedArray|BSONArray|array $queryVector, + Optional|bool $exact = Optional::Undefined, + Optional|QueryInterface|array $filter = Optional::Undefined, + Optional|int $numCandidates = Optional::Undefined, + ): VectorSearchStage { + return new VectorSearchStage($index, $limit, $path, $queryVector, $exact, $filter, $numCandidates); + } } diff --git a/src/Builder/Stage/FluentFactoryTrait.php b/src/Builder/Stage/FluentFactoryTrait.php index ab0793f53..5d865ae9b 100644 --- a/src/Builder/Stage/FluentFactoryTrait.php +++ b/src/Builder/Stage/FluentFactoryTrait.php @@ -794,4 +794,30 @@ public function unwind( return $this; } + + /** + * The $vectorSearch stage performs an ANN or ENN search on a vector in the specified field. + * + * @see https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ + * @param string $index + * @param int $limit + * @param array|string $path + * @param BSONArray|PackedArray|array $queryVector + * @param Optional|bool $exact + * @param Optional|QueryInterface|array $filter + * @param Optional|int $numCandidates + */ + public function vectorSearch( + string $index, + int $limit, + array|string $path, + PackedArray|BSONArray|array $queryVector, + Optional|bool $exact = Optional::Undefined, + Optional|QueryInterface|array $filter = Optional::Undefined, + Optional|int $numCandidates = Optional::Undefined, + ): static { + $this->pipeline[] = Stage::vectorSearch($index, $limit, $path, $queryVector, $exact, $filter, $numCandidates); + + return $this; + } } diff --git a/src/Builder/Stage/VectorSearchStage.php b/src/Builder/Stage/VectorSearchStage.php new file mode 100644 index 000000000..1a84c1c39 --- /dev/null +++ b/src/Builder/Stage/VectorSearchStage.php @@ -0,0 +1,93 @@ +index = $index; + $this->limit = $limit; + $this->path = $path; + if (is_array($queryVector) && ! array_is_list($queryVector)) { + throw new InvalidArgumentException('Expected $queryVector argument to be a list, got an associative array.'); + } + + $this->queryVector = $queryVector; + $this->exact = $exact; + if (is_array($filter)) { + $filter = QueryObject::create($filter); + } + + $this->filter = $filter; + $this->numCandidates = $numCandidates; + } + + public function getOperator(): string + { + return '$vectorSearch'; + } +} diff --git a/tests/Builder/Stage/Pipelines.php b/tests/Builder/Stage/Pipelines.php index eefa4e714..850013c3e 100644 --- a/tests/Builder/Stage/Pipelines.php +++ b/tests/Builder/Stage/Pipelines.php @@ -3553,4 +3553,163 @@ enum Pipelines: string } ] JSON; + + /** + * ANN Basic + * + * @see https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#ann-examples + */ + case VectorSearchANNBasic = <<<'JSON' + [ + { + "$vectorSearch": { + "index": "vector_index", + "path": "plot_embedding", + "queryVector": [ + { + "$numberDouble": "-0.0016261311999999999121" + }, + { + "$numberDouble": "-0.028070756999999998266" + }, + { + "$numberDouble": "-0.011342932000000000015" + } + ], + "numCandidates": { + "$numberInt": "150" + }, + "limit": { + "$numberInt": "10" + } + } + }, + { + "$project": { + "_id": { + "$numberInt": "0" + }, + "plot": { + "$numberInt": "1" + }, + "title": { + "$numberInt": "1" + }, + "score": { + "$meta": "vectorSearchScore" + } + } + } + ] + JSON; + + /** + * ANN Filter + * + * @see https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#ann-examples + */ + case VectorSearchANNFilter = <<<'JSON' + [ + { + "$vectorSearch": { + "index": "vector_index", + "path": "plot_embedding", + "filter": { + "$and": [ + { + "year": { + "$lt": { + "$numberInt": "1975" + } + } + } + ] + }, + "queryVector": [ + { + "$numberDouble": "0.024210530000000000939" + }, + { + "$numberDouble": "-0.022372592000000000173" + }, + { + "$numberDouble": "-0.0062311370000000003075" + } + ], + "numCandidates": { + "$numberInt": "150" + }, + "limit": { + "$numberInt": "10" + } + } + }, + { + "$project": { + "_id": { + "$numberInt": "0" + }, + "title": { + "$numberInt": "1" + }, + "plot": { + "$numberInt": "1" + }, + "year": { + "$numberInt": "1" + }, + "score": { + "$meta": "vectorSearchScore" + } + } + } + ] + JSON; + + /** + * ENN + * + * @see https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#enn-examples + */ + case VectorSearchENN = <<<'JSON' + [ + { + "$vectorSearch": { + "index": "vector_index", + "path": "plot_embedding", + "queryVector": [ + { + "$numberDouble": "-0.0069540970000000002296" + }, + { + "$numberDouble": "-0.009932498999999999148" + }, + { + "$numberDouble": "-0.0013114739999999999731" + } + ], + "exact": true, + "limit": { + "$numberInt": "10" + } + } + }, + { + "$project": { + "_id": { + "$numberInt": "0" + }, + "plot": { + "$numberInt": "1" + }, + "title": { + "$numberInt": "1" + }, + "score": { + "$meta": "vectorSearchScore" + } + } + } + ] + JSON; } diff --git a/tests/Builder/Stage/VectorSearchStageTest.php b/tests/Builder/Stage/VectorSearchStageTest.php new file mode 100644 index 000000000..f6259510a --- /dev/null +++ b/tests/Builder/Stage/VectorSearchStageTest.php @@ -0,0 +1,85 @@ + 'vectorSearchScore'], + ), + ); + + $this->assertSamePipeline(Pipelines::VectorSearchANNBasic, $pipeline); + } + + public function testANNFilter(): void + { + $pipeline = new Pipeline( + Stage::vectorSearch( + index: 'vector_index', + limit: 10, + path: 'plot_embedding', + queryVector: [0.02421053, -0.022372592, -0.006231137], + filter: Query::and( + Query::query( + year: Query::lt(1975), + ), + ), + numCandidates: 150, + ), + Stage::project( + _id: 0, + title: 1, + plot: 1, + year: 1, + score: ['$meta' => 'vectorSearchScore'], + ), + ); + + $this->assertSamePipeline(Pipelines::VectorSearchANNFilter, $pipeline); + } + + public function testENN(): void + { + $pipeline = new Pipeline( + Stage::vectorSearch( + index: 'vector_index', + limit: 10, + path: 'plot_embedding', + queryVector: [-0.006954097, -0.009932499, -0.001311474], + exact: true, + ), + Stage::project( + _id: 0, + title: 1, + plot: 1, + score: ['$meta' => 'vectorSearchScore'], + ), + ); + + $this->assertSamePipeline(Pipelines::VectorSearchENN, $pipeline); + } +}