Skip to content

Commit

Permalink
feat: include duckdb in python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
gadomski committed Oct 17, 2024
1 parent 416c648 commit 5fc4540
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 13 deletions.
4 changes: 4 additions & 0 deletions python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ name = "stacrs"
crate-type = ["cdylib"]

[dependencies]
duckdb = { workspace = true, features = [
"bundled",
] } # we don't use it directly, but we need to ensure it's bundled
geojson = { workspace = true }
pyo3 = { workspace = true, features = ["extension-module"] }
pythonize = { workspace = true }
Expand All @@ -25,4 +28,5 @@ stac = { workspace = true, features = [
"validate-blocking",
] }
stac-api = { workspace = true, features = ["client"] }
stac-duckdb = { workspace = true }
tokio = { workspace = true, features = ["rt"] }
Binary file added python/data/extended-item.parquet
Binary file not shown.
6 changes: 6 additions & 0 deletions python/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ impl From<stac_api::Error> for Error {
}
}

impl From<stac_duckdb::Error> for Error {
fn from(value: stac_duckdb::Error) -> Self {
Error(value.to_string())
}
}

impl From<geojson::Error> for Error {
fn from(value: geojson::Error) -> Self {
Error(value.to_string())
Expand Down
1 change: 1 addition & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod search;
mod validate;
mod write;

use duckdb as _;
use error::Error;
use pyo3::prelude::*;

Expand Down
52 changes: 39 additions & 13 deletions python/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use pyo3::{
use serde::de::DeserializeOwned;
use stac::Format;
use stac_api::{BlockingClient, Fields, Item, ItemCollection, Items, Search};
use stac_duckdb::Client;
use std::str::FromStr;
use tokio::runtime::Builder;

Expand Down Expand Up @@ -40,6 +41,9 @@ use tokio::runtime::Builder;
/// will be interpreted as cql2-text, dictionaries as cql2-json.
/// query (dict[str, Any] | None): Additional filtering based on properties.
/// It is recommended to use filter instead, if possible.
/// use_duckdb (bool | None): Query with DuckDB. If None and the href has a
/// 'parquet' or 'geoparquet' extension, will be set to True. Defaults
/// to None.
///
/// Returns:
/// list[dict[str, Any]]: A list of the returned STAC items.
Expand All @@ -53,7 +57,7 @@ use tokio::runtime::Builder;
/// ... max_items=1,
/// ... )
#[pyfunction]
#[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None))]
#[pyo3(signature = (href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, use_duckdb=None))]
pub fn search<'py>(
py: Python<'py>,
href: String,
Expand All @@ -69,6 +73,7 @@ pub fn search<'py>(
sortby: Option<StringOrList>,
filter: Option<StringOrDict>,
query: Option<Py<PyDict>>,
use_duckdb: Option<bool>,
) -> PyResult<Bound<'py, PyList>> {
let items = search_items(
href,
Expand All @@ -84,6 +89,7 @@ pub fn search<'py>(
sortby,
filter,
query,
use_duckdb,
)?;
pythonize::pythonize(py, &items)
.map_err(PyErr::from)
Expand Down Expand Up @@ -126,6 +132,9 @@ pub fn search<'py>(
/// format (str | None): The output format. If none, will be inferred from
/// the outfile extension, and if that fails will fall back to compact JSON.
/// options (list[tuple[str, str]] | None): Configuration values to pass to the object store backend.
/// use_duckdb (bool | None): Query with DuckDB. If None and the href has a
/// 'parquet' or 'geoparquet' extension, will be set to True. Defaults
/// to None.
///
/// Returns:
/// list[dict[str, Any]]: A list of the returned STAC items.
Expand All @@ -139,7 +148,7 @@ pub fn search<'py>(
/// ... max_items=1,
/// ... )
#[pyfunction]
#[pyo3(signature = (outfile, href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, format=None, options=None))]
#[pyo3(signature = (outfile, href, *, intersects=None, ids=None, collections=None, max_items=None, limit=None, bbox=None, datetime=None, include=None, exclude=None, sortby=None, filter=None, query=None, format=None, options=None, use_duckdb=None))]
pub fn search_to(
outfile: String,
href: String,
Expand All @@ -157,6 +166,7 @@ pub fn search_to(
query: Option<Py<PyDict>>,
format: Option<String>,
options: Option<Vec<(String, String)>>,
use_duckdb: Option<bool>,
) -> PyResult<usize> {
let items = search_items(
href,
Expand All @@ -172,6 +182,7 @@ pub fn search_to(
sortby,
filter,
query,
use_duckdb,
)?;
let format = format
.map(|s| s.parse())
Expand Down Expand Up @@ -206,8 +217,8 @@ fn search_items(
sortby: Option<StringOrList>,
filter: Option<StringOrDict>,
query: Option<Py<PyDict>>,
use_duckdb: Option<bool>,
) -> PyResult<Vec<Item>> {
let client = BlockingClient::new(&href).map_err(Error::from)?;
let mut fields = Fields::default();
if let Some(include) = include {
fields.include = include.into();
Expand All @@ -225,7 +236,7 @@ fn search_items(
.map(|q| pythonize::depythonize(&q.into_bound(py)))
.transpose()
})?;
let search = Search {
let mut search = Search {
intersects: intersects.map(|i| i.into()).transpose()?,
ids: ids.map(|ids| ids.into()),
collections: collections.map(|c| c.into()),
Expand All @@ -243,18 +254,33 @@ fn search_items(
..Default::default()
},
};
let items = client.search(search).map_err(Error::from)?;
if let Some(max_items) = max_items {
items
.take(max_items)
.collect::<Result<_, _>>()
if use_duckdb
.unwrap_or_else(|| matches!(Format::infer_from_href(&href), Some(Format::Geoparquet(_))))
{
if let Some(max_items) = max_items {
search.items.limit = Some(max_items.try_into()?);
}
let client = Client::from_href(href).map_err(Error::from)?;
client
.search_to_json(search)
.map(|item_collection| item_collection.items)
.map_err(Error::from)
.map_err(PyErr::from)
} else {
items
.collect::<Result<_, _>>()
.map_err(Error::from)
.map_err(PyErr::from)
let client = BlockingClient::new(&href).map_err(Error::from)?;
let items = client.search(search).map_err(Error::from)?;
if let Some(max_items) = max_items {
items
.take(max_items)
.collect::<Result<_, _>>()
.map_err(Error::from)
.map_err(PyErr::from)
} else {
items
.collect::<Result<_, _>>()
.map_err(Error::from)
.map_err(PyErr::from)
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions python/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ def test_search_to_geoparquet(tmp_path: Path) -> None:
table = pyarrow.parquet.read_table(tmp_path / "out.parquet")
items = list(stac_geoparquet.arrow.stac_table_to_items(table))
assert len(items) == 1


def test_search_geoparquet(data: Path) -> None:
items = stacrs.search(str(data / "extended-item.parquet"))
assert len(items) == 1

0 comments on commit 5fc4540

Please sign in to comment.