Skip to content

Commit

Permalink
implement ST_ASEWKT
Browse files Browse the repository at this point in the history
refactor and add tests

clean up

restore ignored test
  • Loading branch information
r3stl355 committed May 23, 2023
1 parent bb49756 commit 79b291c
Show file tree
Hide file tree
Showing 47 changed files with 446 additions and 13 deletions.
27 changes: 27 additions & 0 deletions docs/code-example-notebooks/accessors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,30 @@ df.select(st_astext(st_point($"lon", $"lat")).alias("wkt")).show()
// MAGIC %r
// MAGIC df <- createDataFrame(data.frame(lon = 30.0, lat = 10.0))
// MAGIC showDF(select(df, alias(st_aswkt(st_point(column("lon"), column("lat"))), "wkt")), truncate=F)

// COMMAND ----------

// MAGIC %md
// MAGIC ### st_asewkt

// COMMAND ----------

// MAGIC %python
// MAGIC df = spark.createDataFrame([{'lon': 30., 'lat': 10.}])
// MAGIC df.select(st_asewkt(st_point('lon', 'lat')).alias('ewkt')).show()

// COMMAND ----------

val df = List((30.0, 10.0)).toDF("lon", "lat")
df.select(st_asewkt(st_point($"lon", $"lat")).alias("ewkt")).show()

// COMMAND ----------

// MAGIC %sql
// MAGIC SELECT st_asewkt(st_point(30D, 10D)) AS ewkt

// COMMAND ----------

// MAGIC %r
// MAGIC df <- createDataFrame(data.frame(lon = 30.0, lat = 10.0))
// MAGIC showDF(select(df, alias(st_asewkt(st_point(column("lon"), column("lat"))), "ewkt")), truncate=F)
3 changes: 3 additions & 0 deletions docs/code-example-notebooks/kepler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@

# WKB representation
.withColumn("geom_wkb", mos.st_aswkb(col("geom_internal")))

# WKT representation
.withColumn("geom_ewkt", mos.st_asewkt(col("geom_internal")))

# Limit to only 1 shape
.limit(1)
Expand Down
57 changes: 57 additions & 0 deletions docs/source/api/geometry-accessors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,3 +278,60 @@ st_aswkt


.. note:: Alias for :ref:`st_astext`.

st_aswkt
********

.. function:: st_aeswkt(col)

Translate a geometry into its representation in Extended Well-known Text (EWKT) format.

:param col: Geometry column
:type col: Column: BinaryType, HexType, JSONType or InternalGeometryType
:rtype: Column: StringType

:example:

.. tabs::
.. code-tab:: py

>>> df = spark.createDataFrame([{'lon': 30., 'lat': 10.}])
>>> df.select(st_asewkt(st_point('lon', 'lat')).alias('ewkt')).show()
+-----------------------+
| ewkt|
+-----------------------+
|SRID=4326;POINT (30 10)|
+-----------------------+

.. code-tab:: scala

>>> val df = List((30.0, 10.0)).toDF("lon", "lat")
>>> df.select(st_asewkt(st_point($"lon", $"lat")).alias("ewkt")).show()
+-----------------------+
| ewkt|
+-----------------------+
|SRID=4326;POINT (30 10)|
+-----------------------+

.. code-tab:: sql

>>> SELECT st_asewkt(st_point(30.0D, 10.0D)) AS ewkt
+-----------------------+
| ewkt|
+-----------------------+
|SRID=4326;POINT (30 10)|
+-----------------------+

.. code-tab:: r R

>>> df <- createDataFrame(data.frame(lon = 30.0, lat = 10.0))
>>> showDF(select(df, alias(st_asewkt(st_point(column("lon"), column("lat"))), "ewkt")), truncate=F)
+-----------------------+
| ewkt|
+-----------------------+
|SRID=4326;POINT (30 10)|
+-----------------------+


.. note:: Default SRID value of a geometry created without specifying the explicit SRID value may be specific to a chosen geometry API. Currently,
default SRID on ESRI is 4326 (as shown in the examples), whereas it is 0 on JTS.
20 changes: 20 additions & 0 deletions python/mosaic/api/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"as_hex",
"as_json",
"convert_to",
"st_asewkt",
]


Expand Down Expand Up @@ -131,3 +132,22 @@ def convert_to(geom: ColumnOrName) -> Column:
return config.mosaic_context.invoke_function(
"convert_to", pyspark_to_java_column(geom)
)

def st_asewkt(geom: ColumnOrName) -> Column:
"""
Translate a geometry into its Extended Well-known Text (EWKT) representation.
Parameters
----------
geom : Column (BinaryType, HexType, JSONType or InternalGeometryType)
Geometry column
Returns
-------
Column (StringType)
An EWKT geometry
"""
return config.mosaic_context.invoke_function(
"st_asewkt", pyspark_to_java_column(geom)
)
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ object ConvertToCodeGen {
case "JSONOBJECT" => geometryCodeGen.toJSON(ctx, eval, geometryAPI)
case "GEOJSON" => geometryCodeGen.toGeoJSON(ctx, eval, geometryAPI)
case "COORDS" => geometryCodeGen.toInternal(ctx, eval, geometryAPI)
case "EWKT" => geometryCodeGen.toEWKT(ctx, eval, geometryAPI)
case _ => throw new Error(s"Data type unsupported: $outputDataFormatName.")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ trait GeometryIOCodeGen {

def fromInternal(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String)

def fromEWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String)

def toWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String)

def toWKB(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String)
Expand All @@ -28,4 +30,6 @@ trait GeometryIOCodeGen {

def toInternal(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String)

def toEWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String)

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package com.databricks.labs.mosaic.codegen.format

import java.nio.ByteBuffer

import com.databricks.labs.mosaic.core.geometry.MosaicGeometryESRI
import com.databricks.labs.mosaic.core.geometry.{MosaicGeometry, MosaicGeometryESRI}
import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI
import com.databricks.labs.mosaic.core.types.InternalGeometryType
import com.esri.core.geometry.ogc.OGCGeometry
import com.esri.core.geometry.SpatialReference
import org.locationtech.jts.io.{WKBReader, WKBWriter}

import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
Expand All @@ -15,9 +16,10 @@ import org.apache.spark.sql.types.{BinaryType, StringType}
object MosaicGeometryIOCodeGenESRI extends GeometryIOCodeGen {

override def fromWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
val inputGeom = ctx.freshName("inputGeom")
val ogcGeom = classOf[OGCGeometry].getName
(s"""$ogcGeom $inputGeom = $ogcGeom.fromText($eval.toString());""", inputGeom)
// Technically, fromEWKT can have an implementation which is only a subset of implementation of
// fromWKT but it's not really necessary and both can use the same implementation so long as
// it works for both.
fromEWKT(ctx, eval, geometryAPI)
}

override def fromWKB(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
Expand Down Expand Up @@ -77,6 +79,30 @@ object MosaicGeometryIOCodeGenESRI extends GeometryIOCodeGen {
)
}

override def fromEWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
val inputGeom = ctx.freshName("inputGeom")
val geom = ctx.freshName("geom")
val parts = ctx.freshName("parts")
val srid = ctx.freshName("srid")
val ogcGeom = classOf[OGCGeometry].getName
val sptRef = classOf[SpatialReference].getName
(
s"""
|$ogcGeom $inputGeom;
|String $geom = $eval.toString();
|if ($geom.startsWith("SRID=")) {
| String[] $parts = $geom.split(";", 0);
| String $srid = $parts[0].split("=", 0)[1];
| $inputGeom = $ogcGeom.fromText($parts[1]);
| $inputGeom.setSpatialReference($sptRef.create(Integer.parseInt($srid)));
|} else {
| $inputGeom = $ogcGeom.fromText($geom);
|}
|""".stripMargin,
inputGeom
)
}

override def toWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
val outputGeom = ctx.freshName("outputGeom")
val javaStringType = CodeGenerator.javaType(StringType)
Expand Down Expand Up @@ -153,4 +179,18 @@ object MosaicGeometryIOCodeGenESRI extends GeometryIOCodeGen {
)
}

override def toEWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
val outputGeom = ctx.freshName("outputGeom")
val srid = ctx.freshName("grid")
val javaStringType = CodeGenerator.javaType(StringType)
(
s"""
|int $srid = 0;
|if ($eval.esriSR != null) $srid = $eval.getEsriSpatialReference().getID();
|$javaStringType $outputGeom = $javaStringType.fromString("SRID=" + Integer.toString($srid) + ";" + $eval.asText());
|""".stripMargin,
outputGeom
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ import org.apache.spark.sql.types.{BinaryType, StringType}
object MosaicGeometryIOCodeGenJTS extends GeometryIOCodeGen {

override def fromWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
val inputGeom = ctx.freshName("inputGeom")
val jtsGeom = classOf[Geometry].getName
val wktReader = classOf[WKTReader].getName
(s"""$jtsGeom $inputGeom = new $wktReader().read($eval.toString());""", inputGeom)
fromEWKT(ctx, eval, geometryAPI)
}

override def fromWKB(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
Expand Down Expand Up @@ -78,6 +75,30 @@ object MosaicGeometryIOCodeGenJTS extends GeometryIOCodeGen {
)
}

override def fromEWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
val inputGeom = ctx.freshName("inputGeom")
val geom = ctx.freshName("geom")
val parts = ctx.freshName("parts")
val srid = ctx.freshName("srid")
val jtsGeom = classOf[Geometry].getName
val wktReader = classOf[WKTReader].getName
(
s"""
|$jtsGeom $inputGeom;
|String $geom = $eval.toString();
|if ($geom.startsWith("SRID=")) {
| String[] $parts = $geom.split(";", 0);
| String $srid = $parts[0].split("=", 0)[1];
| $inputGeom = new $wktReader().read($parts[1]);
| $inputGeom.setSRID(Integer.parseInt($srid));
|} else {
| $inputGeom = new $wktReader().read($geom);;
|}
|""".stripMargin,
inputGeom
)
}

override def toWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
val outputGeom = ctx.freshName("outputGeom")
val javaStringType = CodeGenerator.javaType(StringType)
Expand Down Expand Up @@ -172,4 +193,16 @@ object MosaicGeometryIOCodeGenJTS extends GeometryIOCodeGen {
)
}

override def toEWKT(ctx: CodegenContext, eval: String, geometryAPI: GeometryAPI): (String, String) = {
val outputGeom = ctx.freshName("outputGeom")
val javaStringType = CodeGenerator.javaType(StringType)
val wktWriterClass = classOf[WKTWriter].getName
(
s"""
|$javaStringType $outputGeom = $javaStringType.fromString("SRID=" + Integer.toString($eval.getSRID()) + ";" + new $wktWriterClass().write($eval));
|""".stripMargin,
outputGeom
)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ trait GeometryReader {

def fromSeq[T <: MosaicGeometry](geomSeq: Seq[T], geomType: GeometryTypeEnum.Value): MosaicGeometry

def fromEWKT(ewkt: String): MosaicGeometry
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ trait GeometryWriter {

def toHEX: String

def toEWKT: String

}
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ abstract class MosaicGeometryESRI(geom: OGCGeometry) extends MosaicGeometry {

override def toWKB: Array[Byte] = geom.asBinary().array()

override def toEWKT: String = s"SRID=${getSpatialReference};${toWKT}"

override def getSpatialReference: Int = if (geom.esriSR == null) 0 else geom.getEsriSpatialReference.getID

override def setSpatialReference(srid: Int): Unit = {
Expand Down Expand Up @@ -268,4 +270,12 @@ object MosaicGeometryESRI extends GeometryReader {
reader(typeId).fromInternal(row).asInstanceOf[MosaicGeometryESRI]
}

override def fromEWKT(ewkt: String): MosaicGeometryESRI = {
val pat = "SRID=(\\d*);(.*)".r
val pat(srid, wkt) = ewkt
val res = MosaicGeometryESRI(OGCGeometry.fromText(wkt))
res.setSpatialReference(srid.toInt)
res
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ abstract class MosaicGeometryJTS(geom: Geometry) extends MosaicGeometry {

override def toWKB: Array[Byte] = new WKBWriter().write(geom)

override def toEWKT: String = s"SRID=${getSpatialReference};${toWKT}"

override def numPoints: Int = geom.getNumPoints

override def getSpatialReference: Int = geom.getSRID
Expand Down Expand Up @@ -277,4 +279,12 @@ object MosaicGeometryJTS extends GeometryReader {
case GEOMETRYCOLLECTION => MosaicGeometryCollectionJTS
}

override def fromEWKT(ewkt: String): MosaicGeometryJTS = {
val pat = "SRID=(\\d*);(.*)".r
val pat(srid, wkt) = ewkt
val res = MosaicGeometryJTS(new WKTReader().read(wkt))
res.setSpatialReference(srid.toInt)
res
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ abstract class GeometryAPI(
case "HEX" => reader.fromHEX(input.asInstanceOf[String])
case "WKB" => reader.fromWKB(input.asInstanceOf[Array[Byte]])
case "GEOJSON" => reader.fromJSON(input.asInstanceOf[String])
case "EWKT" => reader.fromEWKT(input.asInstanceOf[String])
case "COORDS" => throw new Error(s"$typeName not supported.")
case _ => throw new Error(s"$typeName not supported.")
}
Expand All @@ -43,7 +44,7 @@ abstract class GeometryAPI(
def geometry(inputData: InternalRow, dataType: DataType): MosaicGeometry = {
dataType match {
case _: BinaryType => reader.fromWKB(inputData.getBinary(0))
case _: StringType => reader.fromWKT(inputData.getString(0))
case _: StringType => val s = inputData.getString(0); if (s.matches("SRID=\\d*;.*")) reader.fromEWKT(s) else reader.fromWKT(s)
case _: HexType => reader.fromHEX(inputData.get(0, HexType).asInstanceOf[InternalRow].getString(0))
case _: JSONType => reader.fromJSON(inputData.get(0, JSONType).asInstanceOf[InternalRow].getString(0))
case _: InternalGeometryType => reader.fromInternal(inputData.get(0, InternalGeometryType).asInstanceOf[InternalRow])
Expand All @@ -64,7 +65,7 @@ abstract class GeometryAPI(
def geometry(inputData: Any, dataType: DataType): MosaicGeometry =
dataType match {
case _: BinaryType => reader.fromWKB(inputData.asInstanceOf[Array[Byte]])
case _: StringType => reader.fromWKT(inputData.asInstanceOf[UTF8String].toString)
case _: StringType => val s = inputData.asInstanceOf[UTF8String].toString; if (s.matches("SRID=\\d*;.*")) reader.fromEWKT(s) else reader.fromWKT(s)
case _: HexType => reader.fromHEX(inputData.asInstanceOf[InternalRow].getString(0))
case _: JSONType => reader.fromJSON(inputData.asInstanceOf[InternalRow].getString(0))
case _: InternalGeometryType => reader.fromInternal(inputData.asInstanceOf[InternalRow])
Expand All @@ -83,6 +84,7 @@ abstract class GeometryAPI(
case "JSONOBJECT" => InternalRow.fromSeq(Seq(UTF8String.fromString(geometry.toJSON)))
case "GEOJSON" => UTF8String.fromString(geometry.toJSON)
case "COORDS" => geometry.toInternal.serialize
case "EWKT" => UTF8String.fromString(geometry.toEWKT)
case _ => throw new Error(s"$dataFormatName not supported.")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,5 @@ object MosaicGeometryCollectionESRI extends GeometryReader {

override def fromHEX(hex: String): MosaicGeometryESRI = MosaicGeometryESRI.fromHEX(hex)

override def fromEWKT(ewkt: String): MosaicGeometryESRI = MosaicGeometryESRI.fromEWKT(ewkt)
}
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,6 @@ object MosaicGeometryCollectionJTS extends GeometryReader {

override def fromHEX(hex: String): MosaicGeometryJTS = MosaicGeometryJTS.fromHEX(hex)

override def fromEWKT(ewkt: String): MosaicGeometryJTS = MosaicGeometryJTS.fromEWKT(ewkt)

}
Loading

0 comments on commit 79b291c

Please sign in to comment.