diff --git a/Cargo.toml b/Cargo.toml index fc759c3384..a188fba469 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,6 +53,7 @@ itertools = "0.11.0" measure_time = "0.8.2" async-trait = "0.1.53" arc-swap = "1.5.0" +rstar = { version = "0.11", optional = true } columnar = { version= "0.1", path="./columnar", package ="tantivy-columnar" } sstable = { version= "0.1", path="./sstable", package ="tantivy-sstable", optional = true } @@ -113,6 +114,8 @@ unstable = [] # useful for benches. quickwit = ["sstable", "futures-util"] +spatial = ["rstar"] + [workspace] members = ["query-grammar", "bitpacker", "common", "ownedbytes", "stacker", "sstable", "tokenizer-api", "columnar"] diff --git a/src/lib.rs b/src/lib.rs index ad70857650..9f5026837e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -158,6 +158,9 @@ pub mod space_usage; pub mod store; pub mod termdict; +#[cfg(feature = "spatial")] +pub mod spatial; + mod reader; pub use self::reader::{IndexReader, IndexReaderBuilder, ReloadPolicy, Warmer}; diff --git a/src/spatial.rs b/src/spatial.rs new file mode 100644 index 0000000000..0882217201 --- /dev/null +++ b/src/spatial.rs @@ -0,0 +1,281 @@ +//! TODO +//! +//! ``` +//! # fn main() -> tantivy::Result<()> { +//! use std::sync::{Arc, Weak}; +//! +//! use rstar::{primitives::GeomWithData, RTree, AABB}; +//! use tantivy::{ +//! collector::DocSetCollector, +//! doc, +//! schema::{Schema, STORED}, +//! spatial::{SpatialIndex, SpatialQuery}, +//! DocAddress, Index, Result, Warmer, +//! }; +//! +//! let mut schema = Schema::builder(); +//! let x = schema.add_f64_field("x", STORED); +//! let y = schema.add_f64_field("y", STORED); +//! let schema = schema.build(); +//! +//! let index = Index::create_in_ram(schema); +//! +//! let mut writer = index.writer_with_num_threads(1, 10_000_000)?; +//! writer.add_document(doc!(x => 0.5, y => 0.5))?; +//! writer.add_document(doc!(x => 1.5, y => 0.5))?; +//! writer.add_document(doc!(x => 0.5, y => 1.5))?; +//! writer.add_document(doc!(x => 0.25, y => 0.75))?; +//! writer.add_document(doc!(x => 0.75, y => 0.25))?; +//! writer.commit()?; +//! +//! let spatial_index = Arc::new(SpatialIndex::new(move |reader| { +//! let store_reader = reader.get_store_reader(0)?; +//! +//! Ok(RTree::bulk_load( +//! reader +//! .doc_ids_alive() +//! .map(|doc_id| { +//! let doc = store_reader.get(doc_id)?; +//! let x = doc.get_first(x).unwrap().as_f64().unwrap(); +//! let y = doc.get_first(y).unwrap().as_f64().unwrap(); +//! +//! Ok(GeomWithData::new([x, y], doc_id)) +//! }) +//! .collect::>()?, +//! )) +//! })); +//! +//! let warmers = vec![Arc::downgrade(&spatial_index) as Weak]; +//! let reader = index.reader_builder().warmers(warmers).try_into()?; +//! +//! let spatial_query = +//! SpatialQuery::locate_in_envelope(&spatial_index, AABB::from_corners([0., 0.], [1., 1.])); +//! +//! let searcher = reader.searcher(); +//! let results = searcher.search(&spatial_query, &DocSetCollector)?; +//! +//! assert_eq!( +//! results, +//! [ +//! DocAddress { +//! segment_ord: 0, +//! doc_id: 0, +//! }, +//! DocAddress { +//! segment_ord: 0, +//! doc_id: 3, +//! }, +//! DocAddress { +//! segment_ord: 0, +//! doc_id: 4, +//! }, +//! ] +//! .into(), +//! ); +//! # Ok(()) } +//! ``` +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::sync::Arc; + +use arc_swap::ArcSwap; +use common::BitSet; +use rstar::primitives::GeomWithData; +use rstar::{Envelope, Point, PointDistance, RTree, RTreeObject}; + +use crate::query::{BitSetDocSet, ConstScorer, EnableScoring, Explanation, Query, Scorer, Weight}; +use crate::{ + DocId, Opstamp, Result, Score, Searcher, SearcherGeneration, SegmentId, SegmentReader, + TantivyError, Warmer, +}; + +type SegmentKey = (SegmentId, Option); + +/// TODO +pub type SegmentTree = RTree>; + +type Trees = HashMap>>; + +type Inner = dyn Fn(&SegmentTree, &mut BitSet) + Send + Sync; + +/// TODO +pub struct SpatialIndex { + trees: ArcSwap>, + builder: Box Result> + Send + Sync>, +} + +impl SpatialIndex { + /// TODO + pub fn new(builder: B) -> Self + where B: Fn(&SegmentReader) -> Result> + Send + Sync + 'static { + Self { + trees: Default::default(), + builder: Box::new(builder), + } + } +} + +impl Warmer for SpatialIndex +where SegmentTree: Send + Sync +{ + fn warm(&self, searcher: &Searcher) -> Result<()> { + let mut trees = self.trees.load_full(); + + for reader in searcher.segment_readers() { + let key = (reader.segment_id(), reader.delete_opstamp()); + + if trees.contains_key(&key) { + continue; + } + + let tree = (self.builder)(reader)?; + + Arc::make_mut(&mut trees).insert(key, Arc::new(tree)); + } + + self.trees.store(trees); + + Ok(()) + } + + fn garbage_collect(&self, live_generations: &[&SearcherGeneration]) { + let live_keys = live_generations + .iter() + .flat_map(|gen| gen.segments()) + .map(|(&segment_id, &opstamp)| (segment_id, opstamp)) + .collect::>(); + + let mut trees = self.trees.load_full(); + + Arc::make_mut(&mut trees).retain(|key, _tree| live_keys.contains(key)); + + self.trees.store(trees); + } +} + +/// TODO +pub struct SpatialQuery { + trees: Arc>, + inner: Arc>, +} + +impl Clone for SpatialQuery { + fn clone(&self) -> Self { + Self { + trees: Arc::clone(&self.trees), + inner: Arc::clone(&self.inner), + } + } +} + +impl fmt::Debug for SpatialQuery { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SpatialQuery").finish_non_exhaustive() + } +} + +impl SpatialQuery { + /// TODO + pub fn locate_all_at_point( + index: &SpatialIndex, + point: ::Point, + ) -> Self + where + T: PointDistance, + ::Point: Send + Sync + 'static, + { + Self { + trees: index.trees.load_full(), + inner: Arc::new(move |tree, bitset| { + for node in tree.locate_all_at_point(&point) { + bitset.insert(node.data) + } + }), + } + } + + /// TODO + pub fn locate_in_envelope(index: &SpatialIndex, envelope: T::Envelope) -> Self + where T::Envelope: Send + Sync + 'static { + Self { + trees: index.trees.load_full(), + inner: Arc::new(move |tree, bitset| { + for node in tree.locate_in_envelope(&envelope) { + bitset.insert(node.data) + } + }), + } + } + + /// TODO + pub fn locate_in_envelope_intersecting(index: &SpatialIndex, envelope: T::Envelope) -> Self + where T::Envelope: Send + Sync + 'static { + Self { + trees: index.trees.load_full(), + inner: Arc::new(move |tree, bitset| { + for node in tree.locate_in_envelope_intersecting(&envelope) { + bitset.insert(node.data) + } + }), + } + } + + /// TODO + pub fn locate_within_distance( + index: &SpatialIndex, + query_point: ::Point, + max_squared_radius: <::Point as Point>::Scalar, + ) -> Self + where + T: PointDistance, + ::Point: Clone + Send + Sync + 'static, + <::Point as Point>::Scalar: Send + Sync + 'static, + { + Self { + trees: index.trees.load_full(), + inner: Arc::new(move |tree, bitset| { + for node in tree.locate_within_distance(query_point.clone(), max_squared_radius) { + bitset.insert(node.data) + } + }), + } + } +} + +impl Query for SpatialQuery +where SegmentTree: Send + Sync +{ + fn weight(&self, _: EnableScoring<'_>) -> Result> { + Ok(Box::new(self.clone())) + } +} + +impl Weight for SpatialQuery +where SegmentTree: Send + Sync +{ + fn scorer(&self, reader: &SegmentReader, boost: Score) -> Result> { + let key = (reader.segment_id(), reader.delete_opstamp()); + + let tree = &self.trees[&key]; + + let mut bitset = BitSet::with_max_value(reader.max_doc()); + + (self.inner)(tree, &mut bitset); + + Ok(Box::new(ConstScorer::new( + BitSetDocSet::from(bitset), + boost, + ))) + } + + fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result { + let mut scorer = self.scorer(reader, 1.0)?; + if scorer.seek(doc) == doc { + Ok(Explanation::new("SpatialQuery", 1.0)) + } else { + Err(TantivyError::InvalidArgument(format!( + "Document #({doc}) does not match" + ))) + } + } +}