Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include duckdb in Python bindings for searching #458

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ name = "stacrs"
crate-type = ["cdylib"]

[dependencies]
duckdb = { workspace = true, features = [
"bundled",
] } # we don't use it directly, but we need to ensure it's bundled
geojson = { workspace = true }
pyo3 = { workspace = true, features = ["extension-module"] }
pythonize = { workspace = true }
Expand All @@ -25,4 +28,5 @@ stac = { workspace = true, features = [
"validate-blocking",
] }
stac-api = { workspace = true, features = ["client"] }
stac-duckdb = { workspace = true }
tokio = { workspace = true, features = ["rt"] }
Binary file added python/data/extended-item.parquet
Binary file not shown.
6 changes: 6 additions & 0 deletions python/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ impl From<stac_api::Error> for Error {
}
}

impl From<stac_duckdb::Error> for Error {
fn from(value: stac_duckdb::Error) -> Self {
Error(value.to_string())
}
}

impl From<geojson::Error> for Error {
fn from(value: geojson::Error) -> Self {
Error(value.to_string())
Expand Down
1 change: 1 addition & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod search;
mod validate;
mod write;

use duckdb as _;
use error::Error;
use pyo3::prelude::*;

Expand Down
52 changes: 39 additions & 13 deletions python/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::{
use serde::de::DeserializeOwned;
use stac::Format;
use stac_api::{BlockingClient, Fields, Item, ItemCollection, Items, Search};
use stac_duckdb::Client;
use std::str::FromStr;
use tokio::runtime::Builder;

Expand Down Expand Up @@ -40,6 +41,9 @@ use tokio::runtime::Builder;
/// will be interpreted as cql2-text, dictionaries as cql2-json.
/// query (dict[str, Any] | None): Additional filtering based on properties.
/// It is recommended to use filter instead, if possible.
/// use_duckdb (bool | None): Query with DuckDB. If None and the href has a
/// 'parquet' or 'geoparquet' extension, will be set to True. Defaults
/// to None.
///
/// Returns:
/// list[dict[str, Any]]: A list of the returned STAC items.
Expand All @@ -53,7 +57,7 @@ use tokio::runtime::Builder;
/// ... max_items=1,
/// ... )
#[pyfunction]
#[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None))]
#[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, use_duckdb=None))]
pub fn search<'py>(
py: Python<'py>,
href: String,
Expand All @@ -69,6 +73,7 @@ pub fn search<'py>(
sortby: Option<StringOrList>,
filter: Option<StringOrDict>,
query: Option<Py<PyDict>>,
use_duckdb: Option<bool>,
) -> PyResult<Bound<'py, PyList>> {
let items = search_items(
href,
Expand All @@ -84,6 +89,7 @@ pub fn search<'py>(
sortby,
filter,
query,
use_duckdb,
)?;
pythonize::pythonize(py, &items)
.map_err(PyErr::from)
Expand Down Expand Up @@ -126,6 +132,9 @@ pub fn search<'py>(
/// format (str | None): The output format. If none, will be inferred from
/// the outfile extension, and if that fails will fall back to compact JSON.
/// options (list[tuple[str, str]] | None): Configuration values to pass to the object store backend.
/// use_duckdb (bool | None): Query with DuckDB. If None and the href has a
/// 'parquet' or 'geoparquet' extension, will be set to True. Defaults
/// to None.
///
/// Returns:
/// list[dict[str, Any]]: A list of the returned STAC items.
Expand All @@ -139,7 +148,7 @@ pub fn search<'py>(
/// ... max_items=1,
/// ... )
#[pyfunction]
#[pyo3(signature = (outfile, href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, format=None, options=None))]
#[pyo3(signature = (outfile, href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, format=None, options=None, use_duckdb=None))]
pub fn search_to(
outfile: String,
href: String,
Expand All @@ -157,6 +166,7 @@ pub fn search_to(
query: Option<Py<PyDict>>,
format: Option<String>,
options: Option<Vec<(String, String)>>,
use_duckdb: Option<bool>,
) -> PyResult<usize> {
let items = search_items(
href,
Expand All @@ -172,6 +182,7 @@ pub fn search_to(
sortby,
filter,
query,
use_duckdb,
)?;
let format = format
.map(|s| s.parse())
Expand Down Expand Up @@ -206,8 +217,8 @@ fn search_items(
sortby: Option<StringOrList>,
filter: Option<StringOrDict>,
query: Option<Py<PyDict>>,
use_duckdb: Option<bool>,
) -> PyResult<Vec<Item>> {
let client = BlockingClient::new(&href).map_err(Error::from)?;
let mut fields = Fields::default();
if let Some(include) = include {
fields.include = include.into();
Expand All @@ -225,7 +236,7 @@ fn search_items(
.map(|q| pythonize::depythonize(&q.into_bound(py)))
.transpose()
})?;
let search = Search {
let mut search = Search {
intersects: intersects.map(|i| i.into()).transpose()?,
ids: ids.map(|ids| ids.into()),
collections: collections.map(|c| c.into()),
Expand All @@ -243,18 +254,33 @@ fn search_items(
..Default::default()
},
};
let items = client.search(search).map_err(Error::from)?;
if let Some(max_items) = max_items {
items
.take(max_items)
.collect::<Result<_, _>>()
if use_duckdb
.unwrap_or_else(|| matches!(Format::infer_from_href(&href), Some(Format::Geoparquet(_))))
{
if let Some(max_items) = max_items {
search.items.limit = Some(max_items.try_into()?);
}
let client = Client::from_href(href).map_err(Error::from)?;
client
.search_to_json(search)
.map(|item_collection| item_collection.items)
.map_err(Error::from)
.map_err(PyErr::from)
} else {
items
.collect::<Result<_, _>>()
.map_err(Error::from)
.map_err(PyErr::from)
let client = BlockingClient::new(&href).map_err(Error::from)?;
let items = client.search(search).map_err(Error::from)?;
if let Some(max_items) = max_items {
items
.take(max_items)
.collect::<Result<_, _>>()
.map_err(Error::from)
.map_err(PyErr::from)
} else {
items
.collect::<Result<_, _>>()
.map_err(Error::from)
.map_err(PyErr::from)
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions python/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ def test_search_to_geoparquet(tmp_path: Path) -> None:
table = pyarrow.parquet.read_table(tmp_path / "out.parquet")
items = list(stac_geoparquet.arrow.stac_table_to_items(table))
assert len(items) == 1


def test_search_geoparquet(data: Path) -> None:
items = stacrs.search(str(data / "extended-item.parquet"))
assert len(items) == 1
Loading