Skip to content

Commit be7fa9b

Browse files
authored
Merge pull request #554 from CommanderStorm/vector-search-embedder
Support `embedders` setting and other vector/hybrid search related configuration
2 parents 6de4acf + 4c11a52 commit be7fa9b

File tree

2 files changed

+671
-26
lines changed

2 files changed

+671
-26
lines changed

src/search.rs

Lines changed: 199 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,30 @@ pub enum Selectors<T> {
153153
All,
154154
}
155155

156+
/// Setting whether to utilise previously defined embedders for semantic searching
157+
#[derive(Debug, Serialize, Clone)]
158+
#[serde(rename_all = "camelCase")]
159+
pub struct HybridSearch<'a> {
160+
/// Indicates one of the embedders configured for the queried index
161+
///
162+
/// **Default: `"default"`**
163+
pub embedder: &'a str,
164+
/// number between `0` and `1`:
165+
/// - `0.0` indicates full keyword search
166+
/// - `1.0` indicates full semantic search
167+
///
168+
/// **Default: `0.5`**
169+
pub semantic_ratio: f32,
170+
}
171+
impl Default for HybridSearch<'_> {
172+
fn default() -> Self {
173+
HybridSearch {
174+
embedder: "default",
175+
semantic_ratio: 0.5,
176+
}
177+
}
178+
}
179+
156180
type AttributeToCrop<'a> = (&'a str, Option<usize>);
157181

158182
/// A struct representing a query.
@@ -361,6 +385,20 @@ pub struct SearchQuery<'a, Http: HttpClient> {
361385

362386
#[serde(skip_serializing_if = "Option::is_none")]
363387
pub(crate) index_uid: Option<&'a str>,
388+
389+
/// Defines whether to utilise previously defined embedders for semantic searching
390+
#[serde(skip_serializing_if = "Option::is_none")]
391+
pub hybrid: Option<HybridSearch<'a>>,
392+
393+
/// Defines what vectors an userprovided embedder has gotten for semantic searching
394+
#[serde(skip_serializing_if = "Option::is_none")]
395+
pub vector: Option<&'a [f32]>,
396+
397+
/// Defines whether vectors for semantic searching are returned in the search results
398+
///
399+
/// Can Significantly increase the response size.
400+
#[serde(skip_serializing_if = "Option::is_none")]
401+
pub retrieve_vectors: Option<bool>,
364402
}
365403

366404
#[allow(missing_docs)]
@@ -390,6 +428,9 @@ impl<'a, Http: HttpClient> SearchQuery<'a, Http> {
390428
show_ranking_score_details: None,
391429
matching_strategy: None,
392430
index_uid: None,
431+
hybrid: None,
432+
vector: None,
433+
retrieve_vectors: None,
393434
distinct: None,
394435
ranking_score_threshold: None,
395436
locales: None,
@@ -485,6 +526,16 @@ impl<'a, Http: HttpClient> SearchQuery<'a, Http> {
485526
self.filter = Some(Filter::new(Either::Right(filter)));
486527
self
487528
}
529+
/// Defines whether vectors for semantic searching are returned in the search results
530+
///
531+
/// Can Significantly increase the response size.
532+
pub fn with_retrieve_vectors<'b>(
533+
&'b mut self,
534+
retrieve_vectors: bool,
535+
) -> &'b mut SearchQuery<'a, Http> {
536+
self.retrieve_vectors = Some(retrieve_vectors);
537+
self
538+
}
488539
pub fn with_facets<'b>(
489540
&'b mut self,
490541
facets: Selectors<&'a [&'a str]>,
@@ -585,6 +636,23 @@ impl<'a, Http: HttpClient> SearchQuery<'a, Http> {
585636
self.index_uid = Some(&self.index.uid);
586637
self
587638
}
639+
/// Defines whether to utilise previously defined embedders for semantic searching
640+
pub fn with_hybrid<'b>(
641+
&'b mut self,
642+
embedder: &'a str,
643+
semantic_ratio: f32,
644+
) -> &'b mut SearchQuery<'a, Http> {
645+
self.hybrid = Some(HybridSearch {
646+
embedder,
647+
semantic_ratio,
648+
});
649+
self
650+
}
651+
/// Defines what vectors an userprovided embedder has gotten for semantic searching
652+
pub fn with_vector<'b>(&'b mut self, vector: &'a [f32]) -> &'b mut SearchQuery<'a, Http> {
653+
self.vector = Some(vector);
654+
self
655+
}
588656
pub fn with_distinct<'b>(&'b mut self, distinct: &'a str) -> &'b mut SearchQuery<'a, Http> {
589657
self.distinct = Some(distinct);
590658
self
@@ -857,6 +925,36 @@ mod tests {
857925
kind: String,
858926
number: i32,
859927
nested: Nested,
928+
#[serde(skip_serializing_if = "Option::is_none", default)]
929+
_vectors: Option<Vectors>,
930+
}
931+
932+
#[derive(Debug, Serialize, Deserialize, PartialEq)]
933+
struct Vector {
934+
embeddings: SingleOrMultipleVectors,
935+
regenerate: bool,
936+
}
937+
938+
#[derive(Serialize, Deserialize, Debug, PartialEq)]
939+
#[serde(untagged)]
940+
enum SingleOrMultipleVectors {
941+
Single(Vec<f32>),
942+
Multiple(Vec<Vec<f32>>),
943+
}
944+
945+
#[derive(Debug, Serialize, Deserialize, PartialEq)]
946+
struct Vectors(HashMap<String, Vector>);
947+
948+
impl From<&[f32; 1]> for Vectors {
949+
fn from(value: &[f32; 1]) -> Self {
950+
Vectors(HashMap::from([(
951+
S("default"),
952+
Vector {
953+
embeddings: SingleOrMultipleVectors::Multiple(Vec::from([value.to_vec()])),
954+
regenerate: false,
955+
},
956+
)]))
957+
}
860958
}
861959

862960
impl PartialEq<Map<String, Value>> for Document {
@@ -870,16 +968,16 @@ mod tests {
870968

871969
async fn setup_test_index(client: &Client, index: &Index) -> Result<(), Error> {
872970
let t0 = index.add_documents(&[
873-
Document { id: 0, kind: "text".into(), number: 0, value: S("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."), nested: Nested { child: S("first") } },
874-
Document { id: 1, kind: "text".into(), number: 10, value: S("dolor sit amet, consectetur adipiscing elit"), nested: Nested { child: S("second") } },
875-
Document { id: 2, kind: "title".into(), number: 20, value: S("The Social Network"), nested: Nested { child: S("third") } },
876-
Document { id: 3, kind: "title".into(), number: 30, value: S("Harry Potter and the Sorcerer's Stone"), nested: Nested { child: S("fourth") } },
877-
Document { id: 4, kind: "title".into(), number: 40, value: S("Harry Potter and the Chamber of Secrets"), nested: Nested { child: S("fift") } },
878-
Document { id: 5, kind: "title".into(), number: 50, value: S("Harry Potter and the Prisoner of Azkaban"), nested: Nested { child: S("sixth") } },
879-
Document { id: 6, kind: "title".into(), number: 60, value: S("Harry Potter and the Goblet of Fire"), nested: Nested { child: S("seventh") } },
880-
Document { id: 7, kind: "title".into(), number: 70, value: S("Harry Potter and the Order of the Phoenix"), nested: Nested { child: S("eighth") } },
881-
Document { id: 8, kind: "title".into(), number: 80, value: S("Harry Potter and the Half-Blood Prince"), nested: Nested { child: S("ninth") } },
882-
Document { id: 9, kind: "title".into(), number: 90, value: S("Harry Potter and the Deathly Hallows"), nested: Nested { child: S("tenth") } },
971+
Document { id: 0, kind: "text".into(), number: 0, value: S("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."), nested: Nested { child: S("first") }, _vectors: Some(Vectors::from(&[1000.0]))},
972+
Document { id: 1, kind: "text".into(), number: 10, value: S("dolor sit amet, consectetur adipiscing elit"), nested: Nested { child: S("second") }, _vectors: Some(Vectors::from(&[2000.0])) },
973+
Document { id: 2, kind: "title".into(), number: 20, value: S("The Social Network"), nested: Nested { child: S("third") }, _vectors: Some(Vectors::from(&[3000.0])) },
974+
Document { id: 3, kind: "title".into(), number: 30, value: S("Harry Potter and the Sorcerer's Stone"), nested: Nested { child: S("fourth") }, _vectors: Some(Vectors::from(&[4000.0])) },
975+
Document { id: 4, kind: "title".into(), number: 40, value: S("Harry Potter and the Chamber of Secrets"), nested: Nested { child: S("fift") }, _vectors: Some(Vectors::from(&[5000.0])) },
976+
Document { id: 5, kind: "title".into(), number: 50, value: S("Harry Potter and the Prisoner of Azkaban"), nested: Nested { child: S("sixth") }, _vectors: Some(Vectors::from(&[6000.0])) },
977+
Document { id: 6, kind: "title".into(), number: 60, value: S("Harry Potter and the Goblet of Fire"), nested: Nested { child: S("seventh") }, _vectors: Some(Vectors::from(&[7000.0])) },
978+
Document { id: 7, kind: "title".into(), number: 70, value: S("Harry Potter and the Order of the Phoenix"), nested: Nested { child: S("eighth") }, _vectors: Some(Vectors::from(&[8000.0])) },
979+
Document { id: 8, kind: "title".into(), number: 80, value: S("Harry Potter and the Half-Blood Prince"), nested: Nested { child: S("ninth") }, _vectors: Some(Vectors::from(&[9000.0])) },
980+
Document { id: 9, kind: "title".into(), number: 90, value: S("Harry Potter and the Deathly Hallows"), nested: Nested { child: S("tenth") }, _vectors: Some(Vectors::from(&[10000.0])) },
883981
], None).await?;
884982
let t1 = index
885983
.set_filterable_attributes(["kind", "value", "number"])
@@ -967,7 +1065,8 @@ mod tests {
9671065
value: S("dolor sit amet, consectetur adipiscing elit"),
9681066
kind: S("text"),
9691067
number: 10,
970-
nested: Nested { child: S("second") }
1068+
nested: Nested { child: S("second") },
1069+
_vectors: None,
9711070
},
9721071
&results.hits[0].result
9731072
);
@@ -1139,7 +1238,8 @@ mod tests {
11391238
value: S("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do…"),
11401239
kind: S("text"),
11411240
number: 0,
1142-
nested: Nested { child: S("first") }
1241+
nested: Nested { child: S("first") },
1242+
_vectors: None,
11431243
},
11441244
results.hits[0].formatted_result.as_ref().unwrap()
11451245
);
@@ -1154,7 +1254,8 @@ mod tests {
11541254
value: S("Lorem ipsum dolor sit amet…"),
11551255
kind: S("text"),
11561256
number: 0,
1157-
nested: Nested { child: S("first") }
1257+
nested: Nested { child: S("first") },
1258+
_vectors: None,
11581259
},
11591260
results.hits[0].formatted_result.as_ref().unwrap()
11601261
);
@@ -1175,7 +1276,8 @@ mod tests {
11751276
value: S("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."),
11761277
kind: S("text"),
11771278
number: 0,
1178-
nested: Nested { child: S("first") }
1279+
nested: Nested { child: S("first") },
1280+
_vectors: None,
11791281
},
11801282
results.hits[0].formatted_result.as_ref().unwrap());
11811283

@@ -1190,7 +1292,8 @@ mod tests {
11901292
value: S("Lorem ipsum dolor sit amet…"),
11911293
kind: S("text"),
11921294
number: 0,
1193-
nested: Nested { child: S("first") }
1295+
nested: Nested { child: S("first") },
1296+
_vectors: None,
11941297
},
11951298
results.hits[0].formatted_result.as_ref().unwrap()
11961299
);
@@ -1215,7 +1318,8 @@ mod tests {
12151318
value: S("(ꈍᴗꈍ)sed do eiusmod tempor incididunt ut(ꈍᴗꈍ)"),
12161319
kind: S("text"),
12171320
number: 0,
1218-
nested: Nested { child: S("first") }
1321+
nested: Nested { child: S("first") },
1322+
_vectors: None,
12191323
},
12201324
results.hits[0].formatted_result.as_ref().unwrap()
12211325
);
@@ -1242,7 +1346,8 @@ mod tests {
12421346
value: S("The (⊃。•́‿•̀。)⊃ Social ⊂(´• ω •`⊂) Network"),
12431347
kind: S("title"),
12441348
number: 20,
1245-
nested: Nested { child: S("third") }
1349+
nested: Nested { child: S("third") },
1350+
_vectors: None,
12461351
},
12471352
results.hits[0].formatted_result.as_ref().unwrap()
12481353
);
@@ -1264,7 +1369,8 @@ mod tests {
12641369
value: S("<em>dolor</em> sit amet, consectetur adipiscing elit"),
12651370
kind: S("<em>text</em>"),
12661371
number: 10,
1267-
nested: Nested { child: S("first") }
1372+
nested: Nested { child: S("second") },
1373+
_vectors: None,
12681374
},
12691375
results.hits[0].formatted_result.as_ref().unwrap(),
12701376
);
@@ -1279,7 +1385,8 @@ mod tests {
12791385
value: S("<em>dolor</em> sit amet, consectetur adipiscing elit"),
12801386
kind: S("text"),
12811387
number: 10,
1282-
nested: Nested { child: S("first") }
1388+
nested: Nested { child: S("second") },
1389+
_vectors: None,
12831390
},
12841391
results.hits[0].formatted_result.as_ref().unwrap()
12851392
);
@@ -1479,6 +1586,22 @@ mod tests {
14791586
Ok(())
14801587
}
14811588

1589+
/// enable vector searching and configure an userProvided embedder
1590+
async fn setup_hybrid_searching(client: &Client, index: &Index) -> Result<(), Error> {
1591+
use crate::settings::{Embedder, UserProvidedEmbedderSettings};
1592+
let embedder_setting =
1593+
Embedder::UserProvided(UserProvidedEmbedderSettings { dimensions: 1 });
1594+
index
1595+
.set_settings(&crate::settings::Settings {
1596+
embedders: Some(HashMap::from([("default".to_string(), embedder_setting)])),
1597+
..crate::settings::Settings::default()
1598+
})
1599+
.await?
1600+
.wait_for_completion(&client, None, None)
1601+
.await?;
1602+
Ok(())
1603+
}
1604+
14821605
#[meilisearch_test]
14831606
async fn test_facet_search_base(client: Client, index: Index) -> Result<(), Error> {
14841607
setup_test_index(&client, &index).await?;
@@ -1540,8 +1663,64 @@ mod tests {
15401663
assert_eq!(res.facet_hits.len(), 1);
15411664
Ok(())
15421665
}
1543-
1666+
15441667
#[meilisearch_test]
1668+
async fn test_with_vectors(client: Client, index: Index) -> Result<(), Error> {
1669+
setup_hybrid_searching(&client, &index).await?;
1670+
setup_test_index(&client, &index).await?;
1671+
1672+
let results: SearchResults<Document> = index
1673+
.search()
1674+
.with_query("lorem ipsum")
1675+
.with_retrieve_vectors(true)
1676+
.execute()
1677+
.await?;
1678+
assert_eq!(results.hits.len(), 1);
1679+
let expected = Vectors::from(&[1000.0]);
1680+
assert_eq!(results.hits[0].result._vectors, Some(expected));
1681+
1682+
let results: SearchResults<Document> = index
1683+
.search()
1684+
.with_query("lorem ipsum")
1685+
.with_retrieve_vectors(false)
1686+
.execute()
1687+
.await?;
1688+
assert_eq!(results.hits.len(), 1);
1689+
assert_eq!(results.hits[0].result._vectors, None);
1690+
Ok(())
1691+
}
1692+
1693+
#[tokio::test]
1694+
async fn test_hybrid() -> Result<(), Error> {
1695+
// this is mocked as I could not get the hybrid searching to work
1696+
// See https://github.com/meilisearch/meilisearch-rust/pull/554 for further context
1697+
let mut s = mockito::Server::new_async().await;
1698+
let mock_server_url = s.url();
1699+
let client = Client::new(mock_server_url, None::<String>)?;
1700+
let index = client.index("mocked_index");
1701+
1702+
let req = r#"{"q":"hello hybrid searching","hybrid":{"embedder":"default","semanticRatio":0.0},"vector":[1000.0]}"#.to_string();
1703+
let response = r#"{"hits":[],"offset":null,"limit":null,"estimatedTotalHits":null,"page":null,"hitsPerPage":null,"totalHits":null,"totalPages":null,"facetDistribution":null,"facetStats":null,"processingTimeMs":0,"query":"","indexUid":null}"#.to_string();
1704+
let mock_res = s
1705+
.mock("POST", "/indexes/mocked_index/search")
1706+
.with_status(200)
1707+
.match_body(mockito::Matcher::Exact(req))
1708+
.with_body(&response)
1709+
.expect(1)
1710+
.create_async()
1711+
.await;
1712+
let results: Result<SearchResults<Document>, Error> = index
1713+
.search()
1714+
.with_query("hello hybrid searching")
1715+
.with_hybrid("default", 0.0)
1716+
.with_vector(&[1000.0])
1717+
.execute()
1718+
.await;
1719+
mock_res.assert_async().await;
1720+
results?; // purposely not done above to have better debugging output
1721+
Ok(())
1722+
}
1723+
15451724
async fn test_facet_search_with_search_query(
15461725
client: Client,
15471726
index: Index,

0 commit comments

Comments
 (0)