Skip to content

Commit ffaa20d

Browse files
authored
feat(rust/sedona-functions): Implement native ST_ZMFlag using WKBHeader (#260)
1 parent a6ae54e commit ffaa20d

File tree

6 files changed

+246
-0
lines changed

6 files changed

+246
-0
lines changed

benchmarks/test_functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,21 @@ def queries():
293293
eng.execute_and_collect(f"SELECT ST_EndPoint(geom1) from {table}")
294294

295295
benchmark(queries)
296+
297+
@pytest.mark.parametrize(
298+
"eng", [SedonaDBSingleThread, PostGISSingleThread, DuckDBSingleThread]
299+
)
300+
@pytest.mark.parametrize(
301+
"table",
302+
[
303+
"collections_simple",
304+
"collections_complex",
305+
],
306+
)
307+
def test_st_zmflag(self, benchmark, eng, table):
308+
eng = self._get_eng(eng)
309+
310+
def queries():
311+
eng.execute_and_collect(f"SELECT ST_ZmFlag(geom1) from {table}")
312+
313+
benchmark(queries)

python/sedonadb/tests/functions/test_functions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,3 +1694,33 @@ def test_st_simplifypreservetopology(eng, geom, tolerance, expected):
16941694
f"SELECT ST_SimplifyPreserveTopology({geom_or_null(geom)}, {val_or_null(tolerance)})",
16951695
expected,
16961696
)
1697+
1698+
1699+
@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
1700+
@pytest.mark.parametrize(
1701+
("geom", "expected"),
1702+
[
1703+
(None, None),
1704+
("POINT EMPTY", 0),
1705+
("POINT Z EMPTY", 2),
1706+
("POINT M EMPTY", 1),
1707+
("POINT ZM EMPTY", 3),
1708+
("POINT Z (0 0 0)", 2),
1709+
("POINT M (0 0 0)", 1),
1710+
("POINT ZM (0 0 0 0)", 3),
1711+
("LINESTRING EMPTY", 0),
1712+
("LINESTRING Z EMPTY", 2),
1713+
("LINESTRING Z (0 0 0, 1 1 1)", 2),
1714+
("POLYGON EMPTY", 0),
1715+
("MULTIPOINT ((0 0), (1 1))", 0),
1716+
("MULTIPOINT Z ((0 0 0))", 2),
1717+
("MULTIPOINT ZM ((0 0 0 0))", 3),
1718+
("GEOMETRYCOLLECTION EMPTY", 0),
1719+
("GEOMETRYCOLLECTION (POINT Z (0 0 0))", 2),
1720+
("GEOMETRYCOLLECTION Z (POINT Z (0 0 0))", 2),
1721+
("GEOMETRYCOLLECTION (GEOMETRYCOLLECTION (POINT Z (0 0 0)))", 2),
1722+
],
1723+
)
1724+
def test_st_zmflag(eng, geom, expected):
1725+
eng = eng.create_or_skip()
1726+
eng.assert_query_result(f"SELECT ST_ZmFlag({geom_or_null(geom)})", expected)

rust/sedona-functions/benches/native-functions.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ fn criterion_benchmark(c: &mut Criterion) {
157157
benchmark::scalar(c, &f, "native", "st_mmin", LineString(10));
158158
benchmark::scalar(c, &f, "native", "st_mmax", LineString(10));
159159

160+
benchmark::scalar(c, &f, "native", "st_zmflag", Point);
161+
benchmark::scalar(c, &f, "native", "st_zmflag", LineString(10));
162+
160163
benchmark::scalar(
161164
c,
162165
&f,

rust/sedona-functions/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,4 @@ mod st_transform;
5555
pub mod st_union_aggr;
5656
mod st_xyzm;
5757
mod st_xyzm_minmax;
58+
mod st_zmflag;

rust/sedona-functions/src/register.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ pub fn default_function_set() -> FunctionSet {
108108
crate::st_xyzm_minmax::st_mmin_udf,
109109
crate::st_xyzm_minmax::st_mmax_udf,
110110
crate::st_isclosed::st_isclosed_udf,
111+
crate::st_zmflag::st_zmflag_udf,
111112
);
112113

113114
register_aggregate_udfs!(
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
use std::sync::Arc;
18+
19+
use crate::executor::WkbBytesExecutor;
20+
use arrow_array::builder::Int8Builder;
21+
use arrow_schema::DataType;
22+
use datafusion_common::{error::Result, DataFusionError};
23+
use datafusion_expr::{
24+
scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, Volatility,
25+
};
26+
use geo_traits::Dimensions;
27+
use sedona_common::sedona_internal_err;
28+
use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
29+
use sedona_geometry::wkb_header::WkbHeader;
30+
use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
31+
32+
pub fn st_zmflag_udf() -> SedonaScalarUDF {
33+
SedonaScalarUDF::new(
34+
"st_zmflag",
35+
vec![Arc::new(STZmFlag {})],
36+
Volatility::Immutable,
37+
Some(st_zmflag_doc()),
38+
)
39+
}
40+
41+
fn st_zmflag_doc() -> Documentation {
42+
Documentation::builder(
43+
DOC_SECTION_OTHER,
44+
"Returns a code indicating the ZM coordinate dimension of a geometry. Values are 0 for 2D, 1 for 3D-M, 2 for 3D-Z, and 3 for 4D.".to_string(),
45+
"ST_ZmFlag (A: Geometry)".to_string(),
46+
)
47+
.with_argument("geom", "geometry: Input geometry")
48+
.with_sql_example("SELECT ST_ZmFlag(ST_GeomFromWKT('POLYGON ((0 0, 1 0, 0 1, 0 0))'))")
49+
.build()
50+
}
51+
52+
#[derive(Debug)]
53+
struct STZmFlag {}
54+
55+
impl SedonaScalarKernel for STZmFlag {
56+
fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
57+
let matcher = ArgMatcher::new(
58+
vec![ArgMatcher::is_geometry()],
59+
SedonaType::Arrow(DataType::Int8),
60+
);
61+
62+
matcher.match_args(args)
63+
}
64+
65+
fn invoke_batch(
66+
&self,
67+
arg_types: &[SedonaType],
68+
args: &[ColumnarValue],
69+
) -> Result<ColumnarValue> {
70+
let executor = WkbBytesExecutor::new(arg_types, args);
71+
let mut builder = Int8Builder::with_capacity(executor.num_iterations());
72+
73+
executor.execute_wkb_void(|maybe_item| {
74+
match maybe_item {
75+
Some(item) => {
76+
builder.append_value(invoke_scalar(item)?);
77+
}
78+
None => builder.append_null(),
79+
}
80+
Ok(())
81+
})?;
82+
83+
executor.finish(Arc::new(builder.finish()))
84+
}
85+
}
86+
87+
fn invoke_scalar(buf: &[u8]) -> Result<i8> {
88+
let header = WkbHeader::try_new(buf).map_err(|e| DataFusionError::External(Box::new(e)))?;
89+
let top_level_dimensions = header
90+
.dimensions()
91+
.map_err(|e| DataFusionError::External(Box::new(e)))?;
92+
93+
// Infer dimension based on first coordinate dimension for cases where it differs from top-level
94+
// e.g GEOMETRYCOLLECTION (POINT Z (1 2 3))
95+
let dimensions;
96+
if let Some(first_geom_dimensions) = header.first_geom_dimensions() {
97+
dimensions = first_geom_dimensions;
98+
} else {
99+
dimensions = top_level_dimensions;
100+
}
101+
102+
match dimensions {
103+
Dimensions::Xy => Ok(0),
104+
Dimensions::Xym => Ok(1),
105+
Dimensions::Xyz => Ok(2),
106+
Dimensions::Xyzm => Ok(3),
107+
_ => sedona_internal_err!("Invalid dimensions: {:?}", dimensions),
108+
}
109+
}
110+
111+
#[cfg(test)]
112+
mod tests {
113+
use datafusion_common::ScalarValue;
114+
use datafusion_expr::ScalarUDF;
115+
use rstest::rstest;
116+
use sedona_schema::datatypes::{WKB_GEOMETRY, WKB_VIEW_GEOMETRY};
117+
use sedona_testing::{
118+
fixtures::MULTIPOINT_WITH_INFERRED_Z_DIMENSION_WKB, testers::ScalarUdfTester,
119+
};
120+
121+
use super::*;
122+
123+
#[test]
124+
fn udf_metadata() {
125+
let udf: ScalarUDF = st_zmflag_udf().into();
126+
assert_eq!(udf.name(), "st_zmflag");
127+
assert!(udf.documentation().is_some());
128+
}
129+
130+
#[rstest]
131+
fn udf(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] sedona_type: SedonaType) {
132+
let tester = ScalarUdfTester::new(st_zmflag_udf().into(), vec![sedona_type.clone()]);
133+
134+
tester.assert_return_type(DataType::Int8);
135+
136+
let result = tester.invoke_scalar("POINT ZM (1 2 3 4)").unwrap();
137+
tester.assert_scalar_result_equals(result, 3);
138+
139+
let result = tester.invoke_scalar("POINT (1 2)").unwrap();
140+
tester.assert_scalar_result_equals(result, 0);
141+
142+
let result = tester.invoke_scalar("POINT Z (1 2 3)").unwrap();
143+
tester.assert_scalar_result_equals(result, 2);
144+
145+
let result = tester.invoke_wkb_scalar(None).unwrap();
146+
tester.assert_scalar_result_equals(result, ScalarValue::Null);
147+
148+
// Z-dimension specified only in the nested geometry, but not the geom collection level
149+
let result = tester
150+
.invoke_wkb_scalar(Some("GEOMETRYCOLLECTION (POINT Z (1 2 3))"))
151+
.unwrap();
152+
tester.assert_scalar_result_equals(result, 2);
153+
154+
// Z-dimension specified on both the geom collection and nested geometry level
155+
// Geometry collection with Z dimension both on the geom collection and nested geometry level
156+
let result = tester
157+
.invoke_wkb_scalar(Some("GEOMETRYCOLLECTION Z (POINT Z (1 2 3))"))
158+
.unwrap();
159+
tester.assert_scalar_result_equals(result, 2);
160+
161+
let result = tester
162+
.invoke_wkb_scalar(Some("GEOMETRYCOLLECTION (POINT M (1 2 3))"))
163+
.unwrap();
164+
tester.assert_scalar_result_equals(result, 1);
165+
166+
let result = tester
167+
.invoke_wkb_scalar(Some("GEOMETRYCOLLECTION EMPTY"))
168+
.unwrap();
169+
tester.assert_scalar_result_equals(result, 0);
170+
171+
// Empty geometry collections with Z or M dimensions
172+
let result = tester
173+
.invoke_wkb_scalar(Some("GEOMETRYCOLLECTION Z EMPTY"))
174+
.unwrap();
175+
tester.assert_scalar_result_equals(result, 2);
176+
177+
let result = tester
178+
.invoke_wkb_scalar(Some("GEOMETRYCOLLECTION M EMPTY"))
179+
.unwrap();
180+
tester.assert_scalar_result_equals(result, 1);
181+
}
182+
183+
#[test]
184+
fn multipoint_with_inferred_z_dimension() {
185+
let tester = ScalarUdfTester::new(st_zmflag_udf().into(), vec![WKB_GEOMETRY]);
186+
187+
let scalar = ScalarValue::Binary(Some(MULTIPOINT_WITH_INFERRED_Z_DIMENSION_WKB.to_vec()));
188+
assert_eq!(
189+
tester.invoke_scalar(scalar.clone()).unwrap(),
190+
ScalarValue::Int8(Some(2))
191+
);
192+
}
193+
}

0 commit comments

Comments
 (0)