From 7601a35b6fda379afe5218d150193ee201db3122 Mon Sep 17 00:00:00 2001 From: Daniel Sparing Date: Tue, 16 Jan 2024 20:42:21 +0100 Subject: [PATCH] dont use mock in test --- .../expressions/geometry/ST_ZBehaviors.scala | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZBehaviors.scala index bd160a895..566b74b70 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/geometry/ST_ZBehaviors.scala @@ -1,7 +1,7 @@ package com.databricks.labs.mosaic.expressions.geometry import com.databricks.labs.mosaic.functions.MosaicContext -import com.databricks.labs.mosaic.test.{mocks, MosaicSpatialQueryTest} +import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.functions.lit @@ -19,32 +19,19 @@ trait ST_ZBehaviors extends MosaicSpatialQueryTest { import sc.implicits._ mc.register(spark) - val expected = mocks - .getWKTRowsDf() - .orderBy("id") - .select("wkt") - .where("wkt SUBSTRING(wkt, 1, 5) = 'POINT'") - .as[String] - .collect() - .map(wkt => mc.getGeometryAPI.geometry(wkt, "WKT")) - .map(c => c.getZ) + val rows = List( + ("POINT (2 3 5)", 5), + ("POINT (7 11 13)", 13), + ("POINT (17 19 23)", 23), + ("POINT (29 31 37)", 37) + ) - val result = mocks - .getWKTRowsDf() - .select(st_z($"wkt").alias("z")) - .as[Double] - .collect() + val result = rows + .toDF("wkt", "expected") + .withColumn("result", st_z($"wkt")) + .where($"expected" === $"result") - result.zip(expected).foreach { case (l, r) => l.equals(r) shouldEqual true } - - mocks.getWKTRowsDf().createOrReplaceTempView("source") - - val sqlResult = spark - .sql("""select st_z(wkt) from source""".stripMargin) - .as[Double] - .collect() - - sqlResult.zip(expected).foreach { case (l, r) => l.equals(r) shouldEqual true } + results.count shouldBe 4 } def stzCodegen(mosaicContext: MosaicContext): Unit = { @@ -55,9 +42,18 @@ trait ST_ZBehaviors extends MosaicSpatialQueryTest { import sc.implicits._ mc.register(spark) - val result = mocks - .getWKTRowsDf() - .select(st_z($"wkt").alias("z")) + val rows = List( + ("POINT (2 3 5)", 5), + ("POINT (7 11 13)", 13), + ("POINT (17 19 23)", 23), + ("POINT (29 31 37)", 37) + ) + + val points = rows.toDF("wkt", "expected") + + val result = points + .withColumn("result", st_z($"wkt")) + .where($"expected" === $"result") val queryExecution = result.queryExecution val plan = queryExecution.executedPlan