Skip to content

Commit

Permalink
Merge pull request #500 from databrickslabs/feature/st_within
Browse files Browse the repository at this point in the history
add st_within
  • Loading branch information
Milos Colic authored Jan 12, 2024
2 parents 21759ae + 3fb1473 commit cb98f55
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 1 deletion.
33 changes: 33 additions & 0 deletions docs/code-example-notebooks/predicates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,36 @@ df.select(st_intersects($"p1", $"p2")).show(false)
// MAGIC %r
// MAGIC df <- createDataFrame(data.frame(p1 = "POLYGON ((0 0, 0 3, 3 3, 3 0))", p2 = "POLYGON ((2 2, 2 4, 4 4, 4 2))"))
// MAGIC showDF(select(df, st_intersects(column("p1"), column("p2"))), truncate=F)

// MAGIC %md
// MAGIC ### st_within

// COMMAND ----------

// MAGIC %python
// MAGIC help(st_within)

// COMMAND ----------

// MAGIC %python
// MAGIC df = spark.createDataFrame([{'point': 'POINT (25 15)', 'poly': 'POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))'}])
// MAGIC df.select(st_within('point', 'poly')).show()

// COMMAND ----------

val df = List(("POINT (25 15)", "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")).toDF("point", "poly")
df.select(st_within($"point", $"poly")).show()

// COMMAND ----------

// MAGIC %sql
// MAGIC SELECT st_within("POINT (25 15)", "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")

// COMMAND ----------

// MAGIC %r
// MAGIC df <- createDataFrame(data.frame(point = c( "POINT (25 15)"), poly = "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))"))
// MAGIC showDF(select(df, st_within(column("point"), column("poly"))))

// COMMAND ----------

62 changes: 61 additions & 1 deletion docs/source/api/spatial-predicates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ st_contains
| true|
+------------------------+

.. note:: ST_Within is the inverse of ST_Contains, where ST_Contains(a, b)==ST_Within(b,a).


st_intersects
*************
Expand Down Expand Up @@ -114,4 +116,62 @@ st_intersects
| true|
+---------------------+

.. note:: Intersection logic will be dependent on the chosen geometry API (ESRI or JTS). ESRI is only available for mosaic < 0.4.x series, in mosaic >= 0.4.0 JTS is the only geometry API.
.. note:: Intersection logic will be dependent on the chosen geometry API (ESRI or JTS). ESRI is only available for mosaic < 0.4.x series, in mosaic >= 0.4.0 JTS is the only geometry API.

st_within
*********

.. function:: st_within(geom1, geom2)

Returns `true` if `geom1` 'spatially' is within `geom2`.

:param geom1: Geometry
:type geom1: Column
:param geom2: Geometry
:type geom2: Column
:rtype: Column: BooleanType

:example:

.. tabs::
.. code-tab:: py

df = spark.createDataFrame([{'point': 'POINT (25 15)', 'poly': 'POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))'}])
df.select(st_within('point', 'poly')).show()
+----------------------+
|st_within(point, poly)|
+----------------------+
| true|
+----------------------+

.. code-tab:: scala

val df = List(("POINT (25 15)", "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")).toDF("point", "poly")
df.select(st_within($"point", $"poly")).show()
+----------------------+
|st_within(point, poly)|
+----------------------+
| true|
+----------------------+

.. code-tab:: sql

SELECT st_within("POINT (25 15)", "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))")
+----------------------+
|st_within(point, poly)|
+----------------------+
| true|
+----------------------+

.. code-tab:: r R

df <- createDataFrame(data.frame(point = c( "POINT (25 15)"), poly = "POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))"))
showDF(select(df, st_within(column("point"), column("poly"))))
+----------------------+
|st_within(point, poly)|
+----------------------+
| true|
+----------------------+

.. note:: ST_Within is the inverse of ST_Contains, where ST_Contains(a, b)==ST_Within(b,a).

21 changes: 21 additions & 0 deletions python/mosaic/api/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,24 @@ def st_contains(geom1: ColumnOrName, geom2: ColumnOrName) -> Column:
pyspark_to_java_column(geom1),
pyspark_to_java_column(geom2),
)


def st_within(geom1: ColumnOrName, geom2: ColumnOrName) -> Column:
"""
Returns `true` if geom1 'spatially' is within geom2.
Parameters
----------
geom1 : Column
geom2 : Column
Returns
-------
Column (BooleanType)
"""
return config.mosaic_context.invoke_function(
"st_within",
pyspark_to_java_column(geom1),
pyspark_to_java_column(geom2),
)
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ trait MosaicGeometry extends GeometryWriter with Serializable {

def contains(other: MosaicGeometry): Boolean

def within(other: MosaicGeometry): Boolean

def flatten: Seq[MosaicGeometry]

def equals(other: MosaicGeometry): Boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ abstract class MosaicGeometryJTS(geom: Geometry) extends MosaicGeometry {

override def contains(geom2: MosaicGeometry): Boolean = geom.contains(geom2.asInstanceOf[MosaicGeometryJTS].getGeom)

override def within(geom2: MosaicGeometry): Boolean = geom.within(geom2.asInstanceOf[MosaicGeometryJTS].getGeom)

def getGeom: Geometry = geom

override def isValid: Boolean = geom.isValid
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.databricks.labs.mosaic.expressions.geometry

import com.databricks.labs.mosaic.core.geometry.MosaicGeometry
import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo}
import com.databricks.labs.mosaic.expressions.geometry.base.BinaryVectorExpression
import com.databricks.labs.mosaic.functions.MosaicExpressionConfig
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{BooleanType, DataType}

/**
* Returns true if leftGeom is within rightGeom.
* @param leftGeom
* The left geometry.
* @param rightGeom
* The right geometry.
* @param expressionConfig
* Additional arguments for the expression (expressionConfigs).
*/
case class ST_Within(
leftGeom: Expression,
rightGeom: Expression,
expressionConfig: MosaicExpressionConfig
) extends BinaryVectorExpression[ST_Within](
leftGeom,
rightGeom,
returnsGeometry = false,
expressionConfig
) {

override def dataType: DataType = BooleanType

override def geometryTransform(leftGeometry: MosaicGeometry, rightGeometry: MosaicGeometry): Any = {
leftGeometry.within(rightGeometry)
}

override def geometryCodeGen(leftGeometryRef: String, rightGeometryRef: String, ctx: CodegenContext): (String, String) = {
val within = ctx.freshName("within")
val code = s"""boolean $within = $leftGeometryRef.within($rightGeometryRef);"""
(code, within)
}

}

/** Expression info required for the expression registration for spark SQL. */
object ST_Within extends WithExpressionInfo {

override def name: String = "st_within"

override def usage: String = "_FUNC_(expr1, expr2) - Returns true if expr1 is within expr2."

override def example: String =
"""
| Examples:
| > SELECT _FUNC_(A, B);
| true
| """.stripMargin

override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = {
GenericExpressionFactory.getBaseBuilder[ST_Within](2, expressionConfig)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
mosaicRegistry.registerExpression[ST_UnaryUnion](expressionConfig)
mosaicRegistry.registerExpression[ST_Union](expressionConfig)
mosaicRegistry.registerExpression[ST_UpdateSRID](expressionConfig)
mosaicRegistry.registerExpression[ST_Within](expressionConfig)
mosaicRegistry.registerExpression[ST_X](expressionConfig)
mosaicRegistry.registerExpression[ST_Y](expressionConfig)
mosaicRegistry.registerExpression[ST_Haversine](expressionConfig)
Expand Down Expand Up @@ -630,6 +631,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends
/** Spatial predicates */
def st_contains(geom1: Column, geom2: Column): Column = ColumnAdapter(ST_Contains(geom1.expr, geom2.expr, expressionConfig))
def st_intersects(left: Column, right: Column): Column = ColumnAdapter(ST_Intersects(left.expr, right.expr, expressionConfig))
def st_within(geom1: Column, geom2: Column): Column = ColumnAdapter(ST_Within(geom1.expr, geom2.expr, expressionConfig))

/** RasterAPI dependent functions */
def rst_bandmetadata(raster: Column, band: Column): Column =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package com.databricks.labs.mosaic.expressions.geometry

import com.databricks.labs.mosaic.functions.MosaicContext
import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator}
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
import org.scalatest.matchers.must.Matchers.noException
import org.scalatest.matchers.should.Matchers.{an, be, convertToAnyShouldWrapper}

trait ST_WithinBehaviors extends MosaicSpatialQueryTest {

def withinBehavior(mosaicContext: MosaicContext): Unit = {
spark.sparkContext.setLogLevel("FATAL")
val mc = mosaicContext
import mc.functions._
val sc = spark
import sc.implicits._
mc.register(spark)

val poly = """POLYGON ((10 10, 110 10, 110 110, 10 110, 10 10),
| (20 20, 20 30, 30 30, 30 20, 20 20),
| (40 20, 40 30, 50 30, 50 20, 40 20))""".stripMargin.filter(_ >= ' ')

val rows = List(
("POINT (35 25)", poly, true),
("POINT (25 25)", poly, false)
)

val results = rows
.toDF("leftGeom", "rightGeom", "expected")
.withColumn("result", st_within($"leftGeom", $"rightGeom"))
.where($"expected" === $"result")

results.count shouldBe 2
}

def withinCodegen(mosaicContext: MosaicContext): Unit = {
spark.sparkContext.setLogLevel("FATAL")
val mc = mosaicContext
val sc = spark
import mc.functions._
import sc.implicits._
mc.register(spark)

val poly = """POLYGON ((10 10, 110 10, 110 110, 10 110, 10 10),
| (20 20, 20 30, 30 30, 30 20, 20 20),
| (40 20, 40 30, 50 30, 50 20, 40 20))""".stripMargin.filter(_ >= ' ')

val rows = List(
("POINT (35 25)", true),
("POINT (25 25)", false)
)

val polygons = List(poly).toDF("rightGeom")
val points = rows.toDF("leftGeom", "expected")

val result = polygons
.crossJoin(points)
.withColumn("result", st_within($"leftGeom", $"rightGeom"))
.where($"expected" === $"result")

val queryExecution = result.queryExecution
val plan = queryExecution.executedPlan

val wholeStageCodegenExec = plan.find(_.isInstanceOf[WholeStageCodegenExec])

wholeStageCodegenExec.isDefined shouldBe true

val codeGenStage = wholeStageCodegenExec.get.asInstanceOf[WholeStageCodegenExec]
val (_, code) = codeGenStage.doCodeGen()

noException should be thrownBy CodeGenerator.compile(code)

val stWithin = ST_Within(lit(rows.head._1).expr, lit(1).expr, mc.expressionConfig)
val ctx = new CodegenContext
an[Error] should be thrownBy stWithin.genCode(ctx)
}

def auxiliaryMethods(mosaicContext: MosaicContext): Unit = {
spark.sparkContext.setLogLevel("FATAL")
val mc = mosaicContext
mc.register(spark)

val poly = """POLYGON ((10 10, 110 10, 110 110, 10 110, 10 10),
| (20 20, 20 30, 30 30, 30 20, 20 20),
| (40 20, 40 30, 50 30, 50 20, 40 20))""".stripMargin.filter(_ >= ' ')

val rows = List(
("POINT (35 25)", true),
("POINT (25 25)", false)
)

val stWithin = ST_Within(lit(rows.head._1).expr, lit(poly).expr, mc.expressionConfig)

stWithin.left shouldEqual lit(rows.head._1).expr
stWithin.right shouldEqual lit(poly).expr
stWithin.dataType shouldEqual BooleanType
noException should be thrownBy stWithin.makeCopy(Array(stWithin.left, stWithin.right))

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.databricks.labs.mosaic.expressions.geometry

import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest
import org.apache.spark.sql.test.SharedSparkSession

class ST_WithinTest extends MosaicSpatialQueryTest with SharedSparkSession with ST_WithinBehaviors {

testAllGeometriesNoCodegen("ST_Within behavior") { withinBehavior }
testAllGeometriesCodegen("ST_Within codegen compilation") { withinCodegen }
testAllGeometriesCodegen("ST_Within codegen behavior") { withinBehavior }
testAllGeometriesNoCodegen("ST_Within auxiliary methods") { auxiliaryMethods }

}

0 comments on commit cb98f55

Please sign in to comment.