From 3fb14735095a9a4f229cf4642c686d0b05b5c75a Mon Sep 17 00:00:00 2001 From: Daniel Sparing Date: Thu, 11 Jan 2024 18:31:25 +0100 Subject: [PATCH] add st_within --- docs/code-example-notebooks/predicates.scala | 33 ++++++ docs/source/api/spatial-predicates.rst | 62 ++++++++++- python/mosaic/api/predicates.py | 21 ++++ .../mosaic/core/geometry/MosaicGeometry.scala | 2 + .../core/geometry/MosaicGeometryJTS.scala | 2 + .../expressions/geometry/ST_Within.scala | 64 +++++++++++ .../labs/mosaic/functions/MosaicContext.scala | 2 + .../geometry/ST_WithinBehaviors.scala | 104 ++++++++++++++++++ .../expressions/geometry/ST_WithinTest.scala | 13 +++ 9 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Within.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_WithinBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_WithinTest.scala diff --git a/docs/code-example-notebooks/predicates.scala b/docs/code-example-notebooks/predicates.scala index 9374d4f67..df51a94fb 100644 --- a/docs/code-example-notebooks/predicates.scala +++ b/docs/code-example-notebooks/predicates.scala @@ -87,3 +87,36 @@ df.select(st_intersects($"p1", $"p2")).show(false) // MAGIC %r // MAGIC df <- createDataFrame(data.frame(p1 = "POLYGON ((0 0, 0 3, 3 3, 3 0))", p2 = "POLYGON ((2 2, 2 4, 4 4, 4 2))")) // MAGIC showDF(select(df, st_intersects(column("p1"), column("p2"))), truncate=F) + +// MAGIC %md +// MAGIC ### st_within + +// COMMAND ---------- + +// MAGIC %python +// MAGIC help(st_within) + +// COMMAND ---------- + +// MAGIC %python +// MAGIC df = spark.createDataFrame([{'point': 'POINT (25 15)', 'poly': 'POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))'}]) +// MAGIC df.select(st_within('point', 'poly')).show() + +// COMMAND ---------- + +val df = List(("POINT (25 15)", "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")).toDF("point", "poly") +df.select(st_within($"point", $"poly")).show() + +// COMMAND ---------- + +// MAGIC %sql +// MAGIC SELECT st_within("POINT (25 15)", "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))") + +// COMMAND ---------- + +// MAGIC %r +// MAGIC df <- createDataFrame(data.frame(point = c( "POINT (25 15)"), poly = "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")) +// MAGIC showDF(select(df, st_within(column("point"), column("poly")))) + +// COMMAND ---------- + diff --git a/docs/source/api/spatial-predicates.rst b/docs/source/api/spatial-predicates.rst index 1a45dc745..09fc6fa31 100644 --- a/docs/source/api/spatial-predicates.rst +++ b/docs/source/api/spatial-predicates.rst @@ -58,6 +58,8 @@ st_contains | true| +------------------------+ +.. note:: ST_Within is the inverse of ST_Contains, where ST_Contains(a, b)==ST_Within(b,a). + st_intersects ************* @@ -114,4 +116,62 @@ st_intersects | true| +---------------------+ -.. note:: Intersection logic will be dependent on the chosen geometry API (ESRI or JTS). ESRI is only available for mosaic < 0.4.x series, in mosaic >= 0.4.0 JTS is the only geometry API. \ No newline at end of file +.. note:: Intersection logic will be dependent on the chosen geometry API (ESRI or JTS). ESRI is only available for mosaic < 0.4.x series, in mosaic >= 0.4.0 JTS is the only geometry API. + +st_within +********* + +.. function:: st_within(geom1, geom2) + + Returns `true` if `geom1` 'spatially' is within `geom2`. + + :param geom1: Geometry + :type geom1: Column + :param geom2: Geometry + :type geom2: Column + :rtype: Column: BooleanType + + :example: + +.. tabs:: + .. code-tab:: py + + df = spark.createDataFrame([{'point': 'POINT (25 15)', 'poly': 'POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))'}]) + df.select(st_within('point', 'poly')).show() + +----------------------+ + |st_within(point, poly)| + +----------------------+ + | true| + +----------------------+ + + .. code-tab:: scala + + val df = List(("POINT (25 15)", "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")).toDF("point", "poly") + df.select(st_within($"point", $"poly")).show() + +----------------------+ + |st_within(point, poly)| + +----------------------+ + | true| + +----------------------+ + + .. code-tab:: sql + + SELECT st_within("POINT (25 15)", "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))") + +----------------------+ + |st_within(point, poly)| + +----------------------+ + | true| + +----------------------+ + + .. code-tab:: r R + + df <- createDataFrame(data.frame(point = c( "POINT (25 15)"), poly = "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")) + showDF(select(df, st_within(column("point"), column("poly")))) + +----------------------+ + |st_within(point, poly)| + +----------------------+ + | true| + +----------------------+ + +.. note:: ST_Within is the inverse of ST_Contains, where ST_Contains(a, b)==ST_Within(b,a). + diff --git a/python/mosaic/api/predicates.py b/python/mosaic/api/predicates.py index 0b7d01815..39f856597 100644 --- a/python/mosaic/api/predicates.py +++ b/python/mosaic/api/predicates.py @@ -52,3 +52,24 @@ def st_contains(geom1: ColumnOrName, geom2: ColumnOrName) -> Column: pyspark_to_java_column(geom1), pyspark_to_java_column(geom2), ) + + +def st_within(geom1: ColumnOrName, geom2: ColumnOrName) -> Column: + """ + Returns `true` if geom1 'spatially' is within geom2. + + Parameters + ---------- + geom1 : Column + geom2 : Column + + Returns + ------- + Column (BooleanType) + + """ + return config.mosaic_context.invoke_function( + "st_within", + pyspark_to_java_column(geom1), + pyspark_to_java_column(geom2), + ) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala index 8af8c9996..2063ca7c7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometry.scala @@ -83,6 +83,8 @@ trait MosaicGeometry extends GeometryWriter with Serializable { def contains(other: MosaicGeometry): Boolean + def within(other: MosaicGeometry): Boolean + def flatten: Seq[MosaicGeometry] def equals(other: MosaicGeometry): Boolean diff --git a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala index 17960d423..25e723510 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/geometry/MosaicGeometryJTS.scala @@ -148,6 +148,8 @@ abstract class MosaicGeometryJTS(geom: Geometry) extends MosaicGeometry { override def contains(geom2: MosaicGeometry): Boolean = geom.contains(geom2.asInstanceOf[MosaicGeometryJTS].getGeom) + override def within(geom2: MosaicGeometry): Boolean = geom.within(geom2.asInstanceOf[MosaicGeometryJTS].getGeom) + def getGeom: Geometry = geom override def isValid: Boolean = geom.isValid diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Within.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Within.scala new file mode 100644 index 000000000..c5cfd5a8f --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/geometry/ST_Within.scala @@ -0,0 +1,64 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.core.geometry.MosaicGeometry +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.geometry.base.BinaryVectorExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.{BooleanType, DataType} + +/** + * Returns true if leftGeom is within rightGeom. + * @param leftGeom + * The left geometry. + * @param rightGeom + * The right geometry. + * @param expressionConfig + * Additional arguments for the expression (expressionConfigs). + */ +case class ST_Within( + leftGeom: Expression, + rightGeom: Expression, + expressionConfig: MosaicExpressionConfig +) extends BinaryVectorExpression[ST_Within]( + leftGeom, + rightGeom, + returnsGeometry = false, + expressionConfig + ) { + + override def dataType: DataType = BooleanType + + override def geometryTransform(leftGeometry: MosaicGeometry, rightGeometry: MosaicGeometry): Any = { + leftGeometry.within(rightGeometry) + } + + override def geometryCodeGen(leftGeometryRef: String, rightGeometryRef: String, ctx: CodegenContext): (String, String) = { + val within = ctx.freshName("within") + val code = s"""boolean $within = $leftGeometryRef.within($rightGeometryRef);""" + (code, within) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object ST_Within extends WithExpressionInfo { + + override def name: String = "st_within" + + override def usage: String = "_FUNC_(expr1, expr2) - Returns true if expr1 is within expr2." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(A, B); + | true + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[ST_Within](2, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index fbb0bb922..2cdfeb3a5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -177,6 +177,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[ST_UnaryUnion](expressionConfig) mosaicRegistry.registerExpression[ST_Union](expressionConfig) mosaicRegistry.registerExpression[ST_UpdateSRID](expressionConfig) + mosaicRegistry.registerExpression[ST_Within](expressionConfig) mosaicRegistry.registerExpression[ST_X](expressionConfig) mosaicRegistry.registerExpression[ST_Y](expressionConfig) mosaicRegistry.registerExpression[ST_Haversine](expressionConfig) @@ -630,6 +631,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends /** Spatial predicates */ def st_contains(geom1: Column, geom2: Column): Column = ColumnAdapter(ST_Contains(geom1.expr, geom2.expr, expressionConfig)) def st_intersects(left: Column, right: Column): Column = ColumnAdapter(ST_Intersects(left.expr, right.expr, expressionConfig)) + def st_within(geom1: Column, geom2: Column): Column = ColumnAdapter(ST_Within(geom1.expr, geom2.expr, expressionConfig)) /** RasterAPI dependent functions */ def rst_bandmetadata(raster: Column, band: Column): Column = diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_WithinBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_WithinBehaviors.scala new file mode 100644 index 000000000..fba4805b5 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_WithinBehaviors.scala @@ -0,0 +1,104 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.functions.MosaicContext +import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ +import org.scalatest.matchers.must.Matchers.noException +import org.scalatest.matchers.should.Matchers.{an, be, convertToAnyShouldWrapper} + +trait ST_WithinBehaviors extends MosaicSpatialQueryTest { + + def withinBehavior(mosaicContext: MosaicContext): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = mosaicContext + import mc.functions._ + val sc = spark + import sc.implicits._ + mc.register(spark) + + val poly = """POLYGON ((10 10, 110 10, 110 110, 10 110, 10 10), + | (20 20, 20 30, 30 30, 30 20, 20 20), + | (40 20, 40 30, 50 30, 50 20, 40 20))""".stripMargin.filter(_ >= ' ') + + val rows = List( + ("POINT (35 25)", poly, true), + ("POINT (25 25)", poly, false) + ) + + val results = rows + .toDF("leftGeom", "rightGeom", "expected") + .withColumn("result", st_within($"leftGeom", $"rightGeom")) + .where($"expected" === $"result") + + results.count shouldBe 2 + } + + def withinCodegen(mosaicContext: MosaicContext): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = mosaicContext + val sc = spark + import mc.functions._ + import sc.implicits._ + mc.register(spark) + + val poly = """POLYGON ((10 10, 110 10, 110 110, 10 110, 10 10), + | (20 20, 20 30, 30 30, 30 20, 20 20), + | (40 20, 40 30, 50 30, 50 20, 40 20))""".stripMargin.filter(_ >= ' ') + + val rows = List( + ("POINT (35 25)", true), + ("POINT (25 25)", false) + ) + + val polygons = List(poly).toDF("rightGeom") + val points = rows.toDF("leftGeom", "expected") + + val result = polygons + .crossJoin(points) + .withColumn("result", st_within($"leftGeom", $"rightGeom")) + .where($"expected" === $"result") + + val queryExecution = result.queryExecution + val plan = queryExecution.executedPlan + + val wholeStageCodegenExec = plan.find(_.isInstanceOf[WholeStageCodegenExec]) + + wholeStageCodegenExec.isDefined shouldBe true + + val codeGenStage = wholeStageCodegenExec.get.asInstanceOf[WholeStageCodegenExec] + val (_, code) = codeGenStage.doCodeGen() + + noException should be thrownBy CodeGenerator.compile(code) + + val stWithin = ST_Within(lit(rows.head._1).expr, lit(1).expr, mc.expressionConfig) + val ctx = new CodegenContext + an[Error] should be thrownBy stWithin.genCode(ctx) + } + + def auxiliaryMethods(mosaicContext: MosaicContext): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = mosaicContext + mc.register(spark) + + val poly = """POLYGON ((10 10, 110 10, 110 110, 10 110, 10 10), + | (20 20, 20 30, 30 30, 30 20, 20 20), + | (40 20, 40 30, 50 30, 50 20, 40 20))""".stripMargin.filter(_ >= ' ') + + val rows = List( + ("POINT (35 25)", true), + ("POINT (25 25)", false) + ) + + val stWithin = ST_Within(lit(rows.head._1).expr, lit(poly).expr, mc.expressionConfig) + + stWithin.left shouldEqual lit(rows.head._1).expr + stWithin.right shouldEqual lit(poly).expr + stWithin.dataType shouldEqual BooleanType + noException should be thrownBy stWithin.makeCopy(Array(stWithin.left, stWithin.right)) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_WithinTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_WithinTest.scala new file mode 100644 index 000000000..963843c27 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_WithinTest.scala @@ -0,0 +1,13 @@ +package com.databricks.labs.mosaic.expressions.geometry + +import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest +import org.apache.spark.sql.test.SharedSparkSession + +class ST_WithinTest extends MosaicSpatialQueryTest with SharedSparkSession with ST_WithinBehaviors { + + testAllGeometriesNoCodegen("ST_Within behavior") { withinBehavior } + testAllGeometriesCodegen("ST_Within codegen compilation") { withinCodegen } + testAllGeometriesCodegen("ST_Within codegen behavior") { withinBehavior } + testAllGeometriesNoCodegen("ST_Within auxiliary methods") { auxiliaryMethods } + +}