From 9c3c7cbbc89ecf0e2ca088eeb6c358200c7d0ba6 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 15 Jan 2024 13:29:57 +0000 Subject: [PATCH 01/26] Add capabilities for RST_ stats. Refactor RasterToGrid readers. --- .../raster/gdal/MosaicRasterBandGDAL.scala | 24 ++ .../core/raster/gdal/MosaicRasterGDAL.scala | 4 +- .../core/raster/operator/gdal/GDALInfo.scala | 38 +++ .../operator/retile/BalancedSubdivision.scala | 9 +- .../multiread/RasterAsGridReader.scala | 52 +++- .../mosaic/expressions/raster/RST_Avg.scala | 57 ++++ .../mosaic/expressions/raster/RST_Max.scala | 48 ++++ .../expressions/raster/RST_Median.scala | 60 +++++ .../mosaic/expressions/raster/RST_Min.scala | 48 ++++ .../expressions/raster/RST_PixelCount.scala | 46 ++++ .../labs/mosaic/functions/MosaicContext.scala | 16 +- .../labs/mosaic/utils/FileUtils.scala | 34 +++ .../multiread/RasterAsGridReaderTest.scala | 245 ++++++++++-------- .../expressions/raster/RST_MaxBehaviors.scala | 52 ++++ .../expressions/raster/RST_MaxTest.scala | 32 +++ .../raster/RST_MedianBehaviors.scala | 52 ++++ .../expressions/raster/RST_MedianTest.scala | 32 +++ .../raster/RST_PixelCountBehaviors.scala | 52 ++++ .../raster/RST_PixelCountTest.scala | 32 +++ .../sql/test/MosaicTestSparkSession.scala | 27 ++ .../sql/test/SharedSparkSessionGDAL.scala | 11 +- 21 files changed, 837 insertions(+), 134 deletions(-) create mode 100644 src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxTest.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountTest.scala create mode 100644 src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala index 3fa45f8e5..a7c9ece10 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala @@ -219,6 +219,30 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) { } } + /** + * Counts the number of pixels in the band. The mask is used to determine + * if a pixel is valid. If pixel value is noData or mask value is 0.0, the + * pixel is not counted. + * + * @return + * Returns the band's pixel count. + */ + def pixelCount: Int = { + val line = Array.ofDim[Double](band.GetXSize()) + val maskLine = Array.ofDim[Double](band.GetXSize()) + var count = 0 + for (y <- 0 until band.GetYSize()) { + band.ReadRaster(0, y, band.GetXSize(), 1, line) + val maskRead = band.GetMaskBand().ReadRaster(0, y, band.GetXSize(), 1, maskLine) + if (maskRead != gdalconstConstants.CE_None) { + count = count + line.count(_ != noDataValue) + } else { + count = count + line.zip(maskLine).count { case (pixel, mask) => pixel != noDataValue && mask != 0.0 } + } + } + count + } + /** * @return * Returns the band's mask flags. diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index 4f51749dc..3ac467f53 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -8,7 +8,7 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose import com.databricks.labs.mosaic.core.raster.io.{RasterCleaner, RasterReader, RasterWriter} import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum.POLYGON -import com.databricks.labs.mosaic.utils.PathUtils +import com.databricks.labs.mosaic.utils.{FileUtils, PathUtils} import org.gdal.gdal.gdal.GDALInfo import org.gdal.gdal.{Dataset, InfoOptions, gdal} import org.gdal.gdalconst.gdalconstConstants._ @@ -405,7 +405,7 @@ case class MosaicRasterGDAL( } else { path } - val byteArray = Files.readAllBytes(Paths.get(readPath)) + val byteArray = FileUtils.readBytes(readPath) if (dispose) RasterCleaner.dispose(this) if (readPath != PathUtils.getCleanPath(parentPath)) { Files.deleteIfExists(Paths.get(readPath)) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala new file mode 100644 index 000000000..7a60a837a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala @@ -0,0 +1,38 @@ +package com.databricks.labs.mosaic.core.raster.operator.gdal + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import org.gdal.gdal.{InfoOptions, gdal} + +/** GDALBuildVRT is a wrapper for the GDAL BuildVRT command. */ +object GDALInfo { + + /** + * Executes the GDAL BuildVRT command. For flags check the way gdalinfo.py + * script is called, InfoOptions expects a collection of same flags. + * + * @param raster + * The raster to get info from. + * @param command + * The GDAL Info command. + * @return + * A result json string. + */ + def executeInfo(raster: MosaicRasterGDAL, command: String): String = { + require(command.startsWith("gdalinfo"), "Not a valid GDAL Info command.") + + val infoOptionsVec = OperatorOptions.parseOptions(command) + val infoOptions = new InfoOptions(infoOptionsVec) + val gdalInfo = gdal.GDALInfo(raster.getRaster, infoOptions) + + if (gdalInfo == null) { + throw new Exception(s""" + |GDAL Info failed. + |Command: $command + |Error: ${gdal.GetLastErrorMsg} + |""".stripMargin) + } + + gdalInfo + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala index 75e59c1fa..daa0e6266 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala @@ -21,9 +21,12 @@ object BalancedSubdivision { */ def getNumSplits(raster: MosaicRasterGDAL, destSize: Int): Int = { val size = raster.getMemSize - val n = size.toDouble / (destSize * 1000 * 1000) - val nInt = Math.ceil(n).toInt - Math.pow(4, Math.ceil(Math.log(nInt) / Math.log(4))).toInt + var n = 1 + while (true) { + n *= 4 + if (size / n <= destSize * 1000 * 1000) return n + } + n } /** diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala index c1f805afa..d6f26caf4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.datasource.multiread +import com.databricks.labs.mosaic.MOSAIC_RASTER_READ_STRATEGY import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -25,6 +26,12 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead nPartitions } + private def workerNCores = { + sparkSession.sparkContext.range(0, 1).map(_ => java.lang.Runtime.getRuntime.availableProcessors).collect.head + } + + private def nWorkers = sparkSession.sparkContext.getExecutorMemoryStatus.size + override def load(path: String): DataFrame = load(Seq(path): _*) override def load(paths: String*): DataFrame = { @@ -32,11 +39,23 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead val config = getConfig val resolution = config("resolution").toInt val nPartitions = getNPartitions(config) + val readStrategy = config("retile") match { + case "true" => "retile_on_read" + case _ => "in_memory" + } + val tileSize = config("sizeInMB").toInt + + val nCores = nWorkers * workerNCores + val stageCoefficient = math.ceil(math.log(nCores) / math.log(4)) + + val firstStageSize = (tileSize * math.pow(4, stageCoefficient)).toInt val pathsDf = sparkSession.read .format("gdal") .option("extensions", config("extensions")) - .option("raster_storage", "in-memory") + .option(MOSAIC_RASTER_READ_STRATEGY, readStrategy) + .option("vsizip", config("vsizip")) + .option("sizeInMB", firstStageSize) .load(paths: _*) .repartition(nPartitions) @@ -46,7 +65,12 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead val retiledDf = retileRaster(rasterDf, config) - val loadedDf = retiledDf + val loadedDf = rasterDf + .withColumn( + "tile", + rst_tessellate(col("tile"), lit(resolution)) + ) + .repartition(nPartitions) .withColumn( "grid_measures", rasterToGridCombiner(col("tile"), lit(resolution)) @@ -58,6 +82,7 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead .select( posexplode(col("grid_measures")).as(Seq("band_id", "grid_measures")) ) + .repartition(nPartitions) .select( col("band_id"), explode(col("grid_measures")).alias("grid_measures") @@ -88,16 +113,22 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead */ private def retileRaster(rasterDf: DataFrame, config: Map[String, String]) = { val retile = config("retile").toBoolean - val tileSize = config("tileSize").toInt + val tileSize = config.getOrElse("tileSize", "-1").toInt + val memSize = config.getOrElse("sizeInMB", "-1").toInt val nPartitions = getNPartitions(config) if (retile) { - rasterDf - .withColumn( - "tile", - rst_retile(col("tile"), lit(tileSize), lit(tileSize)) - ) - .repartition(nPartitions) + if (memSize > 0) { + rasterDf + .withColumn("tile", rst_subdivide(col("tile"), lit(memSize))) + .repartition(nPartitions) + } else if (tileSize > 0) { + rasterDf + .withColumn("tile", rst_retile(col("tile"), lit(tileSize), lit(tileSize))) + .repartition(nPartitions) + } else { + rasterDf + } } else { rasterDf } @@ -200,7 +231,8 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead "resolution" -> this.extraOptions.getOrElse("resolution", "0"), "combiner" -> this.extraOptions.getOrElse("combiner", "mean"), "retile" -> this.extraOptions.getOrElse("retile", "false"), - "tileSize" -> this.extraOptions.getOrElse("tileSize", "256"), + "tileSize" -> this.extraOptions.getOrElse("tileSize", "-1"), + "sizeInMB" -> this.extraOptions.getOrElse("sizeInMB", ""), "kRingInterpolate" -> this.extraOptions.getOrElse("kRingInterpolate", "0") ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala new file mode 100644 index 000000000..be82af449 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala @@ -0,0 +1,57 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALInfo +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + + +/** Returns the upper left x of the raster. */ +case class RST_Avg(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Avg](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + import org.json4s._ + import org.json4s.jackson.JsonMethods._ + implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats + + val command = s"gdalinfo -stats -json -mm -nogcp -nomd -norat -noct" + val gdalInfo = GDALInfo.executeInfo(tile.raster, command) + // parse json from gdalinfo + val json = parse(gdalInfo).extract[Map[String, Any]] + val maxValues = json("bands").asInstanceOf[List[Map[String, Any]]].map { band => + band("mean").asInstanceOf[Double] + } + ArrayData.toArrayData(maxValues.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Avg extends WithExpressionInfo { + + override def name: String = "rst_mean" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Avg](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala new file mode 100644 index 000000000..abe042c2b --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALInfo +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + + +/** Returns the upper left x of the raster. */ +case class RST_Max(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Max](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + val nBands = tile.raster.raster.GetRasterCount() + val maxValues = (1 to nBands).map(tile.raster.getBand(_).maxPixelValue) + ArrayData.toArrayData(maxValues.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Max extends WithExpressionInfo { + + override def name: String = "rst_max" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing max values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Max](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala new file mode 100644 index 000000000..091121e91 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala @@ -0,0 +1,60 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALCalc, GDALInfo, GDALWarp} +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + +/** Returns the upper left x of the raster. */ +case class RST_Median(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Median](rasterExpr, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + val raster = tile.raster + val width = raster.xSize * raster.pixelXSize + val height = raster.ySize * raster.pixelYSize + val outShortName = raster.getDriversShortName + val resultFileName = PathUtils.createTmpFilePath(GDAL.getExtension(outShortName)) + val medRaster = GDALWarp.executeWarp( + resultFileName, + Seq(raster), + command = s"gdalwarp -r med -tr $width $height -of $outShortName" + ) + // Max pixel is a hack since we get a 1x1 raster back + val maxValues = (1 to medRaster.raster.GetRasterCount()).map(medRaster.getBand(_).maxPixelValue) + ArrayData.toArrayData(maxValues.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Median extends WithExpressionInfo { + + override def name: String = "rst_median" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Median](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala new file mode 100644 index 000000000..67fdb30d3 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALInfo +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + + +/** Returns the upper left x of the raster. */ +case class RST_Min(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Min](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + val nBands = tile.raster.raster.GetRasterCount() + val minValues = (1 to nBands).map(tile.raster.getBand(_).minPixelValue) + ArrayData.toArrayData(minValues.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Min extends WithExpressionInfo { + + override def name: String = "rst_min" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing min values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Min](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala new file mode 100644 index 000000000..79f44db03 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala @@ -0,0 +1,46 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + +/** Returns the upper left x of the raster. */ +case class RST_PixelCount(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_PixelCount](rasterExpr, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + val bandCount = tile.raster.raster.GetRasterCount() + val pixelCount = (1 to bandCount).map(tile.raster.getBand(_).pixelCount) + ArrayData.toArrayData(pixelCount.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_PixelCount extends WithExpressionInfo { + + override def name: String = "rst_pixelcount" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing valid pixel count values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [12, 212, 313] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_PixelCount](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 8e483c702..1b59ff9a1 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -13,6 +13,8 @@ import com.databricks.labs.mosaic.expressions.geometry.ST_MinMaxXYZ._ import com.databricks.labs.mosaic.expressions.index._ import com.databricks.labs.mosaic.expressions.raster._ import com.databricks.labs.mosaic.expressions.util.TrySql +import com.databricks.labs.mosaic.functions.MosaicContext.mosaicVersion +import com.databricks.labs.mosaic.utils.FileUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier @@ -21,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{LongType, StringType} -import java.nio.file.Files import scala.reflect.runtime.universe //noinspection DuplicatedCode @@ -255,6 +256,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ) /** RasterAPI dependent functions */ + mosaicRegistry.registerExpression[RST_Avg](expressionConfig) mosaicRegistry.registerExpression[RST_BandMetaData](expressionConfig) mosaicRegistry.registerExpression[RST_BoundingBox](expressionConfig) mosaicRegistry.registerExpression[RST_Clip](expressionConfig) @@ -266,6 +268,9 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[RST_Height](expressionConfig) mosaicRegistry.registerExpression[RST_InitNoData](expressionConfig) mosaicRegistry.registerExpression[RST_IsEmpty](expressionConfig) + mosaicRegistry.registerExpression[RST_Max](expressionConfig) + mosaicRegistry.registerExpression[RST_Min](expressionConfig) + mosaicRegistry.registerExpression[RST_Median](expressionConfig) mosaicRegistry.registerExpression[RST_MemSize](expressionConfig) mosaicRegistry.registerExpression[RST_Merge](expressionConfig) mosaicRegistry.registerExpression[RST_FromBands](expressionConfig) @@ -275,6 +280,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[RST_NumBands](expressionConfig) mosaicRegistry.registerExpression[RST_PixelWidth](expressionConfig) mosaicRegistry.registerExpression[RST_PixelHeight](expressionConfig) + mosaicRegistry.registerExpression[RST_PixelCount](expressionConfig) mosaicRegistry.registerExpression[RST_RasterToGridAvg](expressionConfig) mosaicRegistry.registerExpression[RST_RasterToGridMax](expressionConfig) mosaicRegistry.registerExpression[RST_RasterToGridMin](expressionConfig) @@ -637,6 +643,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(RST_BandMetaData(raster.expr, lit(band).expr, expressionConfig)) def rst_boundingbox(raster: Column): Column = ColumnAdapter(RST_BoundingBox(raster.expr, expressionConfig)) def rst_clip(raster: Column, geometry: Column): Column = ColumnAdapter(RST_Clip(raster.expr, geometry.expr, expressionConfig)) + def rst_pixelcount(raster: Column): Column = ColumnAdapter(RST_PixelCount(raster.expr, expressionConfig)) def rst_combineavg(rasterArray: Column): Column = ColumnAdapter(RST_CombineAvg(rasterArray.expr, expressionConfig)) def rst_derivedband(raster: Column, pythonFunc: Column, funcName: Column): Column = ColumnAdapter(RST_DerivedBand(raster.expr, pythonFunc.expr, funcName.expr, expressionConfig)) @@ -649,6 +656,10 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_height(raster: Column): Column = ColumnAdapter(RST_Height(raster.expr, expressionConfig)) def rst_initnodata(raster: Column): Column = ColumnAdapter(RST_InitNoData(raster.expr, expressionConfig)) def rst_isempty(raster: Column): Column = ColumnAdapter(RST_IsEmpty(raster.expr, expressionConfig)) + def rst_max(raster: Column): Column = ColumnAdapter(RST_Max(raster.expr, expressionConfig)) + def rst_min(raster: Column): Column = ColumnAdapter(RST_Min(raster.expr, expressionConfig)) + def rst_median(raster: Column): Column = ColumnAdapter(RST_Median(raster.expr, expressionConfig)) + def rst_avg(raster: Column): Column = ColumnAdapter(RST_Avg(raster.expr, expressionConfig)) def rst_memsize(raster: Column): Column = ColumnAdapter(RST_MemSize(raster.expr, expressionConfig)) def rst_frombands(bandsArray: Column): Column = ColumnAdapter(RST_FromBands(bandsArray.expr, expressionConfig)) def rst_merge(rasterArray: Column): Column = ColumnAdapter(RST_Merge(rasterArray.expr, expressionConfig)) @@ -965,8 +976,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends object MosaicContext extends Logging { - val tmpDir: String = Files.createTempDirectory("mosaic").toAbsolutePath.toString - + val tmpDir: String = FileUtils.createMosaicTempDir() val mosaicVersion: String = "0.3.14" private var instance: Option[MosaicContext] = None diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala new file mode 100644 index 000000000..a1aac5c2f --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala @@ -0,0 +1,34 @@ +package com.databricks.labs.mosaic.utils + +import java.io.{BufferedInputStream, FileInputStream} +import java.nio.file.{Files, Paths} + +object FileUtils { + + def readBytes(path: String): Array[Byte] = { + val bufferSize = 1024 * 1024 // 1MB + val inputStream = new BufferedInputStream(new FileInputStream(path)) + val buffer = new Array[Byte](bufferSize) + + var bytesRead = 0 + var bytes = Array.empty[Byte] + + while ({ + bytesRead = inputStream.read(buffer); bytesRead + } != -1) { + bytes = bytes ++ buffer.slice(0, bytesRead) + } + inputStream.close() + bytes + } + + def createMosaicTempDir(): String = { + val tempRoot = Paths.get("/mosaic_tmp/") + if (!Files.exists(tempRoot)) { + Files.createDirectory(tempRoot) + } + val tempDir = Files.createTempDirectory(tempRoot, "mosaic") + tempDir.toFile.getAbsolutePath + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala index 1f7b4008b..6e99aa1df 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala @@ -6,134 +6,157 @@ import com.databricks.labs.mosaic.core.index.H3IndexSystem import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest import org.apache.spark.sql.test.SharedSparkSessionGDAL import org.scalatest.matchers.must.Matchers.{be, noException} -import org.scalatest.matchers.should.Matchers.an +import org.scalatest.matchers.should.Matchers.{an, convertToAnyShouldWrapper} import java.nio.file.{Files, Paths} class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSessionGDAL { - test("Read netcdf with Raster As Grid Reader") { + test("Read big tif with Raster As Grid Reader") { assume(System.getProperty("os.name") == "Linux") + spark.sparkContext.setLogLevel("INFO") MosaicContext.build(H3IndexSystem, JTS) - val netcdf = "/binary/netcdf-coral/" - val filePath = getClass.getResource(netcdf).getPath - - noException should be thrownBy MosaicContext.read - .format("raster_to_grid") - .option("retile", "true") - .option("tileSize", "10") - .option("readSubdataset", "true") - .option("subdataset", "1") - .option("kRingInterpolate", "3") - .load(filePath) - .select("measure") - .queryExecution - .executedPlan - - } - - test("Read grib with Raster As Grid Reader") { - assume(System.getProperty("os.name") == "Linux") - MosaicContext.build(H3IndexSystem, JTS) - - val grib = "/binary/grib-cams/" - val filePath = getClass.getResource(grib).getPath + val tif = "/binary/big_tiff.tif" + val filePath = getClass.getResource(tif).getPath - noException should be thrownBy MosaicContext.read + val df = MosaicContext.read .format("raster_to_grid") - .option("extensions", "grib") - .option("combiner", "min") .option("retile", "true") - .option("tileSize", "10") - .option("kRingInterpolate", "3") + .option("sizeInMB", "64") + .option("resolution", "1") .load(filePath) .select("measure") - .take(1) - - } - test("Read tif with Raster As Grid Reader") { - assume(System.getProperty("os.name") == "Linux") - MosaicContext.build(H3IndexSystem, JTS) + //df.queryExecution.optimizedPlan - val tif = "/modis/" - val filePath = getClass.getResource(tif).getPath - - noException should be thrownBy MosaicContext.read - .format("raster_to_grid") - .option("combiner", "max") - .option("tileSize", "10") - .option("kRingInterpolate", "3") - .load(filePath) - .select("measure") - .take(1) + //noException should be thrownBy df.queryExecution.executedPlan + df.count() } - test("Read zarr with Raster As Grid Reader") { - assume(System.getProperty("os.name") == "Linux") - MosaicContext.build(H3IndexSystem, JTS) - - val zarr = "/binary/zarr-example/" - val filePath = getClass.getResource(zarr).getPath - - noException should be thrownBy MosaicContext.read - .format("raster_to_grid") - .option("combiner", "median") - .option("vsizip", "true") - .option("tileSize", "10") - .load(filePath) - .select("measure") - .take(1) - - noException should be thrownBy MosaicContext.read - .format("raster_to_grid") - .option("combiner", "count") - .option("vsizip", "true") - .load(filePath) - .select("measure") - .take(1) - - noException should be thrownBy MosaicContext.read - .format("raster_to_grid") - .option("combiner", "average") - .option("vsizip", "true") - .load(filePath) - .select("measure") - .take(1) - - noException should be thrownBy MosaicContext.read - .format("raster_to_grid") - .option("combiner", "avg") - .option("vsizip", "true") - .load(filePath) - .select("measure") - .take(1) - - val paths = Files.list(Paths.get(filePath)).toArray.map(_.toString) - - an[Error] should be thrownBy MosaicContext.read - .format("raster_to_grid") - .option("combiner", "count_+") - .option("vsizip", "true") - .load(paths: _*) - .select("measure") - .take(1) - - an[Error] should be thrownBy MosaicContext.read - .format("invalid") - .load(paths: _*) - - an[Error] should be thrownBy MosaicContext.read - .format("invalid") - .load(filePath) - - noException should be thrownBy MosaicContext.read - .format("raster_to_grid") - .option("kRingInterpolate", "3") - .load(filePath) - - } +// test("Read netcdf with Raster As Grid Reader") { +// assume(System.getProperty("os.name") == "Linux") +// MosaicContext.build(H3IndexSystem, JTS) +// +// val netcdf = "/binary/netcdf-coral/" +// val filePath = getClass.getResource(netcdf).getPath +// +// noException should be thrownBy MosaicContext.read +// .format("raster_to_grid") +// .option("retile", "true") +// .option("tileSize", "10") +// .option("readSubdataset", "true") +// .option("subdataset", "1") +// .option("kRingInterpolate", "3") +// .load(filePath) +// .select("measure") +// .queryExecution +// .executedPlan +// +// } +// +// test("Read grib with Raster As Grid Reader") { +// assume(System.getProperty("os.name") == "Linux") +// MosaicContext.build(H3IndexSystem, JTS) +// +// val grib = "/binary/grib-cams/" +// val filePath = getClass.getResource(grib).getPath +// +// noException should be thrownBy MosaicContext.read +// .format("raster_to_grid") +// .option("extensions", "grib") +// .option("combiner", "min") +// .option("retile", "true") +// .option("tileSize", "10") +// .option("kRingInterpolate", "3") +// .load(filePath) +// .select("measure") +// .take(1) +// +// } +// +// test("Read tif with Raster As Grid Reader") { +// assume(System.getProperty("os.name") == "Linux") +// MosaicContext.build(H3IndexSystem, JTS) +// +// val tif = "/modis/" +// val filePath = getClass.getResource(tif).getPath +// +// noException should be thrownBy MosaicContext.read +// .format("raster_to_grid") +// .option("combiner", "max") +// .option("tileSize", "10") +// .option("kRingInterpolate", "3") +// .load(filePath) +// .select("measure") +// .take(1) +// +// } +// +// test("Read zarr with Raster As Grid Reader") { +// assume(System.getProperty("os.name") == "Linux") +// MosaicContext.build(H3IndexSystem, JTS) +// +// val zarr = "/binary/zarr-example/" +// val filePath = getClass.getResource(zarr).getPath +// +// noException should be thrownBy MosaicContext.read +// .format("raster_to_grid") +// .option("combiner", "median") +// .option("vsizip", "true") +// .option("tileSize", "10") +// .load(filePath) +// .select("measure") +// .take(1) +// +// noException should be thrownBy MosaicContext.read +// .format("raster_to_grid") +// .option("combiner", "count") +// .option("vsizip", "true") +// .load(filePath) +// .select("measure") +// .take(1) +// +// noException should be thrownBy MosaicContext.read +// .format("raster_to_grid") +// .option("combiner", "average") +// .option("vsizip", "true") +// .load(filePath) +// .select("measure") +// .take(1) +// +// noException should be thrownBy MosaicContext.read +// .format("raster_to_grid") +// .option("combiner", "avg") +// .option("vsizip", "true") +// .load(filePath) +// .select("measure") +// .take(1) +// +// val paths = Files.list(Paths.get(filePath)).toArray.map(_.toString) +// +// an[Error] should be thrownBy MosaicContext.read +// .format("raster_to_grid") +// .option("combiner", "count_+") +// .option("vsizip", "true") +// .load(paths: _*) +// .select("measure") +// .take(1) +// +// an[Error] should be thrownBy MosaicContext.read +// .format("invalid") +// .load(paths: _*) +// +// an[Error] should be thrownBy MosaicContext.read +// .format("invalid") +// .load(filePath) +// +// noException should be thrownBy MosaicContext.read +// .format("raster_to_grid") +// .option("kRingInterpolate", "3") +// .load(filePath) +// +// } } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala new file mode 100644 index 000000000..9c095488d --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala @@ -0,0 +1,52 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_MaxBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_max($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_max(tile) from source + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridmax($"tile", lit(3))) + .select("result") + + val result = df.as[Double].collect().max + + result > 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_max() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxTest.scala new file mode 100644 index 000000000..0a0b865cd --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MaxTest extends QueryTest with SharedSparkSessionGDAL with RST_MaxBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_max behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala new file mode 100644 index 000000000..1b99fbc6f --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala @@ -0,0 +1,52 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_MedianBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_median($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_median(tile) from source + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridmax($"tile", lit(3))) + .select("result") + + val result = df.as[Double].collect().max + + result > 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_median() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala new file mode 100644 index 000000000..cfe270813 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MedianTest extends QueryTest with SharedSparkSessionGDAL with RST_MedianBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_median behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountBehaviors.scala new file mode 100644 index 000000000..87582df1f --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountBehaviors.scala @@ -0,0 +1,52 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_PixelCountBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_pixelcount($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_pixelcount(tile) from source + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridmax($"tile", lit(3))) + .select("result") + + val result = df.as[Double].collect().max + + result > 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_pixelcount() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountTest.scala new file mode 100644 index 000000000..1d24b58a4 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_PixelCountTest extends QueryTest with SharedSparkSessionGDAL with RST_PixelCountBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_pixelcount behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala b/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala new file mode 100644 index 000000000..8029c30a7 --- /dev/null +++ b/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala @@ -0,0 +1,27 @@ +package org.apache.spark.sql.test + +import org.apache.spark.{SparkConf, SparkContext} + +class MosaicTestSparkSession(sc: SparkContext) extends TestSparkSession(sc) { + + def this(sparkConf: SparkConf) = { + + this( + new SparkContext( + "local[4]", + "test-sql-context", + sparkConf + .set("spark.sql.adaptive.enabled", "false") + .set("spark.driver.memory", "32g") + .set("spark.executor.memory", "32g") + .set("spark.sql.shuffle.partitions", "4") + .set("spark.sql.testkey", "true") + ) + ) + } + + def this() = { + this(new SparkConf) + } + +} diff --git a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala index 36da49694..fad2383fd 100644 --- a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala +++ b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala @@ -1,12 +1,13 @@ package org.apache.spark.sql.test import com.databricks.labs.mosaic.gdal.MosaicGDAL +import com.databricks.labs.mosaic.utils.FileUtils import com.databricks.labs.mosaic.{MOSAIC_GDAL_NATIVE, MOSAIC_RASTER_CHECKPOINT} import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.gdal.gdal.gdal -import java.nio.file.Files +import java.nio.file.{Files, Paths} import scala.util.Try trait SharedSparkSessionGDAL extends SharedSparkSession { @@ -18,12 +19,12 @@ trait SharedSparkSessionGDAL extends SharedSparkSession { override def createSparkSession: TestSparkSession = { val conf = sparkConf - conf.set(MOSAIC_RASTER_CHECKPOINT, Files.createTempDirectory("mosaic").toFile.getAbsolutePath) + conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir()) SparkSession.cleanupAnyExistingSession() - val session = new TestSparkSession(conf) - session.sparkContext.setLogLevel("FATAL") + val session = new MosaicTestSparkSession(conf) + session.sparkContext.setLogLevel("INFO") Try { - val tempPath = Files.createTempDirectory("mosaic-gdal") + //val tempPath = Files.createTempDirectory("mosaic-gdal") // prepareEnvironment no longer exists // - only have python calls now //MosaicGDAL.prepareEnvironment(session, tempPath.toAbsolutePath.toString) From fa2fc5cce6763122dc901c74ad12b6544c0b6a76 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Wed, 17 Jan 2024 14:02:15 +0000 Subject: [PATCH 02/26] Add tests for RST_stats expressions. --- .../mosaic/expressions/raster/RST_Avg.scala | 2 +- .../expressions/raster/RST_AvgBehaviors.scala | 48 +++++++++++++++++++ .../expressions/raster/RST_AvgTest.scala | 32 +++++++++++++ .../expressions/raster/RST_MaxBehaviors.scala | 4 -- .../expressions/raster/RST_MinBehaviors.scala | 48 +++++++++++++++++++ .../expressions/raster/RST_MinTest.scala | 32 +++++++++++++ 6 files changed, 161 insertions(+), 5 deletions(-) create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala index be82af449..82752cad4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala @@ -39,7 +39,7 @@ case class RST_Avg(raster: Expression, expressionConfig: MosaicExpressionConfig) /** Expression info required for the expression registration for spark SQL. */ object RST_Avg extends WithExpressionInfo { - override def name: String = "rst_mean" + override def name: String = "rst_avg" override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band." diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala new file mode 100644 index 000000000..f01ce2d25 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_AvgBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_avg($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_avg(tile) from source + |""".stripMargin) + + val result = df.as[Double].collect().max + + result > 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_avg() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala new file mode 100644 index 000000000..6805f0723 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_AvgTest extends QueryTest with SharedSparkSessionGDAL with RST_AvgBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_avg behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala index 9c095488d..daab1ee90 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala @@ -35,10 +35,6 @@ trait RST_MaxBehaviors extends QueryTest { |select rst_max(tile) from source |""".stripMargin) - noException should be thrownBy rastersInMemory - .withColumn("result", rst_rastertogridmax($"tile", lit(3))) - .select("result") - val result = df.as[Double].collect().max result > 0 shouldBe true diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala new file mode 100644 index 000000000..bd867ee65 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_MinBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_min($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_min(tile) from source + |""".stripMargin) + + val result = df.as[Double].collect().min + + result < 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_min() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala new file mode 100644 index 000000000..ec09792f9 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MinTest extends QueryTest with SharedSparkSessionGDAL with RST_MinBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_min behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} From 79ff6e6feeed61d98e542d9db8467bf6a595a4d1 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Thu, 1 Feb 2024 15:40:17 +0000 Subject: [PATCH 03/26] Fix the format name for grib files in tests. Fix temp location utils. Separate temp on worker location and off worker location. Introduce GDAL Block notion. Implement Kernel filters via GDALBlocks. Add additional params to gdal programs when they run. Fix TILED=YES issue with TIF files. Introduce writeOptions concept for tmp writes. Update expressions to take into account new concepts. Fix Zarr format issues with SerDeser. --- .../labs/mosaic/core/raster/api/GDAL.scala | 9 +- .../mosaic/core/raster/gdal/GDALBlock.scala | 186 ++++++++++++ .../raster/gdal/MosaicRasterBandGDAL.scala | 77 +++++ .../core/raster/gdal/MosaicRasterGDAL.scala | 228 +++++++++++---- .../gdal/MosaicRasterWriteOptions.scala | 55 ++++ .../mosaic/core/raster/gdal/Padding.scala | 58 ++++ .../operator/clip/RasterClipByVector.scala | 10 +- .../raster/operator/gdal/GDALBuildVRT.scala | 16 +- .../core/raster/operator/gdal/GDALCalc.scala | 17 +- .../raster/operator/gdal/GDALTranslate.scala | 25 +- .../core/raster/operator/gdal/GDALWarp.scala | 5 +- .../operator/gdal/OperatorOptions.scala | 35 +++ .../raster/operator/merge/MergeBands.scala | 14 +- .../raster/operator/merge/MergeRasters.scala | 18 +- .../operator/pixel/PixelCombineRasters.scala | 8 +- .../raster/operator/proj/RasterProject.scala | 7 +- .../operator/retile/OverlappingTiles.scala | 5 +- .../operator/retile/RasterTessellate.scala | 5 +- .../core/raster/operator/retile/ReTile.scala | 7 +- .../mosaic/core/types/RasterTileType.scala | 43 ++- .../core/types/model/MosaicRasterTile.scala | 26 +- .../mosaic/datasource/gdal/ReTileOnRead.scala | 15 +- .../mosaic/datasource/gdal/ReadInMemory.scala | 4 +- .../multiread/RasterAsGridReader.scala | 40 ++- .../mosaic/expressions/raster/RST_Avg.scala | 6 +- .../expressions/raster/RST_BandMetaData.scala | 3 +- .../expressions/raster/RST_BoundingBox.scala | 4 +- .../mosaic/expressions/raster/RST_Clip.scala | 3 +- .../expressions/raster/RST_CombineAvg.scala | 8 +- .../raster/RST_CombineAvgAgg.scala | 28 +- .../expressions/raster/RST_Convolve.scala | 73 +++++ .../expressions/raster/RST_DerivedBand.scala | 8 +- .../raster/RST_DerivedBandAgg.scala | 19 +- .../expressions/raster/RST_Filter.scala | 77 +++++ .../expressions/raster/RST_FromBands.scala | 8 +- .../expressions/raster/RST_FromContent.scala | 25 +- .../expressions/raster/RST_FromFile.scala | 10 +- .../expressions/raster/RST_GeoReference.scala | 4 +- .../expressions/raster/RST_GetNoData.scala | 5 +- .../raster/RST_GetSubdataset.scala | 13 +- .../expressions/raster/RST_Height.scala | 4 +- .../expressions/raster/RST_InitNoData.scala | 8 +- .../expressions/raster/RST_IsEmpty.scala | 4 +- .../expressions/raster/RST_MakeTiles.scala | 205 +++++++++++++ .../expressions/raster/RST_MapAlgebra.scala | 8 +- .../mosaic/expressions/raster/RST_Max.scala | 4 +- .../expressions/raster/RST_Median.scala | 7 +- .../expressions/raster/RST_MemSize.scala | 4 +- .../mosaic/expressions/raster/RST_Merge.scala | 8 +- .../expressions/raster/RST_MergeAgg.scala | 19 +- .../expressions/raster/RST_MetaData.scala | 4 +- .../mosaic/expressions/raster/RST_Min.scala | 5 +- .../mosaic/expressions/raster/RST_NDVI.scala | 8 +- .../expressions/raster/RST_NumBands.scala | 4 +- .../expressions/raster/RST_PixelCount.scala | 4 +- .../expressions/raster/RST_PixelHeight.scala | 4 +- .../expressions/raster/RST_PixelWidth.scala | 4 +- .../raster/RST_RasterToWorldCoord.scala | 4 +- .../raster/RST_RasterToWorldCoordX.scala | 4 +- .../raster/RST_RasterToWorldCoordY.scala | 4 +- .../expressions/raster/RST_ReTile.scala | 3 + .../expressions/raster/RST_Rotation.scala | 4 +- .../mosaic/expressions/raster/RST_SRID.scala | 4 +- .../expressions/raster/RST_ScaleX.scala | 4 +- .../expressions/raster/RST_ScaleY.scala | 4 +- .../expressions/raster/RST_SetNoData.scala | 8 +- .../mosaic/expressions/raster/RST_SkewX.scala | 4 +- .../mosaic/expressions/raster/RST_SkewY.scala | 4 +- .../expressions/raster/RST_Subdatasets.scala | 3 +- .../expressions/raster/RST_Summary.scala | 4 +- .../expressions/raster/RST_TryOpen.scala | 4 +- .../expressions/raster/RST_UpperLeftX.scala | 4 +- .../expressions/raster/RST_UpperLeftY.scala | 4 +- .../mosaic/expressions/raster/RST_Width.scala | 4 +- .../raster/RST_WorldToRasterCoord.scala | 5 +- .../raster/RST_WorldToRasterCoordX.scala | 4 +- .../raster/RST_WorldToRasterCoordY.scala | 4 +- .../raster/base/Raster1ArgExpression.scala | 17 +- .../raster/base/Raster2ArgExpression.scala | 16 +- .../base/RasterArray1ArgExpression.scala | 12 +- .../base/RasterArray2ArgExpression.scala | 12 +- .../raster/base/RasterArrayExpression.scala | 10 +- .../raster/base/RasterArrayUtils.scala | 8 +- .../raster/base/RasterBandExpression.scala | 17 +- .../raster/base/RasterExpression.scala | 16 +- .../base/RasterExpressionSerialization.scala | 4 +- .../base/RasterGeneratorExpression.scala | 11 +- .../RasterTessellateGeneratorExpression.scala | 15 +- .../raster/base/RasterToGridExpression.scala | 4 +- .../labs/mosaic/functions/MosaicContext.scala | 24 +- .../functions/MosaicExpressionConfig.scala | 2 + .../labs/mosaic/gdal/MosaicGDAL.scala | 66 +++-- .../com/databricks/labs/mosaic/package.scala | 6 +- .../labs/mosaic/utils/FileUtils.scala | 6 +- .../labs/mosaic/utils/PathUtils.scala | 66 ++++- .../labs/mosaic/utils/SysUtils.scala | 44 ++- ...-041ac051-015d-49b0-95df-b5daa7084c7e.grb} | Bin ...1-015d-49b0-95df-b5daa7084c7e.grb.aux.xml} | 0 ...-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb} | Bin ...6-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml} | 0 ...-0ede0273-89e3-4100-a0f2-48916ca607ed.grb} | Bin ...3-89e3-4100-a0f2-48916ca607ed.grb.aux.xml} | 0 .../core/raster/TestRasterBandGDAL.scala | 4 +- .../mosaic/core/raster/TestRasterGDAL.scala | 223 +++++++++++++- .../datasource/GDALFileFormatTest.scala | 19 +- .../multiread/RasterAsGridReaderTest.scala | 272 +++++++++--------- .../raster/RST_CombineAvgBehaviors.scala | 4 +- .../raster/RST_FilterBehaviors.scala | 36 +++ .../expressions/raster/RST_FilterTest.scala | 32 +++ .../expressions/raster/RST_MinBehaviors.scala | 2 +- .../raster/RST_TessellateBehaviors.scala | 11 +- .../databricks/labs/mosaic/test/package.scala | 2 +- .../sql/test/MosaicTestSparkSession.scala | 4 +- .../sql/test/SharedSparkSessionGDAL.scala | 6 +- 114 files changed, 2005 insertions(+), 560 deletions(-) create mode 100644 src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib => adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml => adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb.aux.xml} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib => adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib.aux.xml => adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib => adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb} (100%) rename src/test/resources/binary/grib-cams/{adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml => adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb.aux.xml} (100%) create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala index 66bde39a3..b86489359 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterBandGDAL, Mosaic import com.databricks.labs.mosaic.core.raster.io.RasterCleaner import com.databricks.labs.mosaic.core.raster.operator.transform.RasterTransform import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.gdal.MosaicGDAL import com.databricks.labs.mosaic.gdal.MosaicGDAL.configureGDAL import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.{BinaryType, DataType, StringType} @@ -114,6 +115,8 @@ object GDAL { } else { raster } + case _ => + throw new IllegalArgumentException(s"Unsupported data type: $inputDT") } } @@ -122,19 +125,17 @@ object GDAL { * * @param generatedRasters * The rasters to write. - * @param checkpointPath - * The path to write the rasters to. * @return * Returns the paths of the written rasters. */ - def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], checkpointPath: String, rasterDT: DataType): Seq[Any] = { + def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], rasterDT: DataType): Seq[Any] = { generatedRasters.map(raster => if (raster != null) { rasterDT match { case StringType => val uuid = UUID.randomUUID().toString val extension = GDAL.getExtension(raster.getDriversShortName) - val writePath = s"$checkpointPath/$uuid.$extension" + val writePath = s"${MosaicGDAL.checkpointPath}/$uuid.$extension" val outPath = raster.writeToPath(writePath) RasterCleaner.dispose(raster) UTF8String.fromString(outPath) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala new file mode 100644 index 000000000..8c5c7a495 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala @@ -0,0 +1,186 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +import scala.reflect.ClassTag + +case class GDALBlock[T: ClassTag]( + block: Array[T], + maskBlock: Array[Double], + noDataValue: Double, + xOffset: Int, + yOffset: Int, + width: Int, + height: Int, + padding: Padding +)(implicit + num: Numeric[T] +) { + + def elementAt(index: Int): T = block(index) + + def maskAt(index: Int): Double = maskBlock(index) + + def elementAt(x: Int, y: Int): T = block(y * width + x) + + def maskAt(x: Int, y: Int): Double = maskBlock(y * width + x) + + def rasterElementAt(x: Int, y: Int): T = block((y - yOffset) * width + (x - xOffset)) + + def rasterMaskAt(x: Int, y: Int): Double = maskBlock((y - yOffset) * width + (x - xOffset)) + + def valuesAt(x: Int, y: Int, kernelWidth: Int, kernelHeight: Int): Array[Double] = { + val kernelCenterX = kernelWidth / 2 + val kernelCenterY = kernelHeight / 2 + val values = Array.fill[Double](kernelWidth * kernelHeight)(noDataValue) + var n = 0 + for (i <- 0 until kernelHeight) { + for (j <- 0 until kernelWidth) { + val xIndex = x + (j - kernelCenterX) + val yIndex = y + (i - kernelCenterY) + if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) { + val maskValue = maskAt(xIndex, yIndex) + val value = elementAt(xIndex, yIndex) + if (maskValue != 0.0 && num.toDouble(value) != noDataValue) { + values(n) = num.toDouble(value) + n += 1 + } + } + } + } + val result = values.filter(_ != noDataValue) + // always return only one NoDataValue if no valid values are found + // one and only one NoDataValue can be returned + // in all cases that have some valid values, the NoDataValue will be filtered out + if (result.isEmpty) { + Array(noDataValue) + } else { + result + } + } + + // TODO: Test and fix, not tested, other filters work. + def convolveAt(x: Int, y: Int, kernel: Array[Array[Double]]): Double = { + val kernelWidth = kernel.head.length + val kernelHeight = kernel.length + val kernelCenterX = kernelWidth / 2 + val kernelCenterY = kernelHeight / 2 + var sum = 0.0 + for (i <- 0 until kernelHeight) { + for (j <- 0 until kernelWidth) { + val xIndex = x + (j - kernelCenterX) + val yIndex = y + (i - kernelCenterY) + if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) { + val maskValue = maskAt(xIndex, yIndex) + val value = rasterElementAt(xIndex, yIndex) + if (maskValue != 0.0 && num.toDouble(value) != noDataValue) { + sum += num.toDouble(value) * kernel(i)(j) + } + } + } + } + sum + } + + def avgFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + values.sum / values.length + } + + def minFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + values.min + } + + def maxFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + values.max + } + + def medianFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + val n = values.length + scala.util.Sorting.quickSort(values) + values(n / 2) + } + + def modeFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + val counts = values.groupBy(identity).mapValues(_.length) + counts.maxBy(_._2)._1 + } + + def trimBlock(stride: Int): GDALBlock[Double] = { + val resultBlock = padding.removePadding(block.map(num.toDouble), width, stride) + val resultMask = padding.removePadding(maskBlock, width, stride) + + val newOffset = padding.newOffset(xOffset, yOffset, stride) + val newSize = padding.newSize(width, height, stride) + + new GDALBlock[Double]( + resultBlock, + resultMask, + noDataValue, + newOffset._1, + newOffset._2, + newSize._1, + newSize._2, + Padding.NoPadding + ) + } + +} + +object GDALBlock { + + def getSize(offset: Int, maxSize: Int, blockSize: Int, stride: Int, paddingStrides: Int): Int = { + if (offset + blockSize + paddingStrides * stride > maxSize) { + maxSize - offset + } else { + blockSize + paddingStrides * stride + } + } + + def apply( + band: MosaicRasterBandGDAL, + stride: Int, + xOffset: Int, + yOffset: Int, + blockSize: Int + ): GDALBlock[Double] = { + val noDataValue = band.noDataValue + val rasterWidth = band.xSize + val rasterHeight = band.ySize + // Always read blockSize + stride pixels on every side + // This is fine since kernel size is always much smaller than blockSize + + val padding = Padding( + left = xOffset != 0, + right = xOffset + blockSize < rasterWidth - 1, // not sure about -1 + top = yOffset != 0, + bottom = yOffset + blockSize < rasterHeight - 1 + ) + + val xo = Math.max(0, xOffset - stride) + val yo = Math.max(0, yOffset - stride) + + val xs = getSize(xo, rasterWidth, blockSize, stride, padding.horizontalStrides) + val ys = getSize(yo, rasterHeight, blockSize, stride, padding.verticalStrides) + + val block = Array.ofDim[Double](xs * ys) + val maskBlock = Array.ofDim[Double](xs * ys) + + band.getBand.ReadRaster(xo, yo, xs, ys, block) + band.getBand.GetMaskBand().ReadRaster(xo, yo, xs, ys, maskBlock) + + GDALBlock( + block, + maskBlock, + noDataValue, + xo, + yo, + xs, + ys, + padding + ) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala index a7c9ece10..281eb8b01 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.core.raster.gdal +import com.databricks.labs.mosaic.gdal.MosaicGDAL import org.gdal.gdal.Band import org.gdal.gdalconst.gdalconstConstants @@ -255,4 +256,80 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) { */ def isNoDataMask: Boolean = band.GetMaskFlags() == gdalconstConstants.GMF_NODATA + def convolve(kernel: Array[Array[Double]]): Unit = { + val kernelWidth = kernel.head.length + val kernelHeight = kernel.length + val blockSize = MosaicGDAL.defaultBlockSize + val strideX = kernelWidth / 2 + val strideY = kernelHeight / 2 + + val block = Array.ofDim[Double](blockSize * blockSize) + val maskBlock = Array.ofDim[Double](blockSize * blockSize) + val result = Array.ofDim[Double](blockSize * blockSize) + + for (yOffset <- 0 until ySize by blockSize - strideY) { + for (xOffset <- 0 until xSize by blockSize - strideX) { + val xSize = Math.min(blockSize, this.xSize - xOffset) + val ySize = Math.min(blockSize, this.ySize - yOffset) + + band.ReadRaster(xOffset, yOffset, xSize, ySize, block) + band.GetMaskBand().ReadRaster(xOffset, yOffset, xSize, ySize, maskBlock) + + val currentBlock = GDALBlock[Double](block, maskBlock, noDataValue, xOffset, yOffset, xSize, ySize, Padding.NoPadding) + + for (y <- 0 until ySize) { + for (x <- 0 until xSize) { + result(y * xSize + x) = currentBlock.convolveAt(x, y, kernel) + } + } + + band.WriteRaster(xOffset, yOffset, xSize, ySize, block) + } + } + } + + def filter(kernelSize: Int, operation: String, outputBand: Band): Unit = { + require(kernelSize % 2 == 1, "Kernel size must be odd") + + val blockSize = MosaicGDAL.blockSize + val stride = kernelSize / 2 + + for (yOffset <- 0 until ySize by blockSize) { + for (xOffset <- 0 until xSize by blockSize) { + + val currentBlock = GDALBlock( + this, + stride, + xOffset, + yOffset, + blockSize + ) + + val result = Array.ofDim[Double](currentBlock.block.length) + + for (y <- 0 until currentBlock.height) { + for (x <- 0 until currentBlock.width) { + result(y * currentBlock.width + x) = operation match { + case "avg" => currentBlock.avgFilterAt(x, y, kernelSize) + case "min" => currentBlock.minFilterAt(x, y, kernelSize) + case "max" => currentBlock.maxFilterAt(x, y, kernelSize) + case "median" => currentBlock.medianFilterAt(x, y, kernelSize) + case "mode" => currentBlock.modeFilterAt(x, y, kernelSize) + case _ => throw new Exception("Invalid operation") + } + } + } + + val trimmedResult = currentBlock.copy(block = result).trimBlock(stride) + + outputBand.WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.block) + outputBand.FlushCache() + outputBand.GetMaskBand().WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.maskBlock) + outputBand.GetMaskBand().FlushCache() + + } + } + + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index 3ac467f53..b63bd851e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -8,9 +8,9 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose import com.databricks.labs.mosaic.core.raster.io.{RasterCleaner, RasterReader, RasterWriter} import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum.POLYGON -import com.databricks.labs.mosaic.utils.{FileUtils, PathUtils} -import org.gdal.gdal.gdal.GDALInfo -import org.gdal.gdal.{Dataset, InfoOptions, gdal} +import com.databricks.labs.mosaic.gdal.MosaicGDAL +import com.databricks.labs.mosaic.utils.{FileUtils, PathUtils, SysUtils} +import org.gdal.gdal.{Dataset, gdal} import org.gdal.gdalconst.gdalconstConstants._ import org.gdal.osr import org.gdal.osr.SpatialReference @@ -32,25 +32,41 @@ case class MosaicRasterGDAL( ) extends RasterWriter with RasterCleaner { + def getWriteOptions: MosaicRasterWriteOptions = MosaicRasterWriteOptions(this) + + def getCompression: String = { + val compression = Option(raster.GetMetadata_Dict("IMAGE_STRUCTURE")) + .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) + .getOrElse(Map.empty[String, String]) + .getOrElse("COMPRESSION", "NONE") + compression + } + def getSpatialReference: SpatialReference = { - if (raster != null) { - raster.GetSpatialRef + val spatialRef = + if (raster != null) { + raster.GetSpatialRef + } else { + val tmp = refresh() + val result = tmp.raster.GetSpatialRef + dispose(tmp) + result + } + if (spatialRef == null) { + MosaicGDAL.WSG84 } else { - val tmp = refresh() - val result = tmp.spatialRef - dispose(tmp) - result + spatialRef } } + def isSubDataset: Boolean = { + val isSubdataset = PathUtils.isSubdataset(path) + isSubdataset + } + // Factory for creating CRS objects protected val crsFactory: CRSFactory = new CRSFactory - // Only use this with GDAL rasters - private val wsg84 = new osr.SpatialReference() - wsg84.ImportFromEPSG(4326) - wsg84.SetAxisMappingStrategy(osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) - /** * @return * The raster's driver short name. @@ -157,6 +173,7 @@ case class MosaicRasterGDAL( .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) .getOrElse(Map.empty[String, String]) val keys = subdatasetsMap.keySet + val sanitizedParentPath = PathUtils.getCleanPath(parentPath) keys.flatMap(key => if (key.toUpperCase(Locale.ROOT).contains("NAME")) { val path = subdatasetsMap(key) @@ -164,7 +181,7 @@ case class MosaicRasterGDAL( Seq( key -> pieces.last, s"${pieces.last}_tmp" -> path, - pieces.last -> s"${pieces.head}:$parentPath:${pieces.last}" + pieces.last -> s"${pieces.head}:$sanitizedParentPath:${pieces.last}" ) } else Seq(key -> subdatasetsMap(key)) ).toMap @@ -253,12 +270,6 @@ case class MosaicRasterGDAL( */ def getRaster: Dataset = this.raster - /** - * @return - * Returns the raster's spatial reference. - */ - def spatialRef: SpatialReference = raster.GetSpatialRef() - /** * Applies a function to each band of the raster. * @param f @@ -272,10 +283,10 @@ case class MosaicRasterGDAL( * @return * Returns MosaicGeometry representing bounding box of the raster. */ - def bbox(geometryAPI: GeometryAPI, destCRS: SpatialReference = wsg84): MosaicGeometry = { + def bbox(geometryAPI: GeometryAPI, destCRS: SpatialReference = MosaicGDAL.WSG84): MosaicGeometry = { val gt = getGeoTransform - val sourceCRS = spatialRef + val sourceCRS = getSpatialReference val transform = new osr.CoordinateTransformation(sourceCRS, destCRS) val bbox = geometryAPI.geometry( @@ -300,23 +311,10 @@ case class MosaicRasterGDAL( * compute since it requires reading the raster and computing statistics. */ def isEmpty: Boolean = { - import org.json4s._ - import org.json4s.jackson.JsonMethods._ - implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats - - val vector = new JVector[String]() - vector.add("-stats") - vector.add("-json") - val infoOptions = new InfoOptions(vector) - val gdalInfo = GDALInfo(raster, infoOptions) - val json = parse(gdalInfo).extract[Map[String, Any]] - - if (json.contains("STATISTICS_VALID_PERCENT")) { - json("STATISTICS_VALID_PERCENT").asInstanceOf[Double] == 0.0 - } else if (subdatasets.nonEmpty) { + if (subdatasets.nonEmpty) { false } else { - getBandStats.values.map(_.getOrElse("mean", 0.0)).forall(_ == 0.0) + getValidCount.values.sum == 0 } } @@ -347,11 +345,18 @@ case class MosaicRasterGDAL( val isSubdataset = PathUtils.isSubdataset(path) val filePath = if (isSubdataset) PathUtils.fromSubdatasetPath(path) else path val pamFilePath = s"$filePath.aux.xml" + val cleanPath = filePath.replace("/vsizip/", "") + val zipPath = if (cleanPath.endsWith("zip")) cleanPath else s"$cleanPath.zip" if (path != PathUtils.getCleanPath(parentPath)) { Try(gdal.GetDriverByName(driverShortName).Delete(path)) + Try(Files.deleteIfExists(Paths.get(cleanPath))) Try(Files.deleteIfExists(Paths.get(path))) Try(Files.deleteIfExists(Paths.get(filePath))) Try(Files.deleteIfExists(Paths.get(pamFilePath))) + if (Files.exists(Paths.get(zipPath))) { + Try(Files.deleteIfExists(Paths.get(zipPath.replace(".zip", "")))) + } + Try(Files.deleteIfExists(Paths.get(zipPath))) } } @@ -382,12 +387,26 @@ case class MosaicRasterGDAL( * A boolean indicating if the write was successful. */ def writeToPath(path: String, dispose: Boolean = true): String = { - val driver = raster.GetDriver() - val ds = driver.CreateCopy(path, this.flushCache().getRaster) - ds.FlushCache() - ds.delete() - if (dispose) RasterCleaner.dispose(this) - path + if (isSubDataset) { + val driver = raster.GetDriver() + val ds = driver.CreateCopy(path, this.flushCache().getRaster, 1) + if (ds == null) { + val error = gdal.GetLastErrorMsg() + throw new Exception(s"Error writing raster to path: $error") + } + ds.FlushCache() + ds.delete() + if (dispose) RasterCleaner.dispose(this) + path + } else { + val thisPath = Paths.get(this.path) + val fromDir = thisPath.getParent + val toDir = Paths.get(path).getParent + val stemRegex = PathUtils.getStemRegex(this.path) + PathUtils.wildcardCopy(fromDir.toString, toDir.toString, stemRegex) + if (dispose) RasterCleaner.dispose(this) + s"$toDir/${thisPath.getFileName}" + } } /** @@ -398,17 +417,33 @@ case class MosaicRasterGDAL( */ def writeToBytes(dispose: Boolean = true): Array[Byte] = { val isSubdataset = PathUtils.isSubdataset(path) - val readPath = - if (isSubdataset) { - val tmpPath = PathUtils.createTmpFilePath(getRasterFileExtension) - writeToPath(tmpPath, dispose = false) + val readPath = { + val tmpPath = + if (isSubdataset) { + val tmpPath = PathUtils.createTmpFilePath(getRasterFileExtension) + writeToPath(tmpPath, dispose = false) + tmpPath + } else { + path + } + if (Files.isDirectory(Paths.get(tmpPath))) { + SysUtils.runCommand(s"zip -r0 $tmpPath.zip $tmpPath") + s"$tmpPath.zip" } else { - path + tmpPath } + } val byteArray = FileUtils.readBytes(readPath) if (dispose) RasterCleaner.dispose(this) if (readPath != PathUtils.getCleanPath(parentPath)) { Files.deleteIfExists(Paths.get(readPath)) + if (readPath.endsWith(".zip")) { + val nonZipPath = readPath.replace(".zip", "") + if (Files.isDirectory(Paths.get(nonZipPath))) { + SysUtils.runCommand(s"rm -rf $nonZipPath") + } + Files.deleteIfExists(Paths.get(readPath.replace(".zip", ""))) + } } byteArray } @@ -464,6 +499,20 @@ case class MosaicRasterGDAL( .toMap } + /** + * @return + * Returns the raster's band valid pixel count. + */ + def getValidCount: Map[Int, Long] = { + (1 to numBands) + .map(i => { + val band = raster.GetRasterBand(i) + val validCount = band.AsMDArray().GetStatistics().getValid_count + i -> validCount + }) + .toMap + } + /** * @param subsetName * The name of the subdataset to get. @@ -471,24 +520,59 @@ case class MosaicRasterGDAL( * Returns the raster's subdataset with given name. */ def getSubdataset(subsetName: String): MosaicRasterGDAL = { - subdatasets - val path = Option(raster.GetMetadata_Dict("SUBDATASETS")) - .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) - .getOrElse(Map.empty[String, String]) - .values - .find(_.toUpperCase(Locale.ROOT).endsWith(subsetName.toUpperCase(Locale.ROOT))) - .getOrElse(throw new Exception(s""" - |Subdataset $subsetName not found! - |Available subdatasets: - | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} - """.stripMargin)) - val ds = openRaster(path) + val path = subdatasets.getOrElse( + s"${subsetName}_tmp", + throw new Exception(s""" + |Subdataset $subsetName not found! + |Available subdatasets: + | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} + | """.stripMargin) + ) + val sanitized = PathUtils.getCleanPath(path) + val subdatasetPath = PathUtils.getSubdatasetPath(sanitized) + + val ds = openRaster(subdatasetPath) // Avoid costly IO to compute MEM size here // It will be available when the raster is serialized for next operation // If value is needed then it will be computed when getMemSize is called MosaicRasterGDAL(ds, path, parentPath, driverShortName, -1) } + def convolve(kernel: Array[Array[Double]]): MosaicRasterGDAL = { + val resultRasterPath = PathUtils.createTmpFilePath(getRasterFileExtension) + val outputRaster = this.raster + .GetDriver() + .Create(resultRasterPath, this.xSize, this.ySize, this.numBands, this.raster.GetRasterBand(1).getDataType) + + for (bandIndex <- 1 to this.numBands) { + val band = this.getBand(bandIndex) + band.convolve(kernel) + } + + MosaicRasterGDAL(outputRaster, resultRasterPath, parentPath, driverShortName, -1) + + } + + def filter(kernelSize: Int, operation: String): MosaicRasterGDAL = { + val resultRasterPath = PathUtils.createTmpFilePath(getRasterFileExtension) + + this.raster + .GetDriver() + .CreateCopy(resultRasterPath, this.raster, 1) + .delete() + + val outputRaster = gdal.Open(resultRasterPath, GF_Write) + + for (bandIndex <- 1 to this.numBands) { + val band = this.getBand(bandIndex) + val outputBand = outputRaster.GetRasterBand(bandIndex) + band.filter(kernelSize, operation, outputBand) + } + + val result = MosaicRasterGDAL(outputRaster, resultRasterPath, parentPath, driverShortName, this.memSize) + result.flushCache() + } + } //noinspection ZeroIndexToHead @@ -583,11 +667,29 @@ object MosaicRasterGDAL extends RasterReader { // Try reading as a tmp file, if that fails, rename as a zipped file val dataset = openRaster(tmpPath, Some(driverShortName)) if (dataset == null) { - val zippedPath = PathUtils.createTmpFilePath("zip") + val zippedPath = s"$tmpPath.zip" Files.move(Paths.get(tmpPath), Paths.get(zippedPath), StandardCopyOption.REPLACE_EXISTING) val readPath = PathUtils.getZipPath(zippedPath) val ds = openRaster(readPath, Some(driverShortName)) - MosaicRasterGDAL(ds, readPath, parentPath, driverShortName, contentBytes.length) + if (ds == null) { + // the way we zip using uuid is not compatible with GDAL + // we need to unzip and read the file if it was zipped by us + val parentDir = Paths.get(zippedPath).getParent + val prompt = SysUtils.runScript(Array("/bin/sh", "-c", s"cd $parentDir && unzip -o $zippedPath -d /")) + // zipped files will have the old uuid name of the raster + // we need to get the last extracted file name, but the last extracted file name is not the raster name + // we can't list folders due to concurrent writes + val extension = GDAL.getExtension(driverShortName) + val lastExtracted = SysUtils.getLastOutputLine(prompt) + val unzippedPath = PathUtils.parseUnzippedPathFromExtracted(lastExtracted, extension) + val dataset = openRaster(unzippedPath, Some(driverShortName)) + if (dataset == null) { + throw new Exception(s"Error reading raster from bytes: ${prompt._3}") + } + MosaicRasterGDAL(dataset, unzippedPath, parentPath, driverShortName, contentBytes.length) + } else { + MosaicRasterGDAL(ds, readPath, parentPath, driverShortName, contentBytes.length) + } } else { MosaicRasterGDAL(dataset, tmpPath, parentPath, driverShortName, contentBytes.length) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala new file mode 100644 index 000000000..68a7bd75a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala @@ -0,0 +1,55 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +import com.databricks.labs.mosaic.gdal.MosaicGDAL +import org.gdal.osr.SpatialReference + +case class MosaicRasterWriteOptions( + compression: String = "DEFLATE", + format: String = "GTiff", + extension: String = "tif", + resampling: String = "nearest", + crs: SpatialReference = MosaicGDAL.WSG84, // Assume WGS84 + pixelSize: Option[(Double, Double)] = None, + noDataValue: Option[Double] = None, + missingGeoRef: Boolean = false, + options: Map[String, String] = Map.empty[String, String] +) + +object MosaicRasterWriteOptions { + + val VRT: MosaicRasterWriteOptions = + MosaicRasterWriteOptions( + compression = "NONE", + format = "VRT", + extension = "vrt", + crs = MosaicGDAL.WSG84, + pixelSize = None, + noDataValue = None, + options = Map.empty[String, String] + ) + + val GTiff: MosaicRasterWriteOptions = MosaicRasterWriteOptions() + + def noGPCsNoTransform(raster: MosaicRasterGDAL): Boolean = { + val noGPCs = raster.getRaster.GetGCPCount == 0 + val noGeoTransform = raster.getRaster.GetGeoTransform == null || + (raster.getRaster.GetGeoTransform sameElements Array(0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) + noGPCs && noGeoTransform + } + + def apply(): MosaicRasterWriteOptions = new MosaicRasterWriteOptions() + + def apply(raster: MosaicRasterGDAL): MosaicRasterWriteOptions = { + val compression = raster.getCompression + val format = raster.getRaster.GetDriver.getShortName + val extension = raster.getRasterFileExtension + val resampling = "nearest" + val pixelSize = None + val noDataValue = None + val options = Map.empty[String, String] + val crs = raster.getSpatialReference + val missingGeoRef = noGPCsNoTransform(raster) + new MosaicRasterWriteOptions(compression, format, extension, resampling, crs, pixelSize, noDataValue, missingGeoRef, options) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala new file mode 100644 index 000000000..bb32e772f --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala @@ -0,0 +1,58 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +case class Padding( + left: Boolean, + right: Boolean, + top: Boolean, + bottom: Boolean +) { + + def removePadding(array: Array[Double], rowWidth: Int, stride: Int): Array[Double] = { + val l = if (left) 1 else 0 + val r = if (right) 1 else 0 + val t = if (top) 1 else 0 + val b = if (bottom) 1 else 0 + + val yStart = t * stride * rowWidth + val yEnd = array.length - b * stride * rowWidth + + val slices = for (i <- yStart until yEnd by rowWidth) yield { + val xStart = i + l * stride + val xEnd = i + rowWidth - r * stride + array.slice(xStart, xEnd) + } + + slices.flatten.toArray + } + + def horizontalStrides: Int = { + if (left && right) 2 + else if (left || right) 1 + else 0 + } + + def verticalStrides: Int = { + if (top && bottom) 2 + else if (top || bottom) 1 + else 0 + } + + def newOffset(xOffset: Int, yOffset: Int, stride: Int): (Int, Int) = { + val x = if (left) xOffset + stride else xOffset + val y = if (top) yOffset + stride else yOffset + (x, y) + } + + def newSize(width: Int, height: Int, stride: Int): (Int, Int) = { + val w = if (left && right) width - 2 * stride else if (left || right) width - stride else width + val h = if (top && bottom) height - 2 * stride else if (top || bottom) height - stride else height + (w, h) + } + +} + +object Padding { + + val NoPadding: Padding = Padding(left = false, right = false, top = false, bottom = false) + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala index 6daabc25c..56c29563f 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala @@ -19,8 +19,7 @@ object RasterClipByVector { * abstractions over GDAL Warp. It uses CUTLINE_ALL_TOUCHED=TRUE to ensure * that all pixels that touch the geometry are included. This will avoid * the issue of having a pixel that is half in and half out of the - * geometry, important for tessellation. It also uses COMPRESS=DEFLATE to - * ensure that the output is compressed. The method also uses the geometry + * geometry, important for tessellation. The method also uses the geometry * API to generate a shapefile that is used to clip the raster. The * shapefile is deleted after the clip is complete. * @@ -38,16 +37,19 @@ object RasterClipByVector { def clip(raster: MosaicRasterGDAL, geometry: MosaicGeometry, geomCRS: SpatialReference, geometryAPI: GeometryAPI): MosaicRasterGDAL = { val rasterCRS = raster.getSpatialReference val outShortName = raster.getDriversShortName - val geomSrcCRS = if (geomCRS == null ) rasterCRS else geomCRS + val geomSrcCRS = if (geomCRS == null) rasterCRS else geomCRS val resultFileName = PathUtils.createTmpFilePath(GDAL.getExtension(outShortName)) val shapeFileName = VectorClipper.generateClipper(geometry, geomSrcCRS, rasterCRS, geometryAPI) + // For -wo consult https://gdal.org/doxygen/structGDALWarpOptions.html + // SOURCE_EXTRA=3 is used to ensure that when the raster is clipped, the + // pixels that touch the geometry are included. The default is 1, 3 seems to be a good empirical value. val result = GDALWarp.executeWarp( resultFileName, Seq(raster), - command = s"gdalwarp -wo CUTLINE_ALL_TOUCHED=TRUE -of $outShortName -cutline $shapeFileName -crop_to_cutline -co COMPRESS=DEFLATE -dstalpha" + command = s"gdalwarp -wo CUTLINE_ALL_TOUCHED=TRUE -wo SOURCE_EXTRA=3 -cutline $shapeFileName -crop_to_cutline" ) VectorClipper.cleanUpClipper(shapeFileName) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala index 389defad6..9e1e97401 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import org.gdal.gdal.{BuildVRTOptions, gdal} /** GDALBuildVRT is a wrapper for the GDAL BuildVRT command. */ @@ -20,16 +20,16 @@ object GDALBuildVRT { */ def executeVRT(outputPath: String, rasters: Seq[MosaicRasterGDAL], command: String): MosaicRasterGDAL = { require(command.startsWith("gdalbuildvrt"), "Not a valid GDAL Build VRT command.") - val vrtOptionsVec = OperatorOptions.parseOptions(command) + val effectiveCommand = OperatorOptions.appendOptions(command, MosaicRasterWriteOptions.VRT) + val vrtOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val vrtOptions = new BuildVRTOptions(vrtOptionsVec) val result = gdal.BuildVRT(outputPath, rasters.map(_.getRaster).toArray, vrtOptions) if (result == null) { - throw new Exception( - s""" - |Build VRT failed. - |Command: $command - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) + throw new Exception(s""" + |Build VRT failed. + |Command: $effectiveCommand + |Error: ${gdal.GetLastErrorMsg} + |""".stripMargin) } // TODO: Figure out multiple parents, should this be an array? // VRT files are just meta files, mem size doesnt make much sense so we keep -1 diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala index 97a273d13..cc9c5e500 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala @@ -1,19 +1,16 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal import com.databricks.labs.mosaic.core.raster.api.GDAL -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import com.databricks.labs.mosaic.utils.SysUtils /** GDALCalc is a helper object for executing GDAL Calc commands. */ object GDALCalc { val gdal_calc: String = { - val calcPath = SysUtils.runCommand("find / -iname gdal_calc.py")._1.split("\n").headOption.getOrElse("") - if (calcPath.isEmpty) { - throw new RuntimeException("Could not find gdal_calc.py.") - } - if (calcPath == "ERROR") { - "/usr/lib/python3/dist-packages/osgeo_utils/gdal_calc.py" + val calcPath = SysUtils.runCommand("""find / -maxdepth 20 -iname gdal_calc.py""")._1.split("\n").headOption.getOrElse("") + if (calcPath.isEmpty || calcPath.startsWith("ERROR")) { + "/usr/local/lib/python3.10/dist-packages/osgeo_utils/gdal_calc.py" } else { calcPath } @@ -30,11 +27,13 @@ object GDALCalc { */ def executeCalc(gdalCalcCommand: String, resultPath: String): MosaicRasterGDAL = { require(gdalCalcCommand.startsWith("gdal_calc"), "Not a valid GDAL Calc command.") - val toRun = gdalCalcCommand.replace("gdal_calc", gdal_calc) + val effectiveCommand = OperatorOptions.appendOptions(gdalCalcCommand, MosaicRasterWriteOptions.GTiff) + val toRun = effectiveCommand.replace("gdal_calc", gdal_calc) val commandRes = SysUtils.runCommand(s"python3 $toRun") - if (commandRes._1 == "ERROR") { + if (commandRes._1.startsWith("ERROR")) { throw new RuntimeException(s""" |GDAL Calc command failed: + |$toRun |STDOUT: |${commandRes._2} |STDERR: diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala index bf266cfbf..fd24a0f73 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import org.gdal.gdal.{TranslateOptions, gdal} import java.nio.file.{Files, Paths} @@ -20,21 +20,26 @@ object GDALTranslate { * @return * A MosaicRaster object. */ - def executeTranslate(outputPath: String, raster: MosaicRasterGDAL, command: String): MosaicRasterGDAL = { + def executeTranslate( + outputPath: String, + raster: MosaicRasterGDAL, + command: String, + writeOptions: MosaicRasterWriteOptions + ): MosaicRasterGDAL = { require(command.startsWith("gdal_translate"), "Not a valid GDAL Translate command.") - val translateOptionsVec = OperatorOptions.parseOptions(command) + val effectiveCommand = OperatorOptions.appendOptions(command, writeOptions) + val translateOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val translateOptions = new TranslateOptions(translateOptionsVec) val result = gdal.Translate(outputPath, raster.getRaster, translateOptions) if (result == null) { - throw new Exception( - s""" - |Translate failed. - |Command: $command - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) + throw new Exception(s""" + |Translate failed. + |Command: $effectiveCommand + |Error: ${gdal.GetLastErrorMsg} + |""".stripMargin) } val size = Files.size(Paths.get(outputPath)) - raster.copy(raster = result, path = outputPath, memSize = size).flushCache() + raster.copy(raster = result, path = outputPath, memSize = size, driverShortName = writeOptions.format).flushCache() } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala index 2b13a957b..ba6dce58d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala @@ -23,7 +23,8 @@ object GDALWarp { def executeWarp(outputPath: String, rasters: Seq[MosaicRasterGDAL], command: String): MosaicRasterGDAL = { require(command.startsWith("gdalwarp"), "Not a valid GDAL Warp command.") // Test: gdal.ParseCommandLine(command) - val warpOptionsVec = OperatorOptions.parseOptions(command) + val effectiveCommand = OperatorOptions.appendOptions(command, rasters.head.getWriteOptions) + val warpOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val warpOptions = new WarpOptions(warpOptionsVec) val result = gdal.Warp(outputPath, rasters.map(_.getRaster).toArray, warpOptions) // TODO: Figure out multiple parents, should this be an array? @@ -31,7 +32,7 @@ object GDALWarp { if (result == null) { throw new Exception(s""" |Warp failed. - |Command: $command + |Command: $effectiveCommand |Error: ${gdal.GetLastErrorMsg} |""".stripMargin) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala index b1529d3e7..bc656ec01 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala @@ -1,5 +1,7 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterWriteOptions + /** OperatorOptions is a helper object for parsing GDAL command options. */ object OperatorOptions { @@ -18,4 +20,37 @@ object OperatorOptions { optionsVec } + /** + * Add default options to the command. Extract the compression from the + * raster and append it to the command. This operation does not change the + * output format. For changing the output format, use RST_ToFormat. + * + * @param command + * The command to append options to. + * @param writeOptions + * The write options to append. Note that not all available options are + * actually appended. At this point it is up to the bellow logic to + * decide what is supported and for which format. + * @return + */ + def appendOptions(command: String, writeOptions: MosaicRasterWriteOptions): String = { + val compression = writeOptions.compression + if (command.startsWith("gdal_calc")) { + writeOptions.format match { + case f @ "GTiff" => command + s" --format $f --co TILED=YES --co COMPRESS=$compression" + case f @ "COG" => command + s" --format $f --co TILED=YES --co COMPRESS=$compression" + case f @ _ => command + s" --format $f --co COMPRESS=$compression" + } + } else { + writeOptions.format match { + case f @ "GTiff" => command + s" -of $f -co TILED=YES -co COMPRESS=$compression" + case f @ "COG" => command + s" -of $f -co TILED=YES -co COMPRESS=$compression" + case "VRT" => command + case f @ "Zarr" if writeOptions.missingGeoRef => + command + s" -of $f -co COMPRESS=$compression -to SRC_METHOD=NO_GEOTRANSFORM" + case f @ _ => command + s" -of $f -co COMPRESS=$compression" + } + } + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala index 6333c50c8..8a82d1238 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala @@ -19,10 +19,10 @@ object MergeBands { * A MosaicRaster object. */ def merge(rasters: Seq[MosaicRasterGDAL], resampling: String): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( vrtPath, @@ -33,7 +33,8 @@ object MergeBands { val result = GDALTranslate.executeTranslate( rasterPath, vrtRaster, - command = s"gdal_translate -r $resampling -of $outShortName -co COMPRESS=DEFLATE" + command = s"gdal_translate -r $resampling", + outOptions ) dispose(vrtRaster) @@ -55,10 +56,10 @@ object MergeBands { * A MosaicRaster object. */ def merge(rasters: Seq[MosaicRasterGDAL], pixel: (Double, Double), resampling: String): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( vrtPath, @@ -69,7 +70,8 @@ object MergeBands { val result = GDALTranslate.executeTranslate( rasterPath, vrtRaster, - command = s"gdalwarp -r $resampling -of $outShortName -co COMPRESS=DEFLATE -overwrite" + command = s"gdalwarp -r $resampling", + outOptions ) dispose(vrtRaster) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala index 694d9940a..fafaffbc4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala @@ -17,21 +17,22 @@ object MergeRasters { * A MosaicRaster object. */ def merge(rasters: Seq[MosaicRasterGDAL]): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( - vrtPath, - rasters, - command = s"gdalbuildvrt -resolution highest" + vrtPath, + rasters, + command = s"gdalbuildvrt -resolution highest" ) val result = GDALTranslate.executeTranslate( - rasterPath, - vrtRaster, - command = s"gdal_translate -r bilinear -of $outShortName -co COMPRESS=DEFLATE" + rasterPath, + vrtRaster, + command = s"gdal_translate", + outOptions ) dispose(vrtRaster) @@ -39,5 +40,4 @@ object MergeRasters { result } - } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala index 5bf49fb96..cda9824dc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.core.raster.operator.pixel import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALBuildVRT, GDALTranslate} +import com.databricks.labs.mosaic.gdal.MosaicGDAL.defaultBlockSize import com.databricks.labs.mosaic.utils.PathUtils import java.io.File @@ -20,10 +21,10 @@ object PixelCombineRasters { * A MosaicRaster object. */ def combine(rasters: Seq[MosaicRasterGDAL], pythonFunc: String, pythonFuncName: String): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( vrtPath, @@ -37,7 +38,8 @@ object PixelCombineRasters { val result = GDALTranslate.executeTranslate( rasterPath, vrtRaster.refresh(), - command = s"gdal_translate -r bilinear -of $outShortName -co COMPRESS=DEFLATE" + command = s"gdal_translate", + outOptions ) dispose(vrtRaster) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala index efd7c8c67..5d7c5f5f2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala @@ -15,8 +15,7 @@ object RasterProject { /** * Projects a raster to a new CRS. The method handles all the abstractions * over GDAL Warp. It uses cubic resampling to ensure that the output is - * smooth. It also uses COMPRESS=DEFLATE to ensure that the output is - * compressed. + * smooth. * * @param raster * The raster to project. @@ -33,11 +32,11 @@ object RasterProject { // Note that Null is the right value here val authName = destCRS.GetAuthorityName(null) val authCode = destCRS.GetAuthorityCode(null) - + val result = GDALWarp.executeWarp( resultFileName, Seq(raster), - command = s"gdalwarp -of $outShortName -t_srs $authName:$authCode -r cubic -overwrite -co COMPRESS=DEFLATE" + command = s"gdalwarp -t_srs $authName:$authCode" ) result diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala index c1498ea05..4e9f61c5e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala @@ -48,12 +48,13 @@ object OverlappingTiles { val fileExtension = GDAL.getExtension(tile.getDriver) val rasterPath = PathUtils.createTmpFilePath(fileExtension) - val shortName = raster.getRaster.GetDriver.getShortName + val outOptions = raster.getWriteOptions val result = GDALTranslate.executeTranslate( rasterPath, raster, - command = s"gdal_translate -of $shortName -srcwin $xOff $yOff $width $height" + command = s"gdal_translate -srcwin $xOff $yOff $width $height", + outOptions ) val isEmpty = result.isEmpty diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala index d186de0a5..fa47c6c1d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala @@ -41,9 +41,10 @@ object RasterTessellate { (false, MosaicRasterTile(cell.index, null, "", "")) } else { val cellRaster = tmpRaster.getRasterForCell(cellID, indexSystem, geometryAPI) - val isValidRaster = cellRaster.getBandStats.values.map(_("mean")).sum > 0 && !cellRaster.isEmpty + val isValidRaster = cellRaster.getValidCount.values.sum > 0 && !cellRaster.isEmpty ( - isValidRaster, MosaicRasterTile(cell.index, cellRaster, raster.getParentPath, raster.getDriversShortName) + isValidRaster, + MosaicRasterTile(cell.index, cellRaster, raster.getParentPath, raster.getDriversShortName) ) } }) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala index edaab4720..f25a1f384 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala @@ -1,7 +1,7 @@ package com.databricks.labs.mosaic.core.raster.operator.retile import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose -import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALBuildVRT, GDALTranslate} +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.utils.PathUtils @@ -39,12 +39,13 @@ object ReTile { val fileExtension = raster.getRasterFileExtension val rasterPath = PathUtils.createTmpFilePath(fileExtension) - val shortDriver = raster.getDriversShortName + val outOptions = raster.getWriteOptions val result = GDALTranslate.executeTranslate( rasterPath, raster, - command = s"gdal_translate -of $shortDriver -srcwin $xMin $yMin $xOffset $yOffset -co COMPRESS=DEFLATE" + command = s"gdal_translate -srcwin $xMin $yMin $xOffset $yOffset", + outOptions ) val isEmpty = result.isEmpty diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala index 1cadf2c9a..5203178e0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala @@ -1,11 +1,12 @@ package com.databricks.labs.mosaic.core.types +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types._ /** Type definition for the raster tile. */ class RasterTileType(fields: Array[StructField]) extends StructType(fields) { - def rasterType: DataType = fields(1).dataType + def rasterType: DataType = fields.find(_.name == "raster").get.dataType override def simpleString: String = "RASTER_TILE" @@ -19,20 +20,54 @@ object RasterTileType { * Creates a new instance of [[RasterTileType]]. * * @param idType - * Type of the index ID. + * Type of the index ID. Can be one of [[LongType]], [[IntegerType]] or + * [[StringType]]. + * @param rasterType + * Type of the raster. Can be one of [[ByteType]] or [[StringType]]. Not + * to be confused with the data type of the raster. This is the type of + * the column that contains the raster. + * * @return * An instance of [[RasterTileType]]. */ - def apply(idType: DataType): RasterTileType = { + def apply(idType: DataType, rasterType: DataType): DataType = { require(Seq(LongType, IntegerType, StringType).contains(idType)) new RasterTileType( Array( StructField("index_id", idType), - StructField("raster", BinaryType), + StructField("raster", rasterType), StructField("parentPath", StringType), StructField("driver", StringType) ) ) } + /** + * Creates a new instance of [[RasterTileType]]. + * + * @param idType + * Type of the index ID. Can be one of [[LongType]], [[IntegerType]] or + * [[StringType]]. + * @param tileExpr + * Expression containing a tile. This is used to infer the raster type + * when chaining expressions. + * @return + */ + def apply(idType: DataType, tileExpr: Expression): DataType = { + require(Seq(LongType, IntegerType, StringType).contains(idType)) + tileExpr.dataType match { + case st @ StructType(_) => apply(idType, st.find(_.name == "raster").get.dataType) + case _ @ArrayType(elementType: StructType, _) => apply(idType, elementType.find(_.name == "raster").get.dataType) + case _ => throw new IllegalArgumentException("Unsupported raster type.") + } + } + + def apply(tileExpr: Expression): RasterTileType = { + tileExpr.dataType match { + case StructType(fields) => new RasterTileType(fields) + case ArrayType(elementType: StructType, _) => new RasterTileType(elementType.fields) + case _ => throw new IllegalArgumentException("Unsupported raster type.") + } + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala index e7a8e9218..30a7765c1 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala @@ -106,19 +106,22 @@ case class MosaicRasterTile( * An instance of [[InternalRow]]. */ def serialize( - rasterDataType: DataType = BinaryType, - checkpointLocation: String = "" + rasterDataType: DataType ): InternalRow = { val parentPathUTF8 = UTF8String.fromString(parentPath) val driverUTF8 = UTF8String.fromString(driver) - val encodedRaster = encodeRaster(rasterDataType, checkpointLocation) + val encodedRaster = encodeRaster(rasterDataType) if (Option(index).isDefined) { if (index.isLeft) InternalRow.fromSeq( Seq(index.left.get, encodedRaster, parentPathUTF8, driverUTF8) ) - else InternalRow.fromSeq( - Seq(UTF8String.fromString(index.right.get), encodedRaster, parentPathUTF8, driverUTF8) - ) + else { + // Copy from tmp to checkpoint. + // Have to use GDAL Driver to do this since sidecar files are not copied by spark. + InternalRow.fromSeq( + Seq(UTF8String.fromString(index.right.get), encodedRaster, parentPathUTF8, driverUTF8) + ) + } } else { InternalRow.fromSeq(Seq(null, encodedRaster, parentPathUTF8, driverUTF8)) } @@ -132,10 +135,9 @@ case class MosaicRasterTile( * An instance of [[Array]] of [[Byte]] representing WKB. */ private def encodeRaster( - rasterDataType: DataType = BinaryType, - checkpointLocation: String = "" + rasterDataType: DataType = BinaryType ): Any = { - GDAL.writeRasters(Seq(raster), checkpointLocation, rasterDataType).head + GDAL.writeRasters(Seq(raster), rasterDataType).head } } @@ -153,12 +155,12 @@ object MosaicRasterTile { * @return * An instance of [[MosaicRasterTile]]. */ - def deserialize(row: InternalRow, idDataType: DataType): MosaicRasterTile = { + def deserialize(row: InternalRow, idDataType: DataType, rasterType: DataType): MosaicRasterTile = { val index = row.get(0, idDataType) - val rasterBytes = row.get(1, BinaryType) + val rawRaster = row.get(1, rasterType) val parentPath = row.get(2, StringType).toString val driver = row.get(3, StringType).toString - val raster = GDAL.readRaster(rasterBytes, parentPath, driver, BinaryType) + val raster = GDAL.readRaster(rawRaster, parentPath, driver, rasterType) // noinspection TypeCheckCanBeMatch if (Option(index).isDefined) { if (index.isInstanceOf[Long]) { diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala index 285df2191..a38e76900 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala @@ -19,11 +19,17 @@ import java.nio.file.{Files, Paths} /** An object defining the retiling read strategy for the GDAL file format. */ object ReTileOnRead extends ReadStrategy { + val tileDataType: DataType = StringType + // noinspection DuplicatedCode /** * Returns the schema of the GDAL file format. * @note - * Different read strategies can have different schemas. + * Different read strategies can have different schemas. This is because + * the schema is defined by the read strategy. For retiling we always use + * checkpoint location. In this case rasters are stored off spark rows. + * If you need the tiles in memory please load them from path stored in + * the tile returned by the reader. * * @param options * Options passed to the reader. @@ -54,7 +60,10 @@ object ReTileOnRead extends ReadStrategy { .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) .add(StructField(SRID, IntegerType, nullable = false)) .add(StructField(LENGTH, LongType, nullable = false)) - .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType), nullable = false)) + // Note that for retiling we always use checkpoint location. + // In this case rasters are stored off spark rows. + // If you need the tiles in memory please load them from path stored in the tile returned by the reader. + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType, tileDataType), nullable = false)) } /** @@ -103,7 +112,7 @@ object ReTileOnRead extends ReadStrategy { case other => throw new RuntimeException(s"Unsupported field name: $other") } // Writing to bytes is destructive so we delay reading content and content length until the last possible moment - val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize())) + val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize(tileDataType))) RasterCleaner.dispose(tile) row }) diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala index 0517ac1d9..8c0c4a914 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala @@ -49,7 +49,9 @@ object ReadInMemory extends ReadStrategy { .add(StructField(METADATA, MapType(StringType, StringType), nullable = false)) .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) .add(StructField(SRID, IntegerType, nullable = false)) - .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType), nullable = false)) + // Note, for in memory reads the rasters are stored in the tile. + // For that we use Binary Columns. + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType, BinaryType), nullable = false)) } /** diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala index d6f26caf4..2f5bf39b6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala @@ -65,36 +65,32 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead val retiledDf = retileRaster(rasterDf, config) - val loadedDf = rasterDf + val loadedDf = retiledDf .withColumn( "tile", - rst_tessellate(col("tile"), lit(resolution)) + rst_tessellate(col("tile"), lit(resolution)) ) .repartition(nPartitions) + .groupBy("tile.index_id") + .agg(rst_combineavg_agg(col("tile")).alias("tile")) .withColumn( "grid_measures", - rasterToGridCombiner(col("tile"), lit(resolution)) + rasterToGridCombiner(col("tile")) ) .select( "grid_measures", "tile" ) .select( - posexplode(col("grid_measures")).as(Seq("band_id", "grid_measures")) + posexplode(col("grid_measures")).as(Seq("band_id", "measure")), + col("tile").getField("index_id").alias("cell_id") ) .repartition(nPartitions) .select( col("band_id"), - explode(col("grid_measures")).alias("grid_measures") + col("cell_id"), + col("measure") ) - .repartition(nPartitions) - .select( - col("band_id"), - col("grid_measures").getItem("cellID").alias("cell_id"), - col("grid_measures").getItem("measure").alias("measure") - ) - .groupBy("band_id", "cell_id") - .agg(avg("measure").alias("measure")) kRingResample(loadedDf, config) @@ -203,15 +199,15 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead * @return * The raster to grid function. */ - private def getRasterToGridFunc(combiner: String): (Column, Column) => Column = { + private def getRasterToGridFunc(combiner: String): Column => Column = { combiner match { - case "mean" => rst_rastertogridavg - case "min" => rst_rastertogridmin - case "max" => rst_rastertogridmax - case "median" => rst_rastertogridmedian - case "count" => rst_rastertogridcount - case "average" => rst_rastertogridavg - case "avg" => rst_rastertogridavg + case "mean" => rst_avg + case "min" => rst_min + case "max" => rst_max + case "median" => rst_median + case "count" => rst_pixelcount + case "average" => rst_avg + case "avg" => rst_avg case _ => throw new Error("Combiner not supported") } } @@ -232,7 +228,7 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead "combiner" -> this.extraOptions.getOrElse("combiner", "mean"), "retile" -> this.extraOptions.getOrElse("retile", "false"), "tileSize" -> this.extraOptions.getOrElse("tileSize", "-1"), - "sizeInMB" -> this.extraOptions.getOrElse("sizeInMB", ""), + "sizeInMB" -> this.extraOptions.getOrElse("sizeInMB", "-1"), "kRingInterpolate" -> this.extraOptions.getOrElse("kRingInterpolate", "0") ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala index 82752cad4..a5907dbe9 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala @@ -13,11 +13,13 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ -case class RST_Avg(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Avg](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig) +case class RST_Avg(tileExpr: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Avg](tileExpr, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { import org.json4s._ diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala index 241d913bc..fec760813 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala @@ -25,13 +25,14 @@ case class RST_BandMetaData(raster: Expression, band: Expression, expressionConf extends RasterBandExpression[RST_BandMetaData]( raster, band, - MapType(StringType, StringType), returnsRaster = false, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, StringType) + /** * @param raster * The raster to be used. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala index 8fa2d7314..e79a8ec40 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala @@ -15,10 +15,12 @@ import org.apache.spark.sql.types._ case class RST_BoundingBox( raster: Expression, expressionConfig: MosaicExpressionConfig -) extends RasterExpression[RST_BoundingBox](raster, BinaryType, returnsRaster = false, expressionConfig = expressionConfig) +) extends RasterExpression[RST_BoundingBox](raster, returnsRaster = false, expressionConfig = expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = BinaryType + /** * Computes the bounding box of the raster. The bbox is returned as a WKB * polygon. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala index 557565afe..29449a6ef 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala @@ -19,13 +19,14 @@ case class RST_Clip( ) extends Raster1ArgExpression[RST_Clip]( rastersExpr, geometryExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: org.apache.spark.sql.types.DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) /** diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala index 1d923fdc1..d63fe6914 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala @@ -9,20 +9,22 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** Expression for combining rasters using average of pixels. */ case class RST_CombineAvg( - rastersExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_CombineAvg]( - rastersExpr, - RasterTileType(expressionConfig.getCellIdType), + tileExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** Combines the rasters using average of pixels. */ override def rasterTransform(tiles: Seq[MosaicRasterTile]): Any = { val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala index f6b3ba1dc..5bbc01a7b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala @@ -6,6 +6,7 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner import com.databricks.labs.mosaic.core.raster.operator.CombineAVG import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile.{deserialize => deserializeTile} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpressionSerialization import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow @@ -13,7 +14,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType} +import org.apache.spark.sql.types.{ArrayType, DataType} import scala.collection.mutable.ArrayBuffer @@ -23,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer */ //noinspection DuplicatedCode case class RST_CombineAvgAgg( - rasterExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0 @@ -32,21 +33,23 @@ case class RST_CombineAvgAgg( with RasterExpressionSerialization { override lazy val deterministic: Boolean = true - override val child: Expression = rasterExpr + override val child: Expression = tileExpr override val nullable: Boolean = false - override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override lazy val dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + lazy val tileType: DataType = dataType.asInstanceOf[RasterTileType].rasterType override def prettyName: String = "rst_combine_avg_agg" + val cellIDType: DataType = expressionConfig.getCellIdType private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) private lazy val row = new UnsafeRow(1) - def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { + override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { val value = child.eval(input) buffer += InternalRow.copyValue(value) buffer } - def merge(buffer: ArrayBuffer[Any], input: ArrayBuffer[Any]): ArrayBuffer[Any] = { + override def merge(buffer: ArrayBuffer[Any], input: ArrayBuffer[Any]): ArrayBuffer[Any] = { buffer ++= input } @@ -63,10 +66,15 @@ case class RST_CombineAvgAgg( if (buffer.isEmpty) { null + } else if (buffer.size == 1) { + val result = buffer.head + buffer.clear() + result } else { // Do do move the expression - var tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + var tiles = buffer.map(row => deserializeTile(row.asInstanceOf[InternalRow], cellIDType, tileType)) + buffer.clear() // If merging multiple index rasters, the index value is dropped val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null @@ -77,9 +85,9 @@ case class RST_CombineAvgAgg( val result = MosaicRasterTile(idx, combined, parentPath, driver) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + .serialize(tileType) - tiles.foreach(RasterCleaner.dispose(_)) + tiles.foreach(RasterCleaner.dispose) RasterCleaner.dispose(result) tiles = null @@ -101,7 +109,7 @@ case class RST_CombineAvgAgg( buffer } - override protected def withNewChildInternal(newChild: Expression): RST_CombineAvgAgg = copy(rasterExpr = newChild) + override protected def withNewChildInternal(newChild: Expression): RST_CombineAvgAgg = copy(tileExpr = newChild) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala new file mode 100644 index 000000000..db20f8a3a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala @@ -0,0 +1,73 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** The expression for applying kernel filter on a raster. */ +case class RST_Convolve( + rastersExpr: Expression, + kernelExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_Convolve]( + rastersExpr, + kernelExpr, + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + override def dataType: org.apache.spark.sql.types.DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) + + /** + * Clips a raster by a vector. + * + * @param tile + * The raster to be used. + * @param arg1 + * The vector to be used. + * @return + * The clipped raster. + */ + override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = { + val kernel = arg1.asInstanceOf[Array[Array[Double]]] + tile.copy( + raster = tile.getRaster.convolve(kernel) + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Convolve extends WithExpressionInfo { + + override def name: String = "rst_convolve" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster with the kernel filter applied. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster, kernel); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Convolve](2, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala index 822228a1b..fa576427a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala @@ -9,25 +9,27 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** Expression for combining rasters using average of pixels. */ case class RST_DerivedBand( - rastersExpr: Expression, + tileExpr: Expression, pythonFuncExpr: Expression, funcNameExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArray2ArgExpression[RST_DerivedBand]( - rastersExpr, + tileExpr, pythonFuncExpr, funcNameExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** Combines the rasters using average of pixels. */ override def rasterTransform(tiles: Seq[MosaicRasterTile], arg1: Any, arg2: Any): Any = { val pythonFunc = arg1.asInstanceOf[UTF8String].toString diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala index 47d4aa12a..f02194d62 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer */ //noinspection DuplicatedCode case class RST_DerivedBandAgg( - rasterExpr: Expression, + tileExpr: Expression, pythonFuncExpr: Expression, funcNameExpr: Expression, expressionConfig: MosaicExpressionConfig, @@ -36,13 +36,13 @@ case class RST_DerivedBandAgg( override lazy val deterministic: Boolean = true override val nullable: Boolean = false - override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override lazy val dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) override def prettyName: String = "rst_combine_avg_agg" private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) private lazy val row = new UnsafeRow(1) - override def first: Expression = rasterExpr + override def first: Expression = tileExpr override def second: Expression = pythonFuncExpr override def third: Expression = funcNameExpr @@ -74,9 +74,16 @@ case class RST_DerivedBandAgg( // This works for Literals only val pythonFunc = pythonFuncExpr.eval(null).asInstanceOf[UTF8String].toString val funcName = funcNameExpr.eval(null).asInstanceOf[UTF8String].toString + val rasterType = RasterTileType(tileExpr).rasterType // Do do move the expression - var tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + var tiles = buffer.map(row => + MosaicRasterTile.deserialize( + row.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) + ) // If merging multiple index rasters, the index value is dropped val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null @@ -88,7 +95,7 @@ case class RST_DerivedBandAgg( val result = MosaicRasterTile(idx, combined, parentPath, driver) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + .serialize(BinaryType) tiles.foreach(RasterCleaner.dispose(_)) RasterCleaner.dispose(result) @@ -113,7 +120,7 @@ case class RST_DerivedBandAgg( } override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): RST_DerivedBandAgg = - copy(rasterExpr = newFirst, pythonFuncExpr = newSecond, funcNameExpr = newThird) + copy(tileExpr = newFirst, pythonFuncExpr = newSecond, funcNameExpr = newThird) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala new file mode 100644 index 000000000..ee8b34d3b --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala @@ -0,0 +1,77 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.unsafe.types.UTF8String + +/** The expression for applying NxN filter on a raster. */ +case class RST_Filter( + rastersExpr: Expression, + kernelSizeExpr: Expression, + operationExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster2ArgExpression[RST_Filter]( + rastersExpr, + kernelSizeExpr, + operationExpr, + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + override def dataType: org.apache.spark.sql.types.DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) + + /** + * Clips a raster by a vector. + * + * @param tile + * The raster to be used. + * @param arg1 + * The vector to be used. + * @return + * The clipped raster. + */ + override def rasterTransform(tile: MosaicRasterTile, arg1: Any, arg2: Any): Any = { + val n = arg1.asInstanceOf[Int] + val operation = arg2.asInstanceOf[UTF8String].toString + tile.copy( + raster = tile.getRaster.filter(n, operation) + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Filter extends WithExpressionInfo { + + override def name: String = "rst_filter" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster with the filter applied. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster, kernelSize, operation); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Filter](3, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala index 2befb353c..90e49b654 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala @@ -9,6 +9,7 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.ArrayType /** The expression for stacking and resampling input bands. */ case class RST_FromBands( @@ -16,13 +17,18 @@ case class RST_FromBands( expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_FromBands]( bandsExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: org.apache.spark.sql.types.DataType = + RasterTileType( + expressionConfig.getCellIdType, + RasterTileType(bandsExpr).rasterType + ) + /** * Stacks and resamples input bands. * @param rasters diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala index bd2926bcb..59c701f71 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala @@ -15,7 +15,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} -import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import java.nio.file.{Files, Paths} @@ -25,7 +25,7 @@ import java.nio.file.{Files, Paths} * expression in the expression tree for a raster tile. */ case class RST_FromContent( - rasterExpr: Expression, + contentExpr: Expression, driverExpr: Expression, sizeInMB: Expression, expressionConfig: MosaicExpressionConfig @@ -33,8 +33,10 @@ case class RST_FromContent( with Serializable with NullIntolerant with CodegenFallback { + + val tileType: DataType = BinaryType - override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileType) protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) @@ -46,12 +48,13 @@ case class RST_FromContent( override def inline: Boolean = false - override def children: Seq[Expression] = Seq(rasterExpr, driverExpr, sizeInMB) + override def children: Seq[Expression] = Seq(contentExpr, driverExpr, sizeInMB) override def elementSchema: StructType = StructType(Array(StructField("tile", dataType))) /** - * subdivides raster binary content into tiles of the specified size (in MB). + * subdivides raster binary content into tiles of the specified size (in + * MB). * @param input * The input file path. * @return @@ -61,13 +64,13 @@ case class RST_FromContent( GDAL.enable(expressionConfig) val driver = driverExpr.eval(input).asInstanceOf[UTF8String].toString val ext = GDAL.getExtension(driver) - var rasterArr = rasterExpr.eval(input).asInstanceOf[Array[Byte]] + var rasterArr = contentExpr.eval(input).asInstanceOf[Array[Byte]] val targetSize = sizeInMB.eval(input).asInstanceOf[Int] if (targetSize <= 0 || rasterArr.length <= targetSize) { // - no split required var raster = MosaicRasterGDAL.readRaster(rasterArr, PathUtils.NO_PATH_STRING, driver) var tile = MosaicRasterTile(null, raster, PathUtils.NO_PATH_STRING, driver) - val row = tile.formatCellId(indexSystem).serialize() + val row = tile.formatCellId(indexSystem).serialize(tileType) RasterCleaner.dispose(raster) RasterCleaner.dispose(tile) rasterArr = null @@ -84,7 +87,7 @@ case class RST_FromContent( // split to tiles up to specifed threshold var tiles = ReTileOnRead.localSubdivide(rasterPath, PathUtils.NO_PATH_STRING, targetSize) - val rows = tiles.map(_.formatCellId(indexSystem).serialize()) + val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) tiles.foreach(RasterCleaner.dispose(_)) Files.deleteIfExists(Paths.get(rasterPath)) rasterArr = null @@ -118,10 +121,10 @@ object RST_FromContent extends WithExpressionInfo { | ... | """.stripMargin - override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { - (children: Seq[Expression]) => { + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { (children: Seq[Expression]) => + { val sizeExpr = if (children.length < 3) new Literal(-1, IntegerType) else children(2) - RST_FromContent(children(0), children(1), sizeExpr, expressionConfig) + RST_FromContent(children.head, children(1), sizeExpr, expressionConfig) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala index fbce5bf58..ee5bec721 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala @@ -15,7 +15,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} -import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import java.nio.file.{Files, Paths, StandardCopyOption} @@ -32,8 +32,10 @@ case class RST_FromFile( with Serializable with NullIntolerant with CodegenFallback { + + val tileType: DataType = BinaryType - override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileType) protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) @@ -66,7 +68,7 @@ case class RST_FromFile( if (targetSize <= 0 && Files.size(Paths.get(readPath)) <= Integer.MAX_VALUE) { var raster = MosaicRasterGDAL.readRaster(readPath, path) var tile = MosaicRasterTile(null, raster, path, raster.getDriversShortName) - val row = tile.formatCellId(indexSystem).serialize() + val row = tile.formatCellId(indexSystem).serialize(tileType) RasterCleaner.dispose(raster) RasterCleaner.dispose(tile) raster = null @@ -79,7 +81,7 @@ case class RST_FromFile( Files.copy(Paths.get(readPath), Paths.get(tmpPath), StandardCopyOption.REPLACE_EXISTING) val size = if (targetSize <= 0) 64 else targetSize var tiles = ReTileOnRead.localSubdivide(tmpPath, path, size) - val rows = tiles.map(_.formatCellId(indexSystem).serialize()) + val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) tiles.foreach(RasterCleaner.dispose(_)) Files.deleteIfExists(Paths.get(tmpPath)) tiles = null diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala index f4213eee7..404eb4b90 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the georeference of the raster. */ case class RST_GeoReference(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_GeoReference](raster, MapType(StringType, DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_GeoReference](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, DoubleType) + /** Returns the georeference of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val geoTransform = tile.getRaster.getRaster.GetGeoTransform() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala index 8f10b89cb..aa07a6637 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala @@ -8,7 +8,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{ArrayType, DoubleType} +import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType} /** The expression for extracting the no data value of a raster. */ case class RST_GetNoData( @@ -16,13 +16,14 @@ case class RST_GetNoData( expressionConfig: MosaicExpressionConfig ) extends RasterExpression[RST_GetNoData]( rastersExpr, - ArrayType(DoubleType), returnsRaster = false, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** * Extracts the no data value of a raster. * diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala index a87f6fa25..8d1fc77f1 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala @@ -8,20 +8,25 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** Returns the subdatasets of the raster. */ -case class RST_GetSubdataset(raster: Expression, subsetName: Expression, expressionConfig: MosaicExpressionConfig) - extends Raster1ArgExpression[RST_GetSubdataset]( - raster, +case class RST_GetSubdataset( + tileExpr: Expression, + subsetName: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_GetSubdataset]( + tileExpr, subsetName, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** Returns the subdatasets of the raster. */ override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = { val subsetName = arg1.asInstanceOf[UTF8String].toString diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala index ceb638f29..f2508e1e6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the width of the raster. */ case class RST_Height(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Height](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Height](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the width of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.ySize diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala index 8cf226664..3b1b806da 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala @@ -11,20 +11,22 @@ import com.databricks.labs.mosaic.utils.PathUtils import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** The expression that initializes no data values of a raster. */ case class RST_InitNoData( - rastersExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterExpression[RST_InitNoData]( - rastersExpr, - RasterTileType(expressionConfig.getCellIdType), + tileExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Initializes no data values of a raster. * diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala index 4a5f5034f..7d6267bec 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns true if the raster is empty. */ case class RST_IsEmpty(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_IsEmpty](raster, BooleanType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_IsEmpty](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = BooleanType + /** Returns true if the raster is empty. */ override def rasterTransform(tile: MosaicRasterTile): Any = { var raster = tile.getRaster diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala new file mode 100644 index 000000000..586337556 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala @@ -0,0 +1,205 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.MOSAIC_NO_DRIVER +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.datasource.gdal.ReTileOnRead +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.file.{Files, Paths} +import scala.util.Try + +/** + * Creates raster tiles from the input column. + * + * @param inputExpr + * The expression for the raster. If the raster is stored on disc, the path + * to the raster is provided. If the raster is stored in memory, the bytes of + * the raster are provided. + * @param sizeInMBExpr + * The size of the tiles in MB. If set to -1, the file is loaded and returned + * as a single tile. If set to 0, the file is loaded and subdivided into + * tiles of size 64MB. If set to a positive value, the file is loaded and + * subdivided into tiles of the specified size. If the file is too big to fit + * in memory, it is subdivided into tiles of size 64MB. + * @param driverExpr + * The driver to use for reading the raster. If not specified, the driver is + * inferred from the file extension. If the input is a byte array, the driver + * has to be specified. + * @param withCheckpointExpr + * If set to true, the tiles are written to the checkpoint directory. If set + * to false, the tiles are returned as a in-memory byte arrays. + * @param expressionConfig + * Additional arguments for the expression (expressionConfigs). + */ +case class RST_MakeTiles( + inputExpr: Expression, + driverExpr: Expression, + sizeInMBExpr: Expression, + withCheckpointExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends CollectionGenerator + with Serializable + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = { + require(withCheckpointExpr.isInstanceOf[Literal]) + if (withCheckpointExpr.eval().asInstanceOf[Boolean]) { + // Raster is referenced via a path + RasterTileType(expressionConfig.getCellIdType, StringType) + } else { + // Raster is referenced via a byte array + RasterTileType(expressionConfig.getCellIdType, BinaryType) + } + } + + protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) + + protected val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) + + protected val cellIdDataType: DataType = indexSystem.getCellIdDataType + + override def position: Boolean = false + + override def inline: Boolean = false + + override def children: Seq[Expression] = Seq(inputExpr, driverExpr, sizeInMBExpr, withCheckpointExpr) + + override def elementSchema: StructType = StructType(Array(StructField("tile", dataType))) + + private def getDriver(rawInput: Any, rawDriver: String): String = { + if (rawDriver == MOSAIC_NO_DRIVER) { + if (inputExpr.dataType == StringType) { + val path = rawInput.asInstanceOf[UTF8String].toString + MosaicRasterGDAL.identifyDriver(path) + } else { + throw new IllegalArgumentException("Driver has to be specified for byte array input") + } + } else { + rawDriver + } + } + + private def getInputSize(rawInput: Any): Long = { + if (inputExpr.dataType == StringType) { + val path = rawInput.asInstanceOf[UTF8String].toString + Files.size(Paths.get(path)) + } else { + val bytes = rawInput.asInstanceOf[Array[Byte]] + bytes.length + } + } + + /** + * Loads a raster from a file and subdivides it into tiles of the specified + * size (in MB). + * @param input + * The input file path. + * @return + * The tiles. + */ + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + GDAL.enable(expressionConfig) + + val tileType = dataType.asInstanceOf[StructType].find(_.name == "raster").get.dataType + + val rawDriver = driverExpr.eval(input).asInstanceOf[UTF8String].toString + val rawInput = inputExpr.eval(input) + val driver = getDriver(rawInput, rawDriver) + val targetSize = sizeInMBExpr.eval(input).asInstanceOf[Int] + val inputSize = getInputSize(rawInput) + + if (targetSize <= 0 && inputSize <= Integer.MAX_VALUE) { + // - no split required + val raster = GDAL.readRaster(rawInput, PathUtils.NO_PATH_STRING, driver, inputExpr.dataType) + val tile = MosaicRasterTile(null, raster, PathUtils.NO_PATH_STRING, driver) + val row = tile.formatCellId(indexSystem).serialize(tileType) + RasterCleaner.dispose(raster) + RasterCleaner.dispose(tile) + Seq(InternalRow.fromSeq(Seq(row))) + } else { + // target size is > 0 and raster size > target size + // - write the initial raster to file (unsplit) + // - createDirectories in case of context isolation + val rasterPath = PathUtils.createTmpFilePath(GDAL.getExtension(driver)) + Files.createDirectories(Paths.get(rasterPath).getParent) + Files.write(Paths.get(rasterPath), rawInput.asInstanceOf[Array[Byte]]) + val size = if (targetSize <= 0) 64 else targetSize + var tiles = ReTileOnRead.localSubdivide(rasterPath, PathUtils.NO_PATH_STRING, size) + val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) + tiles.foreach(RasterCleaner.dispose(_)) + Files.deleteIfExists(Paths.get(rasterPath)) + tiles = null + rows.map(row => InternalRow.fromSeq(Seq(row))) + } + } + + override def makeCopy(newArgs: Array[AnyRef]): Expression = + GenericExpressionFactory.makeCopyImpl[RST_MakeTiles](this, newArgs, children.length, expressionConfig) + + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = makeCopy(newChildren.toArray) + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_MakeTiles extends WithExpressionInfo { + + override def name: String = "rst_maketiles" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a set of new rasters with the specified tile size (tileWidth x tileHeight). + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_path); + | {index_id, raster, parent_path, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { (children: Seq[Expression]) => + { + def checkSize(size: Expression) = Try(size.eval().asInstanceOf[Int]).isSuccess + def checkChkpnt(chkpnt: Expression) = Try(chkpnt.eval().asInstanceOf[Boolean]).isSuccess + def checkDriver(driver: Expression) = Try(driver.eval().asInstanceOf[UTF8String].toString).isSuccess + val noSize = new Literal(-1, IntegerType) + val noDriver = new Literal(MOSAIC_NO_DRIVER, StringType) + val noCheckpoint = new Literal(false, BooleanType) + + children match { + // Note type checking only works for literals + case Seq(input) => RST_MakeTiles(input, noDriver, noSize, noCheckpoint, expressionConfig) + case Seq(input, driver) if checkDriver(driver) => RST_MakeTiles(input, driver, noSize, noCheckpoint, expressionConfig) + case Seq(input, size) if checkSize(size) => RST_MakeTiles(input, noDriver, size, noCheckpoint, expressionConfig) + case Seq(input, checkpoint) if checkChkpnt(checkpoint) => + RST_MakeTiles(input, noDriver, noSize, checkpoint, expressionConfig) + case Seq(input, size, checkpoint) if checkSize(size) && checkChkpnt(checkpoint) => + RST_MakeTiles(input, noDriver, size, checkpoint, expressionConfig) + case Seq(input, driver, size) if checkDriver(driver) && checkSize(size) => + RST_MakeTiles(input, driver, size, noCheckpoint, expressionConfig) + case Seq(input, driver, checkpoint) if checkDriver(driver) && checkChkpnt(checkpoint) => + RST_MakeTiles(input, driver, noSize, checkpoint, expressionConfig) + case Seq(input, driver, size, checkpoint) if checkDriver(driver) && checkSize(size) && checkChkpnt(checkpoint) => + RST_MakeTiles(input, driver, size, checkpoint, expressionConfig) + case _ => RST_MakeTiles(children.head, children(1), children(2), children(3), expressionConfig) + } + } + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala index 53e84d96b..1c74e1f0a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala @@ -11,23 +11,25 @@ import com.databricks.labs.mosaic.utils.PathUtils import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** The expression for map algebra. */ case class RST_MapAlgebra( - rastersExpr: Expression, + tileExpr: Expression, jsonSpecExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArray1ArgExpression[RST_MapAlgebra]( - rastersExpr, + tileExpr, jsonSpecExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Map Algebra. * @param tiles diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala index abe042c2b..434be4a68 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala @@ -14,10 +14,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_Max(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Max](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Max](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val nBands = tile.raster.raster.GetRasterCount() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala index 091121e91..19d3fc0a6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala @@ -1,8 +1,7 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.raster.api.GDAL -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL -import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALCalc, GDALInfo, GDALWarp} +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALWarp import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression @@ -16,10 +15,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_Median(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Median](rasterExpr, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Median](rasterExpr, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val raster = tile.raster diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala index 804c4f195..f77058a65 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the memory size of the raster in bytes. */ case class RST_MemSize(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_MemSize](raster, LongType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_MemSize](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = LongType + /** Returns the memory size of the raster in bytes. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.getMemSize diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala index cb9907848..c8ef6846d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala @@ -9,20 +9,22 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** Returns a raster that is a result of merging an array of rasters. */ case class RST_Merge( - rastersExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_Merge]( - rastersExpr, - RasterTileType(expressionConfig.getCellIdType), + tileExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Merges an array of rasters. * @param tiles diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala index 5902eac3b..88705dfe0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer /** Merges rasters into a single raster. */ //noinspection DuplicatedCode case class RST_MergeAgg( - rasterExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0 @@ -29,9 +29,9 @@ case class RST_MergeAgg( with RasterExpressionSerialization { override lazy val deterministic: Boolean = true - override val child: Expression = rasterExpr + override val child: Expression = tileExpr override val nullable: Boolean = false - override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override lazy val dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) override def prettyName: String = "rst_merge_agg" private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) @@ -66,8 +66,15 @@ case class RST_MergeAgg( // This is a trick to get the rasters sorted by their parent path to ensure more consistent results // when merging rasters with large overlaps + val rasterType = RasterTileType(tileExpr).rasterType var tiles = buffer - .map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + .map(row => + MosaicRasterTile.deserialize( + row.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) + ) .sortBy(_.getParentPath) // If merging multiple index rasters, the index value is dropped @@ -79,7 +86,7 @@ case class RST_MergeAgg( val result = MosaicRasterTile(idx, merged, parentPath, driver) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + .serialize(BinaryType) tiles.foreach(RasterCleaner.dispose(_)) RasterCleaner.dispose(merged) @@ -103,7 +110,7 @@ case class RST_MergeAgg( buffer } - override protected def withNewChildInternal(newChild: Expression): RST_MergeAgg = copy(rasterExpr = newChild) + override protected def withNewChildInternal(newChild: Expression): RST_MergeAgg = copy(tileExpr = newChild) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala index 8a96ff0d1..0b6754ebe 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the metadata of the raster. */ case class RST_MetaData(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_MetaData](raster, MapType(StringType, StringType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_MetaData](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, StringType) + /** Returns the metadata of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = buildMapString(tile.getRaster.metadata) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala index 67fdb30d3..ea62e106f 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala @@ -1,6 +1,5 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALInfo import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression @@ -14,10 +13,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_Min(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Min](raster, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Min](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val nBands = tile.raster.raster.GetRasterCount() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala index fa595fd4b..67b580f0c 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala @@ -9,24 +9,26 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** The expression for computing NDVI index. */ case class RST_NDVI( - rastersExpr: Expression, + tileExpr: Expression, redIndex: Expression, nirIndex: Expression, expressionConfig: MosaicExpressionConfig ) extends Raster2ArgExpression[RST_NDVI]( - rastersExpr, + tileExpr, redIndex, nirIndex, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Computes NDVI index. * @param tile diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala index f5dd09551..e0a8c8d9e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the number of bands in the raster. */ case class RST_NumBands(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_NumBands](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_NumBands](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the number of bands in the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.numBands diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala index 79f44db03..b2543a87e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala @@ -12,10 +12,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_PixelCount(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelCount](rasterExpr, ArrayType(DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_PixelCount](rasterExpr, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(LongType) + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val bandCount = tile.raster.raster.GetRasterCount() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala index d1c3713ef..0c34be59b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the pixel height of the raster. */ case class RST_PixelHeight(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelHeight](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_PixelHeight](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the pixel height of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val gt = tile.getRaster.getGeoTransform diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala index 6a4956e9e..b1645696b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the pixel width of the raster. */ case class RST_PixelWidth(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelWidth](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_PixelWidth](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the pixel width of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val gt = tile.getRaster.getGeoTransform diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala index 42b9a928a..9da0f19ef 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala @@ -17,10 +17,12 @@ case class RST_RasterToWorldCoord( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoord](raster, x, y, StringType, returnsRaster = false, expressionConfig = expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoord](raster, x, y, returnsRaster = false, expressionConfig = expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = StringType + /** * Returns the world coordinates of the raster (x,y) pixel by applying * GeoTransform. This ensures the projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala index 4bd06646a..5fea59b49 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala @@ -16,10 +16,12 @@ case class RST_RasterToWorldCoordX( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoordX](raster, x, y, DoubleType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoordX](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** * Returns the world coordinates of the raster x pixel by applying * GeoTransform. This ensures the projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala index 262d6bbad..ae170709c 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala @@ -16,10 +16,12 @@ case class RST_RasterToWorldCoordY( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoordY](raster, x, y, DoubleType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoordY](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** * Returns the world coordinates of the raster y pixel by applying * GeoTransform. This ensures the projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala index 4465866dc..939011882 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala @@ -8,6 +8,7 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** * Returns a set of new rasters with the specified tile size (tileWidth x @@ -22,6 +23,8 @@ case class RST_ReTile( with NullIntolerant with CodegenFallback { + override def dataType: DataType = rasterExpr.dataType + /** * Returns a set of new rasters with the specified tile size (tileWidth x * tileHeight). diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala index c3cd097c7..5933c7133 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the rotation angle of the raster. */ case class RST_Rotation(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Rotation](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Rotation](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the rotation angle of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val gt = tile.getRaster.getRaster.GetGeoTransform() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala index c8bce06b7..648260ae5 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala @@ -14,10 +14,12 @@ import scala.util.Try /** Returns the SRID of the raster. */ case class RST_SRID(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SRID](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_SRID](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the SRID of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { // Reference: https://gis.stackexchange.com/questions/267321/extracting-epsg-from-a-raster-using-gdal-bindings-in-python diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala index c16891871..e13af4763 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the scale x of the raster. */ case class RST_ScaleX(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_ScaleX](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_ScaleX](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the scale x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(1) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala index 3b0779763..8defba49a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the scale y of the raster. */ case class RST_ScaleY(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_ScaleY](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_ScaleY](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the scale y of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(5) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala index 911271d33..f4350e7d3 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala @@ -12,22 +12,24 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.DataType /** Returns a raster with the specified no data values. */ case class RST_SetNoData( - rastersExpr: Expression, + tileExpr: Expression, noDataExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends Raster1ArgExpression[RST_SetNoData]( - rastersExpr, + tileExpr, noDataExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Returns a raster with the specified no data values. * @param tile diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala index ee3d0c4dd..439592e73 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the skew x of the raster. */ case class RST_SkewX(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SkewX](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_SkewX](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the skew x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(2) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala index ff9903687..1f259b5de 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the skew y of the raster. */ case class RST_SkewY(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SkewY](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_SkewY](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the skew y of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(4) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala index 8c58e7f74..3f1536510 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala @@ -13,13 +13,14 @@ import org.apache.spark.sql.types._ case class RST_Subdatasets(raster: Expression, expressionConfig: MosaicExpressionConfig) extends RasterExpression[RST_Subdatasets]( raster, - MapType(StringType, StringType), returnsRaster = false, expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, StringType) + /** Returns the subdatasets of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = buildMapString(tile.getRaster.subdatasets) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala index 6351d47f2..4900eaab2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala @@ -16,10 +16,12 @@ import java.util.{Vector => JVector} /** Returns the summary info the raster. */ case class RST_Summary(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Summary](raster, StringType, returnsRaster = false, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Summary](raster, returnsRaster = false, expressionConfig: MosaicExpressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = StringType + /** Returns the summary info the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val vector = new JVector[String]() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala index b364d39da..16dc25ee0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns true if the raster is empty. */ case class RST_TryOpen(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_TryOpen](raster, BooleanType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_TryOpen](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = BooleanType + /** Returns true if the raster can be opened. */ override def rasterTransform(tile: MosaicRasterTile): Any = { Option(tile.getRaster.getRaster).isDefined diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala index 4f050bc7e..143158736 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_UpperLeftX(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_UpperLeftX](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_UpperLeftX](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(0) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala index 0e052e3ae..702c8a0c4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left y of the raster. */ case class RST_UpperLeftY(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_UpperLeftY](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_UpperLeftY](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the upper left y of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(3) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala index 4bd56686a..953eb17bd 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the width of the raster. */ case class RST_Width(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Width](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Width](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the width of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.xSize diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala index 2d0884a81..2d5438c3b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala @@ -9,6 +9,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** Returns the world coordinate of the raster. */ case class RST_WorldToRasterCoord( @@ -16,10 +17,12 @@ case class RST_WorldToRasterCoord( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoord](raster, x, y, PixelCoordsType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoord](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = PixelCoordsType + /** * Returns the x and y of the raster by applying GeoTransform as a tuple of * Integers. This will ensure projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala index 26c888fe1..41d6e8b9b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala @@ -16,10 +16,12 @@ case class RST_WorldToRasterCoordX( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoordX](raster, x, y, IntegerType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoordX](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: IntegerType = IntegerType + /** * Returns the x coordinate of the raster by applying GeoTransform. This * will ensure projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala index 8bb125faa..62ba72228 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala @@ -16,10 +16,12 @@ case class RST_WorldToRasterCoordY( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoordY](raster, x, y, IntegerType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoordY](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: IntegerType = IntegerType + /** * Returns the y coordinate of the raster by applying GeoTransform. This * will ensure projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala index f01027ff1..35ad927c6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala @@ -2,12 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} -import org.apache.spark.sql.types.DataType import scala.reflect.ClassTag @@ -21,8 +21,6 @@ import scala.reflect.ClassTag * containing the raster file content. * @param arg1Expr * The expression for the first argument. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -31,7 +29,6 @@ import scala.reflect.ClassTag abstract class Raster1ArgExpression[T <: Expression: ClassTag]( rasterExpr: Expression, arg1Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression @@ -43,9 +40,6 @@ abstract class Raster1ArgExpression[T <: Expression: ClassTag]( override def right: Expression = arg1Expr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster and the arguments to @@ -75,10 +69,15 @@ abstract class Raster1ArgExpression[T <: Expression: ClassTag]( // noinspection DuplicatedCode override def nullSafeEval(input: Any, arg1: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + input.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) val raster = tile.getRaster val result = rasterTransform(tile, arg1) - val serialized = serialize(result, returnsRaster, outputType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) RasterCleaner.dispose(raster) RasterCleaner.dispose(result) serialized diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala index ccdc7d5b3..c5be60724 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala @@ -2,6 +2,7 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -22,8 +23,6 @@ import scala.reflect.ClassTag * The expression for the first argument. * @param arg2Expr * The expression for the second argument. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -33,7 +32,6 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( rasterExpr: Expression, arg1Expr: Expression, arg2Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends TernaryExpression @@ -47,9 +45,6 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( override def third: Expression = arg2Expr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster and the arguments to @@ -83,9 +78,14 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( // noinspection DuplicatedCode override def nullSafeEval(input: Any, arg1: Any, arg2: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + input.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) val result = rasterTransform(tile, arg1, arg2) - val serialized = serialize(result, returnsRaster, outputType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) // passed by name makes things re-evaluated RasterCleaner.dispose(tile) serialized diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala index d21f96c2d..5dbfd08cc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala @@ -2,11 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.ArrayType import scala.reflect.ClassTag @@ -18,8 +19,6 @@ import scala.reflect.ClassTag * @param rastersExpr * The rasters expression. It is an array column containing rasters as either * paths or as content byte arrays. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -28,7 +27,6 @@ import scala.reflect.ClassTag abstract class RasterArray1ArgExpression[T <: Expression: ClassTag]( rastersExpr: Expression, arg1Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression @@ -36,9 +34,6 @@ abstract class RasterArray1ArgExpression[T <: Expression: ClassTag]( with Serializable with RasterExpressionSerialization { - /** Output Data Type */ - override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType - override def left: Expression = rastersExpr override def right: Expression = arg1Expr @@ -72,7 +67,8 @@ abstract class RasterArray1ArgExpression[T <: Expression: ClassTag]( GDAL.enable(expressionConfig) val tiles = RasterArrayUtils.getTiles(input, rastersExpr, expressionConfig) val result = rasterTransform(tiles, arg1) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val resultType = if (returnsRaster) RasterTileType(rastersExpr).rasterType else dataType + val serialized = serialize(result, returnsRaster, resultType, expressionConfig) tiles.foreach(t => RasterCleaner.dispose(t)) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala index a26082f2d..9de963684 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala @@ -2,11 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, TernaryExpression} -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.ArrayType import scala.reflect.ClassTag @@ -18,8 +19,6 @@ import scala.reflect.ClassTag * @param rastersExpr * The rasters expression. It is an array column containing rasters as either * paths or as content byte arrays. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -29,7 +28,6 @@ abstract class RasterArray2ArgExpression[T <: Expression: ClassTag]( rastersExpr: Expression, arg1Expr: Expression, arg2Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends TernaryExpression @@ -37,9 +35,6 @@ abstract class RasterArray2ArgExpression[T <: Expression: ClassTag]( with Serializable with RasterExpressionSerialization { - /** Output Data Type */ - override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType - override def first: Expression = rastersExpr override def second: Expression = arg1Expr @@ -77,7 +72,8 @@ abstract class RasterArray2ArgExpression[T <: Expression: ClassTag]( GDAL.enable(expressionConfig) val tiles = RasterArrayUtils.getTiles(input, rastersExpr, expressionConfig) val result = rasterTransform(tiles, arg1, arg2) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val resultType = if (returnsRaster) RasterTileType(rastersExpr).rasterType else dataType + val serialized = serialize(result, returnsRaster, resultType, expressionConfig) tiles.foreach(t => RasterCleaner.dispose(t)) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala index b8ad9fc12..8c3a52d9a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala @@ -2,11 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, UnaryExpression} -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.ArrayType import scala.reflect.ClassTag @@ -27,7 +28,6 @@ import scala.reflect.ClassTag */ abstract class RasterArrayExpression[T <: Expression: ClassTag]( rastersExpr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends UnaryExpression @@ -37,9 +37,6 @@ abstract class RasterArrayExpression[T <: Expression: ClassTag]( override def child: Expression = rastersExpr - /** Output Data Type */ - override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the rasters to the expression. @@ -67,7 +64,8 @@ abstract class RasterArrayExpression[T <: Expression: ClassTag]( GDAL.enable(expressionConfig) val tiles = RasterArrayUtils.getTiles(input, rastersExpr, expressionConfig) val result = rasterTransform(tiles) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val resultType = if (returnsRaster) RasterTileType(rastersExpr).rasterType else dataType + val serialized = serialize(result, returnsRaster, resultType, expressionConfig) tiles.foreach(t => RasterCleaner.dispose(t)) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala index 3162bb421..f2d399350 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.expressions.raster.base +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow @@ -12,11 +13,16 @@ object RasterArrayUtils { def getTiles(input: Any, rastersExpr: Expression, expressionConfig: MosaicExpressionConfig): Seq[MosaicRasterTile] = { val rasterDT = rastersExpr.dataType.asInstanceOf[ArrayType].elementType val arrayData = input.asInstanceOf[ArrayData] + val rasterType = RasterTileType(rastersExpr).rasterType val n = arrayData.numElements() (0 until n) .map(i => MosaicRasterTile - .deserialize(arrayData.get(i, rasterDT).asInstanceOf[InternalRow], expressionConfig.getCellIdType) + .deserialize( + arrayData.get(i, rasterDT).asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala index 7cee607ca..97bd3e333 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala @@ -3,12 +3,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterBandGDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} -import org.apache.spark.sql.types.DataType import scala.reflect.ClassTag @@ -23,8 +23,6 @@ import scala.reflect.ClassTag * MOSAIC_RASTER_STORAGE is set to MOSAIC_RASTER_STORAGE_BYTE. * @param bandExpr * The expression for the band index. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -33,7 +31,6 @@ import scala.reflect.ClassTag abstract class RasterBandExpression[T <: Expression: ClassTag]( rasterExpr: Expression, bandExpr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression @@ -45,9 +42,6 @@ abstract class RasterBandExpression[T <: Expression: ClassTag]( override def right: Expression = bandExpr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster band to the @@ -79,13 +73,18 @@ abstract class RasterBandExpression[T <: Expression: ClassTag]( // noinspection DuplicatedCode override def nullSafeEval(inputRaster: Any, inputBand: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(inputRaster.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + inputRaster.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) val bandIndex = inputBand.asInstanceOf[Int] val band = tile.getRaster.getBand(bandIndex) val result = bandTransform(tile, band) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) RasterCleaner.dispose(tile) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala index 462d3204b..66435f101 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -20,8 +21,6 @@ import scala.reflect.ClassTag * The expression for the raster. If the raster is stored on disc, the path * to the raster is provided. If the raster is stored in memory, the bytes of * the raster are provided. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -29,7 +28,6 @@ import scala.reflect.ClassTag */ abstract class RasterExpression[T <: Expression: ClassTag]( rasterExpr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends UnaryExpression @@ -43,9 +41,6 @@ abstract class RasterExpression[T <: Expression: ClassTag]( override def child: Expression = rasterExpr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster to the expression. @@ -69,9 +64,14 @@ abstract class RasterExpression[T <: Expression: ClassTag]( */ override def nullSafeEval(input: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], cellIdDataType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + input.asInstanceOf[InternalRow], + cellIdDataType, + rasterType + ) val result = rasterTransform(tile) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) RasterCleaner.dispose(tile) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala index a9bf17917..dc04cb1c7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala @@ -35,11 +35,9 @@ trait RasterExpressionSerialization { ): Any = { if (returnsRaster) { val tile = data.asInstanceOf[MosaicRasterTile] - val checkpoint = expressionConfig.getRasterCheckpoint - val rasterType = outputDataType.asInstanceOf[StructType].fields(1).dataType val result = tile .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(rasterType, checkpoint) + .serialize(outputDataType) RasterCleaner.dispose(tile) result } else { diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala index 29c714788..3fc80752d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag * rasters based on the input raster. The new rasters are written in the * checkpoint directory. The files are written as GeoTiffs. Subdatasets are not * supported, please flatten beforehand. - * @param rasterExpr + * @param tileExpr * The expression for the raster. If the raster is stored on disc, the path * to the raster is provided. If the raster is stored in memory, the bytes of * the raster are provided. @@ -32,13 +32,13 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class RasterGeneratorExpression[T <: Expression: ClassTag]( - rasterExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends CollectionGenerator with NullIntolerant with Serializable { - override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) val uuid: String = java.util.UUID.randomUUID().toString.replace("-", "_") @@ -72,11 +72,12 @@ abstract class RasterGeneratorExpression[T <: Expression: ClassTag]( override def eval(input: InternalRow): TraversableOnce[InternalRow] = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(rasterExpr.eval(input).asInstanceOf[InternalRow], cellIdDataType) + val rasterType = RasterTileType(tileExpr).rasterType + val tile = MosaicRasterTile.deserialize(tileExpr.eval(input).asInstanceOf[InternalRow], cellIdDataType, rasterType) val generatedRasters = rasterGenerator(tile) // Writing rasters disposes of the written raster - val rows = generatedRasters.map(_.formatCellId(indexSystem).serialize()) + val rows = generatedRasters.map(_.formatCellId(indexSystem).serialize(rasterType)) generatedRasters.foreach(gr => RasterCleaner.dispose(gr)) RasterCleaner.dispose(tile) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala index f2545942b..98ff86ca7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag * checkpoint directory. The files are written as GeoTiffs. Subdatasets are not * supported, please flatten beforehand. * - * @param rasterExpr + * @param tileExpr * The expression for the raster. If the raster is stored on disc, the path * to the raster is provided. If the raster is stored in memory, the bytes of * the raster are provided. @@ -33,7 +33,7 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( - rasterExpr: Expression, + tileExpr: Expression, resolutionExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends CollectionGenerator @@ -55,7 +55,8 @@ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( * needs to be wrapped in a StructType. The actually type is that of the * structs element. */ - override def elementSchema: StructType = StructType(Array(StructField("element", RasterTileType(indexSystem.getCellIdDataType)))) + override def elementSchema: StructType = + StructType(Array(StructField("element", RasterTileType(indexSystem.getCellIdDataType, tileExpr)))) /** * The function to be overridden by the extending class. It is called when @@ -71,17 +72,15 @@ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( override def eval(input: InternalRow): TraversableOnce[InternalRow] = { GDAL.enable(expressionConfig) + val rasterType = RasterTileType(tileExpr).rasterType val tile = MosaicRasterTile - .deserialize( - rasterExpr.eval(input).asInstanceOf[InternalRow], - indexSystem.getCellIdDataType - ) + .deserialize(tileExpr.eval(input).asInstanceOf[InternalRow], indexSystem.getCellIdDataType, rasterType) val inResolution: Int = indexSystem.getResolution(resolutionExpr.eval(input)) val generatedChips = rasterGenerator(tile, inResolution) .map(chip => chip.formatCellId(indexSystem)) val rows = generatedChips - .map(chip => InternalRow.fromSeq(Seq(chip.formatCellId(indexSystem).serialize()))) + .map(chip => InternalRow.fromSeq(Seq(chip.formatCellId(indexSystem).serialize(rasterType)))) RasterCleaner.dispose(tile) generatedChips.foreach(chip => RasterCleaner.dispose(chip)) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala index 743f9cbd6..e7b04f989 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala @@ -37,11 +37,13 @@ abstract class RasterToGridExpression[T <: Expression: ClassTag, P]( resolution: Expression, measureType: DataType, expressionConfig: MosaicExpressionConfig -) extends Raster1ArgExpression[T](rasterExpr, resolution, RasterToGridType(expressionConfig.getCellIdType, measureType), returnsRaster = false, expressionConfig) +) extends Raster1ArgExpression[T](rasterExpr, resolution, returnsRaster = false, expressionConfig) with RasterGridExpression with NullIntolerant with Serializable { + override def dataType: DataType = RasterToGridType(expressionConfig.getCellIdType, measureType) + /** The index system to be used. */ val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 2b85c9785..8398e6882 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -8,20 +8,19 @@ import com.databricks.labs.mosaic.core.types.ChipType import com.databricks.labs.mosaic.datasource.multiread.MosaicDataFrameReader import com.databricks.labs.mosaic.expressions.constructors._ import com.databricks.labs.mosaic.expressions.format._ -import com.databricks.labs.mosaic.expressions.geometry._ import com.databricks.labs.mosaic.expressions.geometry.ST_MinMaxXYZ._ +import com.databricks.labs.mosaic.expressions.geometry._ import com.databricks.labs.mosaic.expressions.index._ import com.databricks.labs.mosaic.expressions.raster._ import com.databricks.labs.mosaic.expressions.util.TrySql -import com.databricks.labs.mosaic.functions.MosaicContext.mosaicVersion import com.databricks.labs.mosaic.utils.FileUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{LongType, StringType} +import org.apache.spark.sql.{Column, SparkSession} import scala.reflect.runtime.universe @@ -270,6 +269,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[RST_Height](expressionConfig) mosaicRegistry.registerExpression[RST_InitNoData](expressionConfig) mosaicRegistry.registerExpression[RST_IsEmpty](expressionConfig) + mosaicRegistry.registerExpression[RST_MakeTiles](expressionConfig) mosaicRegistry.registerExpression[RST_Max](expressionConfig) mosaicRegistry.registerExpression[RST_Min](expressionConfig) mosaicRegistry.registerExpression[RST_Median](expressionConfig) @@ -655,6 +655,10 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_combineavg(rasterArray: Column): Column = ColumnAdapter(RST_CombineAvg(rasterArray.expr, expressionConfig)) def rst_derivedband(raster: Column, pythonFunc: Column, funcName: Column): Column = ColumnAdapter(RST_DerivedBand(raster.expr, pythonFunc.expr, funcName.expr, expressionConfig)) + def rst_filter(raster: Column, kernelSize: Column, operation: Column): Column = + ColumnAdapter(RST_Filter(raster.expr, kernelSize.expr, operation.expr, expressionConfig)) + def rst_filter(raster: Column, kernelSize: Int, operation: String): Column = + ColumnAdapter(RST_Filter(raster.expr, lit(kernelSize).expr, lit(operation).expr, expressionConfig)) def rst_georeference(raster: Column): Column = ColumnAdapter(RST_GeoReference(raster.expr, expressionConfig)) def rst_getnodata(raster: Column): Column = ColumnAdapter(RST_GetNoData(raster.expr, expressionConfig)) def rst_getsubdataset(raster: Column, subdatasetName: Column): Column = @@ -664,6 +668,20 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_height(raster: Column): Column = ColumnAdapter(RST_Height(raster.expr, expressionConfig)) def rst_initnodata(raster: Column): Column = ColumnAdapter(RST_InitNoData(raster.expr, expressionConfig)) def rst_isempty(raster: Column): Column = ColumnAdapter(RST_IsEmpty(raster.expr, expressionConfig)) + def rst_maketiles(input: Column, driver: Column, size: Column, withCheckpoint: Column): Column = + ColumnAdapter(RST_MakeTiles(input.expr, driver.expr, size.expr, withCheckpoint.expr, expressionConfig)) + def rst_maketiles(input: Column, driver: String, size: Int, withCheckpoint: Boolean): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) + def rst_maketiles(input: Column): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(MOSAIC_NO_DRIVER).expr, lit(-1).expr, lit(false).expr, expressionConfig)) + def rst_maketiles(input: Column, size: Int): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(MOSAIC_NO_DRIVER).expr, lit(size).expr, lit(false).expr, expressionConfig)) + def rst_maketiles(input: Column, driver: String): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(-1).expr, lit(false).expr, expressionConfig)) + def rst_maketiles(input: Column, driver: String, withCheckpoint: Boolean): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(-1).expr, lit(withCheckpoint).expr, expressionConfig)) + def rst_maketiles(input: Column, size: Int, withCheckpoint: Boolean): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(MOSAIC_NO_DRIVER).expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) def rst_max(raster: Column): Column = ColumnAdapter(RST_Max(raster.expr, expressionConfig)) def rst_min(raster: Column): Column = ColumnAdapter(RST_Min(raster.expr, expressionConfig)) def rst_median(raster: Column): Column = ColumnAdapter(RST_Median(raster.expr, expressionConfig)) diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala index f306d4e9c..d6643f59b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala @@ -33,6 +33,8 @@ case class MosaicExpressionConfig(configs: Map[String, String]) { def getRasterCheckpoint: String = configs.getOrElse(MOSAIC_RASTER_CHECKPOINT, MOSAIC_RASTER_CHECKPOINT_DEFAULT) def getCellIdType: DataType = IndexSystemFactory.getIndexSystem(getIndexSystem).cellIdType + + def getRasterBlockSize: Int = configs.getOrElse(MOSAIC_RASTER_BLOCKSIZE, MOSAIC_RASTER_BLOCKSIZE_DEFAULT).toInt def setGDALConf(conf: RuntimeConfig): MosaicExpressionConfig = { val toAdd = conf.getAll.filter(_._1.startsWith(MOSAIC_GDAL_PREFIX)) diff --git a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala index 9e8bf1132..92438844a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala @@ -1,11 +1,12 @@ package com.databricks.labs.mosaic.gdal +import com.databricks.labs.mosaic.MOSAIC_RASTER_BLOCKSIZE_DEFAULT import com.databricks.labs.mosaic.functions.{MosaicContext, MosaicExpressionConfig} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.gdal.gdal.gdal +import org.gdal.osr.SpatialReference -import java.io.{BufferedInputStream, File, PrintWriter} import java.nio.file.{Files, Paths} import scala.language.postfixOps import scala.util.Try @@ -22,9 +23,22 @@ object MosaicGDAL extends Logging { private val libjniso3003Path = "/usr/lib/libgdalalljni.so.30.0.3" private val libogdisoPath = "/usr/lib/ogdi/4.1/libgdal.so" + val defaultBlockSize = 1024 + val vrtBlockSize = 128 // This is a must value for VRTs before GDAL 3.7 + var blockSize: Int = MOSAIC_RASTER_BLOCKSIZE_DEFAULT.toInt + // noinspection ScalaWeakerAccess val GDAL_ENABLED = "spark.mosaic.gdal.native.enabled" var isEnabled = false + var checkpointPath: String = _ + + // Only use this with GDAL rasters + val WSG84: SpatialReference = { + val wsg84 = new SpatialReference() + wsg84.ImportFromEPSG(4326) + wsg84.SetAxisMappingStrategy(org.gdal.osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) + wsg84 + } /** Returns true if GDAL is enabled. */ def wasEnabled(spark: SparkSession): Boolean = @@ -33,15 +47,28 @@ object MosaicGDAL extends Logging { /** Configures the GDAL environment. */ def configureGDAL(mosaicConfig: MosaicExpressionConfig): Unit = { val CPL_TMPDIR = MosaicContext.tmpDir - val GDAL_PAM_PROXY_DIR = MosaicContext.tmpDir gdal.SetConfigOption("GDAL_VRT_ENABLE_PYTHON", "YES") - gdal.SetConfigOption("GDAL_DISABLE_READDIR_ON_OPEN", "EMPTY_DIR") + gdal.SetConfigOption("GDAL_DISABLE_READDIR_ON_OPEN", "TRUE") gdal.SetConfigOption("CPL_TMPDIR", CPL_TMPDIR) - gdal.SetConfigOption("GDAL_PAM_PROXY_DIR", GDAL_PAM_PROXY_DIR) - gdal.SetConfigOption("GDAL_PAM_ENABLED", "NO") - gdal.SetConfigOption("CPL_VSIL_USE_TEMP_FILE_FOR_RANDOM_WRITE", "NO") gdal.SetConfigOption("CPL_LOG", s"$CPL_TMPDIR/gdal.log") + gdal.SetConfigOption("GDAL_CACHEMAX", "512") + gdal.SetConfigOption("GDAL_NUM_THREADS", "ALL_CPUS") mosaicConfig.getGDALConf.foreach { case (k, v) => gdal.SetConfigOption(k.split("\\.").last, v) } + setBlockSize(mosaicConfig) + checkpointPath = mosaicConfig.getRasterCheckpoint + } + + def setBlockSize(mosaicConfig: MosaicExpressionConfig): Unit = { + val blockSize = mosaicConfig.getRasterBlockSize + if (blockSize > 0) { + this.blockSize = blockSize + } + } + + def setBlockSize(size: Int): Unit = { + if (size > 0) { + this.blockSize = size + } } /** Enables the GDAL environment. */ @@ -91,18 +118,19 @@ object MosaicGDAL extends Logging { } } - /** Reads the resource bytes. */ - private def readResourceBytes(name: String): Array[Byte] = { - val bis = new BufferedInputStream(getClass.getResourceAsStream(name)) - try { Stream.continually(bis.read()).takeWhile(-1 !=).map(_.toByte).toArray } - finally bis.close() - } +// /** Reads the resource bytes. */ +// private def readResourceBytes(name: String): Array[Byte] = { +// val bis = new BufferedInputStream(getClass.getResourceAsStream(name)) +// try { Stream.continually(bis.read()).takeWhile(-1 !=).map(_.toByte).toArray } +// finally bis.close() +// } + +// /** Reads the resource lines. */ +// // noinspection SameParameterValue +// private def readResourceLines(name: String): Array[String] = { +// val bytes = readResourceBytes(name) +// val lines = new String(bytes).split("\n") +// lines +// } - /** Reads the resource lines. */ - // noinspection SameParameterValue - private def readResourceLines(name: String): Array[String] = { - val bytes = readResourceBytes(name) - val lines = new String(bytes).split("\n") - lines - } } diff --git a/src/main/scala/com/databricks/labs/mosaic/package.scala b/src/main/scala/com/databricks/labs/mosaic/package.scala index 58ee2f98e..eea63cd79 100644 --- a/src/main/scala/com/databricks/labs/mosaic/package.scala +++ b/src/main/scala/com/databricks/labs/mosaic/package.scala @@ -21,13 +21,17 @@ package object mosaic { val MOSAIC_GDAL_PREFIX = "spark.databricks.labs.mosaic.gdal." val MOSAIC_GDAL_NATIVE = "spark.databricks.labs.mosaic.gdal.native" val MOSAIC_RASTER_CHECKPOINT = "spark.databricks.labs.mosaic.raster.checkpoint" - val MOSAIC_RASTER_CHECKPOINT_DEFAULT = "dbfs:/tmp/mosaic/raster/checkpoint" + val MOSAIC_RASTER_CHECKPOINT_DEFAULT = "/dbfs/tmp/mosaic/raster/checkpoint" + val MOSAIC_RASTER_BLOCKSIZE = "spark.databricks.labs.mosaic.raster.blocksize" + val MOSAIC_RASTER_BLOCKSIZE_DEFAULT = "128" val MOSAIC_RASTER_READ_STRATEGY = "raster.read.strategy" val MOSAIC_RASTER_READ_IN_MEMORY = "in_memory" val MOSAIC_RASTER_READ_AS_PATH = "as_path" val MOSAIC_RASTER_RE_TILE_ON_READ = "retile_on_read" + val MOSAIC_NO_DRIVER = "no_driver" + def read: MosaicDataFrameReader = new MosaicDataFrameReader(SparkSession.builder().getOrCreate()) diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala index a1aac5c2f..fc01cfaa0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala @@ -22,10 +22,10 @@ object FileUtils { bytes } - def createMosaicTempDir(): String = { - val tempRoot = Paths.get("/mosaic_tmp/") + def createMosaicTempDir(prefix: String = ""): String = { + val tempRoot = Paths.get(s"$prefix/mosaic_tmp/") if (!Files.exists(tempRoot)) { - Files.createDirectory(tempRoot) + Files.createDirectories(tempRoot) } val tempDir = Files.createTempDirectory(tempRoot, "mosaic") tempDir.toFile.getAbsolutePath diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala index d48c03bfd..469bb0f44 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala @@ -1,7 +1,5 @@ package com.databricks.labs.mosaic.utils -import com.databricks.labs.mosaic.core.raster.api.GDAL -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.functions.MosaicContext import java.nio.file.{Files, Paths} @@ -10,11 +8,15 @@ object PathUtils { val NO_PATH_STRING = "no_path" - def getCleanPath(path: String): String = { - val cleanPath = path + def replaceDBFSTokens(path: String): String = { + path .replace("file:/", "/") .replace("dbfs:/Volumes", "/Volumes") - .replace("dbfs:/","/dbfs/") + .replace("dbfs:/", "/dbfs/") + } + + def getCleanPath(path: String): String = { + val cleanPath = replaceDBFSTokens(path) if (cleanPath.endsWith(".zip") || cleanPath.contains(".zip:")) { getZipPath(cleanPath) } else { @@ -61,17 +63,51 @@ object PathUtils { if (filePath.endsWith("\"")) result = result.dropRight(1) result } + + def getStemRegex(path: String): String = { + val cleanPath = replaceDBFSTokens(path) + val fileName = Paths.get(cleanPath).getFileName.toString + val stemName = fileName.substring(0, fileName.lastIndexOf(".")) + val stemEscaped = stemName.replace(".", "\\.") + val stemRegex = s"$stemEscaped\\..*".r + stemRegex.toString + } - def copyToTmp(inPath: String): String = { - val copyFromPath = inPath - .replace("file:/", "/") - .replace("dbfs:/Volumes", "/Volumes") - .replace("dbfs:/","/dbfs/") - val driver = MosaicRasterGDAL.identifyDriver(getCleanPath(inPath)) - val extension = if (inPath.endsWith(".zip")) "zip" else GDAL.getExtension(driver) - val tmpPath = createTmpFilePath(extension) - Files.copy(Paths.get(copyFromPath), Paths.get(tmpPath)) - tmpPath + def copyToTmp(inPath: String): String = { + val copyFromPath = replaceDBFSTokens(inPath) + val inPathDir = Paths.get(copyFromPath).getParent.toString + + val fullFileName = copyFromPath.split("/").last + val stemRegex = getStemRegex(inPath) + + wildcardCopy(inPathDir, MosaicContext.tmpDir, stemRegex.toString) + + s"${MosaicContext.tmpDir}/$fullFileName" + } + + def wildcardCopy(inDirPath: String, outDirPath: String, pattern: String): Unit = { + import org.apache.commons.io.FileUtils + val copyFromPath = replaceDBFSTokens(inDirPath) + val copyToPath = replaceDBFSTokens(outDirPath) + + val toCopy = Files + .list(Paths.get(copyFromPath)) + .filter(_.getFileName.toString.matches(pattern)) + + toCopy.forEach(path => { + val destination = Paths.get(copyToPath, path.getFileName.toString) + //noinspection SimplifyBooleanMatch + Files.isDirectory(path) match { + case true => FileUtils.copyDirectory(path.toFile, destination.toFile) + case false => Files.copy(path, destination) + } + }) + } + + def parseUnzippedPathFromExtracted(lastExtracted: String, extension: String): String = { + val trimmed = lastExtracted.replace("extracting: ", "").replace(" ", "") + val indexOfFormat = trimmed.indexOf(s".$extension/") + trimmed.substring(0, indexOfFormat + extension.length + 1) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala index 85fa12785..ba1d9c417 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.utils -import java.io.{ByteArrayOutputStream, PrintWriter} +import java.io.{BufferedReader, ByteArrayOutputStream, InputStreamReader, PrintWriter} object SysUtils { @@ -11,16 +11,40 @@ object SysUtils { val stderrStream = new ByteArrayOutputStream val stdoutWriter = new PrintWriter(stdoutStream) val stderrWriter = new PrintWriter(stderrStream) - val exitValue = try { - //noinspection ScalaStyle - cmd.!!(ProcessLogger(stdoutWriter.println, stderrWriter.println)) - } catch { - case _: Exception => "ERROR" - } finally { - stdoutWriter.close() - stderrWriter.close() - } + val exitValue = + try { + // noinspection ScalaStyle + cmd.!!(ProcessLogger(stdoutWriter.println, stderrWriter.println)) + } catch { + case e: Exception => s"ERROR: ${e.getMessage}" + } finally { + stdoutWriter.close() + stderrWriter.close() + } (exitValue, stdoutStream.toString, stderrStream.toString) } + def runScript(cmd: Array[String]): (String, String, String) = { + val p = Runtime.getRuntime.exec(cmd) + val stdinStream = new BufferedReader(new InputStreamReader(p.getInputStream)) + val stderrStream = new BufferedReader(new InputStreamReader(p.getErrorStream)) + val exitValue = + try { + p.waitFor() + } catch { + case e: Exception => s"ERROR: ${e.getMessage}" + } + val stdinOutput = stdinStream.lines().toArray.mkString("\n") + val stderrOutput = stderrStream.lines().toArray.mkString("\n") + stdinStream.close() + stderrStream.close() + (s"$exitValue", stdinOutput, stderrOutput) + } + + def getLastOutputLine(prompt: (String, String, String)): String = { + val (_, stdout, _) = prompt + val lines = stdout.split("\n") + lines.last + } + } diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb.aux.xml similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb.aux.xml diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib.aux.xml rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb.aux.xml similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb.aux.xml diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala index 15eef2009..1337ae6d2 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala @@ -37,8 +37,8 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { assume(System.getProperty("os.name") == "Linux") val testRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib"), - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") ) val testBand = testRaster.getBand(1) testBand.description shouldBe "1[-] HYBL=\"Hybrid level\"" diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala index e39279843..bb53d6b79 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala @@ -1,9 +1,12 @@ package com.databricks.labs.mosaic.core.raster import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.gdal.MosaicGDAL import com.databricks.labs.mosaic.test.mocks.filePath import org.apache.spark.sql.test.SharedSparkSessionGDAL import org.scalatest.matchers.should.Matchers._ +import org.gdal.gdal.{gdal => gdalJNI} +import org.gdal.gdalconst import scala.sys.process._ import scala.util.Try @@ -43,7 +46,7 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { testRaster.SRID shouldBe 0 testRaster.extent shouldBe Seq(-8895604.157333, 1111950.519667, -7783653.637667, 2223901.039333) testRaster.getRaster.GetProjection() - noException should be thrownBy testRaster.spatialRef + noException should be thrownBy testRaster.getSpatialReference an[Exception] should be thrownBy testRaster.getBand(-1) an[Exception] should be thrownBy testRaster.getBand(Int.MaxValue) @@ -54,8 +57,8 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { assume(System.getProperty("os.name") == "Linux") val testRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib"), - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), + filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") ) testRaster.xSize shouldBe 14 testRaster.ySize shouldBe 14 @@ -96,8 +99,8 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { assume(System.getProperty("os.name") == "Linux") val testRaster = MosaicRasterGDAL.readRaster( - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") ) testRaster.pixelXSize - 463.312716527 < 0.0000001 shouldBe true @@ -115,4 +118,214 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { testRaster.getRaster.delete() } + test("Raster filter operations are correct.") { + assume(System.getProperty("os.name") == "Linux") + + gdalJNI.AllRegister() + + MosaicGDAL.setBlockSize(30) + + val ds = gdalJNI.GetDriverByName("GTiff").Create("/mosaic_tmp/test.tif", 50, 50, 1, gdalconst.gdalconstConstants.GDT_Float32) + + val values = 0 until 50 * 50 + ds.GetRasterBand(1).WriteRaster(0, 0, 50, 50, values.toArray) + ds.FlushCache() + + var result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "avg").flushCache() + + var resultValues = result.getBand(1).values + + var inputMatrix = values.toArray.grouped(50).toArray + var resultMatrix = resultValues.grouped(50).toArray + + // first block + resultMatrix(10)(11) shouldBe ( + inputMatrix(8)(9) + inputMatrix(8)(10) + inputMatrix(8)(11) + inputMatrix(8)(12) + inputMatrix(8)(13) + + inputMatrix(9)(9) + inputMatrix(9)(10) + inputMatrix(9)(11) + inputMatrix(9)(12) + inputMatrix(9)(13) + + inputMatrix(10)(9) + inputMatrix(10)(10) + inputMatrix(10)(11) + inputMatrix(10)(12) + inputMatrix(10)(13) + + inputMatrix(11)(9) + inputMatrix(11)(10) + inputMatrix(11)(11) + inputMatrix(11)(12) + inputMatrix(11)(13) + + inputMatrix(12)(9) + inputMatrix(12)(10) + inputMatrix(12)(11) + inputMatrix(12)(12) + inputMatrix(12)(13) + ).toDouble / 25.0 + + // block overlap + resultMatrix(30)(32) shouldBe ( + inputMatrix(28)(30) + inputMatrix(28)(31) + inputMatrix(28)(32) + inputMatrix(28)(33) + inputMatrix(28)(34) + + inputMatrix(29)(30) + inputMatrix(29)(31) + inputMatrix(29)(32) + inputMatrix(29)(33) + inputMatrix(29)(34) + + inputMatrix(30)(30) + inputMatrix(30)(31) + inputMatrix(30)(32) + inputMatrix(30)(33) + inputMatrix(30)(34) + + inputMatrix(31)(30) + inputMatrix(31)(31) + inputMatrix(31)(32) + inputMatrix(31)(33) + inputMatrix(31)(34) + + inputMatrix(32)(30) + inputMatrix(32)(31) + inputMatrix(32)(32) + inputMatrix(32)(33) + inputMatrix(32)(34) + ).toDouble / 25.0 + + // mode + + result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "mode").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).groupBy(identity).maxBy(_._2.size)._1.toDouble + + // corner + + resultMatrix(49)(49) shouldBe Seq( + inputMatrix(47)(47), + inputMatrix(47)(48), + inputMatrix(47)(49), + inputMatrix(48)(47), + inputMatrix(48)(48), + inputMatrix(48)(49), + inputMatrix(49)(47), + inputMatrix(49)(48), + inputMatrix(49)(49) + ).groupBy(identity).maxBy(_._2.size)._1.toDouble + + // median + + result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "median").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).sorted.apply(12).toDouble + + // min filter + + result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "min").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).min.toDouble + + // max filter + + result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "max").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).max.toDouble + + } + } diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala index 99a1563ca..623993b01 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala @@ -34,7 +34,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .take(1) } - + test("Read grib with GDALFileFormat") { assume(System.getProperty("os.name") == "Linux") @@ -43,25 +43,22 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { noException should be thrownBy spark.read .format("gdal") - .option("extensions", "grib") - .option("raster_storage", "disk") - .option("extensions", "grib") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") .load(filePath) .take(1) noException should be thrownBy spark.read .format("gdal") - .option("extensions", "grib") - .option("raster_storage", "disk") - .option("extensions", "grib") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") .load(filePath) .take(1) noException should be thrownBy spark.read .format("gdal") - .option("extensions", "grib") - .option("raster_storage", "disk") - .option("extensions", "grib") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") .load(filePath) .select("metadata") .take(1) @@ -92,7 +89,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .select("metadata") .take(1) - noException should be thrownBy spark.read + noException should be thrownBy spark.read .format("gdal") .option(MOSAIC_RASTER_READ_STRATEGY, "retile_on_read") .load(filePath) diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala index 6e99aa1df..fba2b74cb 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala @@ -1,12 +1,12 @@ package com.databricks.labs.mosaic.datasource.multiread -import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.JTS import com.databricks.labs.mosaic.core.index.H3IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest import org.apache.spark.sql.test.SharedSparkSessionGDAL import org.scalatest.matchers.must.Matchers.{be, noException} -import org.scalatest.matchers.should.Matchers.{an, convertToAnyShouldWrapper} +import org.scalatest.matchers.should.Matchers.an import java.nio.file.{Files, Paths} @@ -14,149 +14,161 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess test("Read big tif with Raster As Grid Reader") { assume(System.getProperty("os.name") == "Linux") - spark.sparkContext.setLogLevel("INFO") MosaicContext.build(H3IndexSystem, JTS) - val tif = "/binary/big_tiff.tif" + val tif = "/modis/" val filePath = getClass.getResource(tif).getPath val df = MosaicContext.read .format("raster_to_grid") .option("retile", "true") - .option("sizeInMB", "64") + .option("sizeInMB", "128") .option("resolution", "1") .load(filePath) .select("measure") - //df.queryExecution.optimizedPlan + df.queryExecution.optimizedPlan - //noException should be thrownBy df.queryExecution.executedPlan + noException should be thrownBy df.queryExecution.executedPlan df.count() } -// test("Read netcdf with Raster As Grid Reader") { -// assume(System.getProperty("os.name") == "Linux") -// MosaicContext.build(H3IndexSystem, JTS) -// -// val netcdf = "/binary/netcdf-coral/" -// val filePath = getClass.getResource(netcdf).getPath -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("retile", "true") -// .option("tileSize", "10") -// .option("readSubdataset", "true") -// .option("subdataset", "1") -// .option("kRingInterpolate", "3") -// .load(filePath) -// .select("measure") -// .queryExecution -// .executedPlan -// -// } -// -// test("Read grib with Raster As Grid Reader") { -// assume(System.getProperty("os.name") == "Linux") -// MosaicContext.build(H3IndexSystem, JTS) -// -// val grib = "/binary/grib-cams/" -// val filePath = getClass.getResource(grib).getPath -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("extensions", "grib") -// .option("combiner", "min") -// .option("retile", "true") -// .option("tileSize", "10") -// .option("kRingInterpolate", "3") -// .load(filePath) -// .select("measure") -// .take(1) -// -// } -// -// test("Read tif with Raster As Grid Reader") { -// assume(System.getProperty("os.name") == "Linux") -// MosaicContext.build(H3IndexSystem, JTS) -// -// val tif = "/modis/" -// val filePath = getClass.getResource(tif).getPath -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "max") -// .option("tileSize", "10") -// .option("kRingInterpolate", "3") -// .load(filePath) -// .select("measure") -// .take(1) -// -// } -// -// test("Read zarr with Raster As Grid Reader") { -// assume(System.getProperty("os.name") == "Linux") -// MosaicContext.build(H3IndexSystem, JTS) -// -// val zarr = "/binary/zarr-example/" -// val filePath = getClass.getResource(zarr).getPath -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "median") -// .option("vsizip", "true") -// .option("tileSize", "10") -// .load(filePath) -// .select("measure") -// .take(1) -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "count") -// .option("vsizip", "true") -// .load(filePath) -// .select("measure") -// .take(1) -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "average") -// .option("vsizip", "true") -// .load(filePath) -// .select("measure") -// .take(1) -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "avg") -// .option("vsizip", "true") -// .load(filePath) -// .select("measure") -// .take(1) -// -// val paths = Files.list(Paths.get(filePath)).toArray.map(_.toString) -// -// an[Error] should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("combiner", "count_+") -// .option("vsizip", "true") -// .load(paths: _*) -// .select("measure") -// .take(1) -// -// an[Error] should be thrownBy MosaicContext.read -// .format("invalid") -// .load(paths: _*) -// -// an[Error] should be thrownBy MosaicContext.read -// .format("invalid") -// .load(filePath) -// -// noException should be thrownBy MosaicContext.read -// .format("raster_to_grid") -// .option("kRingInterpolate", "3") -// .load(filePath) -// -// } + test("Read netcdf with Raster As Grid Reader") { + assume(System.getProperty("os.name") == "Linux") + MosaicContext.build(H3IndexSystem, JTS) + + val netcdf = "/binary/netcdf-coral/" + val filePath = getClass.getResource(netcdf).getPath + + //noException should be thrownBy + + + MosaicContext.read + .format("raster_to_grid") + .option("retile", "true") + .option("tileSize", "10") + .option("readSubdataset", "true") + .option("subdataset", "1") + .option("kRingInterpolate", "3") + .load(filePath) + .select("measure") + .queryExecution + .executedPlan + + } + + test("Read grib with Raster As Grid Reader") { + assume(System.getProperty("os.name") == "Linux") + MosaicContext.build(H3IndexSystem, JTS) + + val grib = "/binary/grib-cams/" + val filePath = getClass.getResource(grib).getPath + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("extensions", "grib") + .option("combiner", "min") + .option("retile", "true") + .option("tileSize", "10") + .option("kRingInterpolate", "3") + .load(filePath) + .select("measure") + .take(1) + + } + + test("Read tif with Raster As Grid Reader") { + assume(System.getProperty("os.name") == "Linux") + MosaicContext.build(H3IndexSystem, JTS) + + val tif = "/modis/" + val filePath = getClass.getResource(tif).getPath + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("combiner", "max") + .option("tileSize", "10") + .option("kRingInterpolate", "3") + .load(filePath) + .select("measure") + .take(1) + + } + + test("Read zarr with Raster As Grid Reader") { + assume(System.getProperty("os.name") == "Linux") + MosaicContext.build(H3IndexSystem, JTS) + + val zarr = "/binary/zarr-example/" + val filePath = getClass.getResource(zarr).getPath + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("combiner", "median") + .option("vsizip", "true") + .option("tileSize", "10") + .load(filePath) + .select("measure") + .take(1) + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("combiner", "count") + .option("vsizip", "true") + .load(filePath) + .select("measure") + .take(1) + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("combiner", "average") + .option("vsizip", "true") + .load(filePath) + .select("measure") + .take(1) + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("combiner", "avg") + .option("vsizip", "true") + .load(filePath) + .select("measure") + .take(1) + + val paths = Files.list(Paths.get(filePath)).toArray.map(_.toString) + + an[Error] should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("combiner", "count_+") + .option("vsizip", "true") + .load(paths: _*) + .select("measure") + .take(1) + + an[Error] should be thrownBy MosaicContext.read + .format("invalid") + .load(paths: _*) + + an[Error] should be thrownBy MosaicContext.read + .format("invalid") + .load(filePath) + + noException should be thrownBy MosaicContext.read + .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") + .option("kRingInterpolate", "3") + .load(filePath) + + } } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala index 8ce57f5b8..611bf8f77 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala @@ -35,7 +35,9 @@ trait RST_CombineAvgBehaviors extends QueryTest { rastersInMemory.union(rastersInMemory) .createOrReplaceTempView("source") - noException should be thrownBy spark.sql(""" + //noException should be thrownBy + + spark.sql(""" |select rst_combineavg(collect_set(tiles)) as tiles |from ( | select path, rst_tessellate(tile, 2) as tiles diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala new file mode 100644 index 000000000..d06923dc1 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala @@ -0,0 +1,36 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_FilterBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("result", rst_filter($"tile", 3, "mode")) + .select("result") + .collect() + + gridTiles.length should be(7) + + rastersInMemory.createOrReplaceTempView("source") + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala new file mode 100644 index 000000000..a243f7168 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_FilterTest extends QueryTest with SharedSparkSessionGDAL with RST_FilterBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_Filter with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala index bd867ee65..d01f79fec 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala @@ -37,7 +37,7 @@ trait RST_MinBehaviors extends QueryTest { val result = df.as[Double].collect().min - result < 0 shouldBe true + result == 0 shouldBe true an[Exception] should be thrownBy spark.sql(""" |select rst_min() from source diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala index c346e82db..cfffd9e6b 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ trait RST_TessellateBehaviors extends QueryTest { @@ -24,8 +25,10 @@ trait RST_TessellateBehaviors extends QueryTest { val gridTiles = rastersInMemory .withColumn("tiles", rst_tessellate($"tile", 3)) - .select("tiles") - + .withColumn("bbox", st_aswkt(rst_boundingbox($"tile"))) + .select("bbox", "path", "tiles") + .withColumn("avg", rst_avg($"tiles")) + rastersInMemory .createOrReplaceTempView("source") @@ -37,9 +40,9 @@ trait RST_TessellateBehaviors extends QueryTest { .withColumn("tiles", rst_tessellate($"tile", 3)) .select("tiles") - val result = gridTiles.collect() + val result = gridTiles.select(explode(col("avg")).alias("a")).groupBy("a").count().collect() - result.length should be(380) + result.length should be(441) } diff --git a/src/test/scala/com/databricks/labs/mosaic/test/package.scala b/src/test/scala/com/databricks/labs/mosaic/test/package.scala index 435ee552c..2e5951be7 100644 --- a/src/test/scala/com/databricks/labs/mosaic/test/package.scala +++ b/src/test/scala/com/databricks/labs/mosaic/test/package.scala @@ -164,7 +164,7 @@ package object test { } val geotiffBytes: Array[Byte] = fileBytes("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") val gribBytes: Array[Byte] = - fileBytes("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") + fileBytes("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") val netcdfBytes: Array[Byte] = fileBytes("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") def polyDf(sparkSession: SparkSession, mosaicContext: MosaicContext): DataFrame = { diff --git a/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala b/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala index 8029c30a7..84a613b31 100644 --- a/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala +++ b/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala @@ -8,13 +8,13 @@ class MosaicTestSparkSession(sc: SparkContext) extends TestSparkSession(sc) { this( new SparkContext( - "local[4]", + "local[8]", "test-sql-context", sparkConf .set("spark.sql.adaptive.enabled", "false") .set("spark.driver.memory", "32g") .set("spark.executor.memory", "32g") - .set("spark.sql.shuffle.partitions", "4") + .set("spark.sql.shuffle.partitions", "8") .set("spark.sql.testkey", "true") ) ) diff --git a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala index 984fff9d8..12dcac6f3 100644 --- a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala +++ b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala @@ -1,6 +1,5 @@ package org.apache.spark.sql.test -import com.databricks.labs.mosaic._ import com.databricks.labs.mosaic.gdal.MosaicGDAL import com.databricks.labs.mosaic.utils.FileUtils import com.databricks.labs.mosaic.{MOSAIC_GDAL_NATIVE, MOSAIC_RASTER_CHECKPOINT} @@ -8,7 +7,6 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.gdal.gdal.gdal -import java.nio.file.{Files, Paths} import scala.util.Try trait SharedSparkSessionGDAL extends SharedSparkSession { @@ -20,10 +18,10 @@ trait SharedSparkSessionGDAL extends SharedSparkSession { override def createSparkSession: TestSparkSession = { val conf = sparkConf - conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir()) + conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir(prefix = "/mnt/")) SparkSession.cleanupAnyExistingSession() val session = new MosaicTestSparkSession(conf) - session.sparkContext.setLogLevel("INFO") + session.sparkContext.setLogLevel("FATAL") Try { MosaicGDAL.enableGDAL(session) } From fa80aabb3c73848c9e9e44d6acf97fb23df7e359 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Tue, 13 Feb 2024 13:02:22 +0000 Subject: [PATCH 04/26] Fix zip logic for zarr files. --- pom.xml | 2 +- .../labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala | 4 +++- .../mosaic/datasource/multiread/RasterAsGridReaderTest.scala | 5 +---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pom.xml b/pom.xml index cdb4c8d2c..2cb2f8787 100644 --- a/pom.xml +++ b/pom.xml @@ -149,7 +149,7 @@ org.scoverage scoverage-maven-plugin - 2.0.1 + 2.0.2 scoverage-report diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index b63bd851e..0309d9bab 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -427,7 +427,9 @@ case class MosaicRasterGDAL( path } if (Files.isDirectory(Paths.get(tmpPath))) { - SysUtils.runCommand(s"zip -r0 $tmpPath.zip $tmpPath") + val parentDir = Paths.get(tmpPath).getParent.toString + val fileName = Paths.get(tmpPath).getFileName.toString + SysUtils.runScript(Array("/bin/sh", "-c", s"cd $parentDir && zip -r0 $fileName.zip $fileName")) s"$tmpPath.zip" } else { tmpPath diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala index fba2b74cb..f174954cd 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala @@ -41,10 +41,7 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess val netcdf = "/binary/netcdf-coral/" val filePath = getClass.getResource(netcdf).getPath - //noException should be thrownBy - - - MosaicContext.read + noException should be thrownBy MosaicContext.read .format("raster_to_grid") .option("retile", "true") .option("tileSize", "10") From 5a85ef91db08a8a67a370c53ee3ef00b0cb42e2a Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Tue, 13 Feb 2024 14:13:12 +0000 Subject: [PATCH 05/26] Fix zarr zip paths. --- .../labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala | 2 +- src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index 0309d9bab..a083ba8b8 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -677,7 +677,7 @@ object MosaicRasterGDAL extends RasterReader { // the way we zip using uuid is not compatible with GDAL // we need to unzip and read the file if it was zipped by us val parentDir = Paths.get(zippedPath).getParent - val prompt = SysUtils.runScript(Array("/bin/sh", "-c", s"cd $parentDir && unzip -o $zippedPath -d /")) + val prompt = SysUtils.runScript(Array("/bin/sh", "-c", s"cd $parentDir && unzip -o $zippedPath -d $parentDir")) // zipped files will have the old uuid name of the raster // we need to get the last extracted file name, but the last extracted file name is not the raster name // we can't list folders due to concurrent writes diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala index fc01cfaa0..a36c0bec0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala @@ -22,7 +22,7 @@ object FileUtils { bytes } - def createMosaicTempDir(prefix: String = ""): String = { + def createMosaicTempDir(prefix: String = "/tmp"): String = { val tempRoot = Paths.get(s"$prefix/mosaic_tmp/") if (!Files.exists(tempRoot)) { Files.createDirectories(tempRoot) From 8356b6a0bd3efe614108d95e03245f97ebef8515 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 26 Feb 2024 14:07:15 +0000 Subject: [PATCH 06/26] Add COG to format extensions list. Add createInfo concept, this will contain driver, parentPath, currentPath, etc. Make gdal programs no-failure operations. Capture errors and warnings of gdal programs in the raster tile metadata. Add RST_Transform expression. Add ReadAsPath reading strategy. --- python/mosaic/api/raster.py | 27 ++ .../mosaic/core/raster/api/FormatLookup.scala | 1 + .../labs/mosaic/core/raster/api/GDAL.scala | 30 +- .../core/raster/gdal/MosaicRasterGDAL.scala | 151 ++++++---- .../mosaic/core/raster/io/RasterReader.scala | 38 +-- .../raster/operator/gdal/GDALBuildVRT.scala | 19 +- .../core/raster/operator/gdal/GDALCalc.scala | 31 ++- .../core/raster/operator/gdal/GDALInfo.scala | 14 +- .../raster/operator/gdal/GDALTranslate.scala | 20 +- .../core/raster/operator/gdal/GDALWarp.scala | 25 +- .../operator/retile/OverlappingTiles.scala | 2 +- .../operator/retile/RasterTessellate.scala | 4 +- .../core/raster/operator/retile/ReTile.scala | 2 +- .../operator/separate/SeparateBands.scala | 13 +- .../mosaic/core/types/RasterTileType.scala | 3 +- .../core/types/model/MosaicRasterTile.scala | 48 ++-- .../mosaic/datasource/gdal/ReTileOnRead.scala | 5 +- .../mosaic/datasource/gdal/ReadAsPath.scala | 124 +++++++++ .../mosaic/datasource/gdal/ReadInMemory.scala | 12 +- .../mosaic/datasource/gdal/ReadStrategy.scala | 1 + .../expressions/raster/RST_CombineAvg.scala | 7 +- .../raster/RST_CombineAvgAgg.scala | 5 +- .../expressions/raster/RST_DerivedBand.scala | 4 +- .../raster/RST_DerivedBandAgg.scala | 5 +- .../expressions/raster/RST_FromContent.scala | 5 +- .../expressions/raster/RST_FromFile.scala | 7 +- .../expressions/raster/RST_MakeTiles.scala | 5 +- .../expressions/raster/RST_MapAlgebra.scala | 8 +- .../expressions/raster/RST_MergeAgg.scala | 5 +- .../expressions/raster/RST_SetSRID.scala | 8 +- .../expressions/raster/RST_Transform.scala | 61 ++++ .../mosaic/expressions/raster/package.scala | 19 +- .../labs/mosaic/functions/MosaicContext.scala | 3 + .../labs/mosaic/gdal/MosaicGDAL.scala | 1 + .../core/raster/TestRasterBandGDAL.scala | 28 +- .../mosaic/core/raster/TestRasterGDAL.scala | 262 +++++++++--------- .../raster/RST_CombineAvgBehaviors.scala | 6 +- .../raster/RST_DerivedBandBehaviors.scala | 6 +- .../raster/RST_MergeBehaviors.scala | 8 +- .../raster/RST_TessellateBehaviors.scala | 2 +- .../raster/RST_TransformBehaviors.scala | 49 ++++ .../raster/RST_TransformTest.scala | 32 +++ 42 files changed, 744 insertions(+), 362 deletions(-) create mode 100644 src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadAsPath.scala create mode 100644 src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Transform.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformTest.scala diff --git a/python/mosaic/api/raster.py b/python/mosaic/api/raster.py index 3638510dc..c61bafcd2 100644 --- a/python/mosaic/api/raster.py +++ b/python/mosaic/api/raster.py @@ -55,6 +55,7 @@ "rst_subdivide", "rst_summary", "rst_tessellate", + "rst_transform", "rst_to_overlapping_tiles", "rst_tryopen", "rst_upperleftx", @@ -997,6 +998,32 @@ def rst_tessellate(raster_tile: ColumnOrName, resolution: ColumnOrName) -> Colum ) +def rst_transform(raster_tile: ColumnOrName, srid: ColumnOrName) -> Column: + """ + Transforms the raster to the given SRID. + The result is a Mosaic raster tile struct of the transformed raster. + The result is stored in the checkpoint directory. + + Parameters + ---------- + raster_tile : Column (RasterTileType) + Mosaic raster tile struct column. + srid : Column (IntegerType) + EPSG authority code for the file's projection. + + Returns + ------- + Column (RasterTileType) + Mosaic raster tile struct column. + + """ + return config.mosaic_context.invoke_function( + "rst_transform", + pyspark_to_java_column(raster_tile), + pyspark_to_java_column(srid), + ) + + def rst_fromcontent( raster_bin: ColumnOrName, driver: ColumnOrName, size_in_mb: Any = -1 ) -> Column: diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/FormatLookup.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/FormatLookup.scala index e3aeb5296..8bf2d9cdb 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/FormatLookup.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/FormatLookup.scala @@ -17,6 +17,7 @@ object FormatLookup { "CAD" -> "dwg", "CEOS" -> "ceos", "COASP" -> "coasp", + "COG" -> "tif", "COSAR" -> "cosar", "CPG" -> "cpg", "CSW" -> "csw", diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala index b86489359..6e2fee0f6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala @@ -96,27 +96,26 @@ object GDAL { */ def readRaster( inputRaster: Any, - parentPath: String, - shortDriverName: String, + createInfo: Map[String, String], inputDT: DataType ): MosaicRasterGDAL = { inputDT match { case StringType => val path = inputRaster.asInstanceOf[UTF8String].toString - MosaicRasterGDAL.readRaster(path, parentPath) + MosaicRasterGDAL.readRaster(createInfo) case BinaryType => val bytes = inputRaster.asInstanceOf[Array[Byte]] - val raster = MosaicRasterGDAL.readRaster(bytes, parentPath, shortDriverName) + val raster = MosaicRasterGDAL.readRaster(bytes, createInfo) // If the raster is coming as a byte array, we can't check for zip condition. // We first try to read the raster directly, if it fails, we read it as a zip. if (raster == null) { + val parentPath = createInfo("parentPath") val zippedPath = s"/vsizip/$parentPath" - MosaicRasterGDAL.readRaster(bytes, zippedPath, shortDriverName) + MosaicRasterGDAL.readRaster(bytes, createInfo + ("path" -> zippedPath)) } else { raster } - case _ => - throw new IllegalArgumentException(s"Unsupported data type: $inputDT") + case _ => throw new IllegalArgumentException(s"Unsupported data type: $inputDT") } } @@ -160,7 +159,10 @@ object GDAL { * @return * Returns a Raster object. */ - def raster(path: String, parentPath: String): MosaicRasterGDAL = MosaicRasterGDAL.readRaster(path, parentPath) + def raster(path: String, parentPath: String): MosaicRasterGDAL = { + val createInfo = Map("path" -> path, "parentPath" -> parentPath) + MosaicRasterGDAL.readRaster(createInfo) + } /** * Reads a raster from the given byte array. If the byte array is a zip @@ -171,8 +173,10 @@ object GDAL { * @return * Returns a Raster object. */ - def raster(content: Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL = - MosaicRasterGDAL.readRaster(content, parentPath, driverShortName) + def raster(content: Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL = { + val createInfo = Map("parentPath" -> parentPath, "driver" -> driverShortName) + MosaicRasterGDAL.readRaster(content, createInfo) + } /** * Reads a raster from the given path. It extracts the specified band from @@ -186,8 +190,10 @@ object GDAL { * @return * Returns a Raster band object. */ - def band(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL = - MosaicRasterGDAL.readBand(path, bandIndex, parentPath) + def band(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL = { + val createInfo = Map("path" -> path, "parentPath" -> parentPath) + MosaicRasterGDAL.readBand(bandIndex, createInfo) + } /** * Converts raster x, y coordinates to lat, lon coordinates. diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index 108591529..10f04416d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -26,13 +26,17 @@ import scala.util.{Failure, Success, Try} //noinspection DuplicatedCode case class MosaicRasterGDAL( raster: Dataset, - path: String, - parentPath: String, - driverShortName: String, + createInfo: Map[String, String], memSize: Long ) extends RasterWriter with RasterCleaner { + def path: String = createInfo("path") + + def parentPath: String = createInfo("parentPath") + + def driverShortName: Option[String] = createInfo.get("driver") + def getWriteOptions: MosaicRasterWriteOptions = MosaicRasterWriteOptions(this) def getCompression: String = { @@ -72,7 +76,9 @@ case class MosaicRasterGDAL( * @return * The raster's driver short name. */ - def getDriversShortName: String = driverShortName + def getDriversShortName: String = driverShortName.getOrElse( + Try(raster.GetDriver().getShortName).getOrElse("NONE") + ) /** * @return @@ -114,7 +120,7 @@ case class MosaicRasterGDAL( def pixelDiagSize: Double = math.sqrt(pixelXSize * pixelXSize + pixelYSize * pixelYSize) /** @return Returns file extension. */ - def getRasterFileExtension: String = GDAL.getExtension(driverShortName) + def getRasterFileExtension: String = GDAL.getExtension(getDriversShortName) /** @return Returns the raster's bands as a Seq. */ def getBands: Seq[MosaicRasterBandGDAL] = (1 to numBands).map(getBand) @@ -141,7 +147,7 @@ case class MosaicRasterGDAL( * A MosaicRaster object. */ def openRaster(path: String): Dataset = { - MosaicRasterGDAL.openRaster(path, Some(driverShortName)) + MosaicRasterGDAL.openRaster(path, driverShortName) } /** @@ -202,7 +208,6 @@ case class MosaicRasterGDAL( .toInt } - /** * @return * Sets the raster's SRID. This is the EPSG code of the raster's CRS. @@ -212,15 +217,18 @@ case class MosaicRasterGDAL( srs.ImportFromEPSG(srid) raster.SetSpatialRef(srs) val driver = raster.GetDriver() - val newPath = PathUtils.createTmpFilePath(GDAL.getExtension(driverShortName)) + val newPath = PathUtils.createTmpFilePath(GDAL.getExtension(getDriversShortName)) driver.CreateCopy(newPath, raster) - val newRaster = MosaicRasterGDAL.openRaster(newPath, Some(driverShortName)) + val newRaster = MosaicRasterGDAL.openRaster(newPath, driverShortName) dispose(this) - MosaicRasterGDAL(newRaster, newPath, parentPath, driverShortName, -1) + val createInfo = Map( + "path" -> newPath, + "parentPath" -> parentPath, + "driver" -> getDriversShortName + ) + MosaicRasterGDAL(newRaster, createInfo, -1) } - - /** * @return * Returns the raster's proj4 string. @@ -340,10 +348,9 @@ case class MosaicRasterGDAL( def isEmpty: Boolean = { val bands = getBands if (bands.isEmpty) { - subdatasets - .values - .filter(_.toLowerCase(Locale.ROOT).startsWith(driverShortName.toLowerCase(Locale.ROOT))) - .flatMap(readRaster(_, path).getBands) + subdatasets.values + .filter(_.toLowerCase(Locale.ROOT).startsWith(getDriversShortName.toLowerCase(Locale.ROOT))) + .flatMap(bp => readRaster(createInfo + ("path" -> bp)).getBands) .takeWhile(_.isEmpty) .nonEmpty } else { @@ -381,7 +388,7 @@ case class MosaicRasterGDAL( val cleanPath = filePath.replace("/vsizip/", "") val zipPath = if (cleanPath.endsWith("zip")) cleanPath else s"$cleanPath.zip" if (path != PathUtils.getCleanPath(parentPath)) { - Try(gdal.GetDriverByName(driverShortName).Delete(path)) + Try(gdal.GetDriverByName(getDriversShortName).Delete(path)) Try(Files.deleteIfExists(Paths.get(cleanPath))) Try(Files.deleteIfExists(Paths.get(path))) Try(Files.deleteIfExists(Paths.get(filePath))) @@ -502,7 +509,7 @@ case class MosaicRasterGDAL( * usable again. */ def refresh(): MosaicRasterGDAL = { - MosaicRasterGDAL(openRaster(path), path, parentPath, driverShortName, memSize) + MosaicRasterGDAL(openRaster(path), createInfo, memSize) } /** @@ -555,22 +562,35 @@ case class MosaicRasterGDAL( * Returns the raster's subdataset with given name. */ def getSubdataset(subsetName: String): MosaicRasterGDAL = { - val path = subdatasets.getOrElse( - s"${subsetName}_tmp", - throw new Exception(s""" - |Subdataset $subsetName not found! - |Available subdatasets: - | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} - | """.stripMargin) - ) - val sanitized = PathUtils.getCleanPath(path) + val path = subdatasets.get(s"${subsetName}_tmp") + val gdalError = gdal.GetLastErrorMsg() + val error = path match { + case Some(_) => "" + case None => + s""" + |Subdataset $subsetName not found! + |Available subdatasets: + | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} + | """.stripMargin + } + val sanitized = PathUtils.getCleanPath(path.getOrElse(PathUtils.NO_PATH_STRING)) val subdatasetPath = PathUtils.getSubdatasetPath(sanitized) val ds = openRaster(subdatasetPath) // Avoid costly IO to compute MEM size here // It will be available when the raster is serialized for next operation // If value is needed then it will be computed when getMemSize is called - MosaicRasterGDAL(ds, path, parentPath, driverShortName, -1) + val createInfo = Map( + "path" -> path.getOrElse(PathUtils.NO_PATH_STRING), + "parentPath" -> parentPath, + "driver" -> getDriversShortName, + "last_error" -> + s""" + |GDAL Error: $gdalError + |$error + |""".stripMargin + ) + MosaicRasterGDAL(ds, createInfo, -1) } def convolve(kernel: Array[Array[Double]]): MosaicRasterGDAL = { @@ -584,7 +604,13 @@ case class MosaicRasterGDAL( band.convolve(kernel) } - MosaicRasterGDAL(outputRaster, resultRasterPath, parentPath, driverShortName, -1) + val createInfo = Map( + "path" -> resultRasterPath, + "parentPath" -> parentPath, + "driver" -> getDriversShortName + ) + + MosaicRasterGDAL(outputRaster, createInfo, -1) } @@ -604,7 +630,13 @@ case class MosaicRasterGDAL( band.filter(kernelSize, operation, outputBand) } - val result = MosaicRasterGDAL(outputRaster, resultRasterPath, parentPath, driverShortName, this.memSize) + val createInfo = Map( + "path" -> resultRasterPath, + "parentPath" -> parentPath, + "driver" -> getDriversShortName + ) + + val result = MosaicRasterGDAL(outputRaster, createInfo, this.memSize) result.flushCache() } @@ -659,25 +691,44 @@ object MosaicRasterGDAL extends RasterReader { * @example * Raster: path = "file:///path/to/file.tif" Subdataset: path = * "file:///path/to/file.tif:subdataset" - * @param inPath - * The path to the raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - path: The path to the raster file. + * - parentPath: The path of the parent raster file. * @return * A MosaicRaster object. */ - override def readRaster(inPath: String, parentPath: String): MosaicRasterGDAL = { + override def readRaster(createInfo: Map[String, String]): MosaicRasterGDAL = { + val inPath = createInfo("path") val isSubdataset = PathUtils.isSubdataset(inPath) val path = PathUtils.getCleanPath(inPath) val readPath = if (isSubdataset) PathUtils.getSubdatasetPath(path) else PathUtils.getZipPath(path) val dataset = openRaster(readPath, None) - val driverShortName = dataset.GetDriver().getShortName - + val error = + if (dataset == null) { + val error = gdal.GetLastErrorMsg() + s""" + Error reading raster from path: $readPath + Error: $error + """ + } else "" + val driverShortName = Try(dataset.GetDriver().getShortName).getOrElse("NONE") // Avoid costly IO to compute MEM size here // It will be available when the raster is serialized for next operation // If value is needed then it will be computed when getMemSize is called // We cannot just use memSize value of the parent due to the fact that the raster could be a subdataset - val raster = MosaicRasterGDAL(dataset, path, parentPath, driverShortName, -1) + val raster = MosaicRasterGDAL( + dataset, + createInfo ++ + Map( + "driver" -> driverShortName, + "last_error" -> error + ), + -1 + ) raster } @@ -685,17 +736,19 @@ object MosaicRasterGDAL extends RasterReader { * Reads a raster from a byte array. * @param contentBytes * The byte array containing the raster data. - * @param driverShortName - * The driver short name of the raster. + * @param createInfo + * Mosaic creation info of the raster. Note: This is not the same as the + * metadata of the raster. This is not the same as GDAL creation options. * @return * A MosaicRaster object. */ - override def readRaster(contentBytes: Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL = { + override def readRaster(contentBytes: Array[Byte], createInfo: Map[String, String]): MosaicRasterGDAL = { if (Option(contentBytes).isEmpty || contentBytes.isEmpty) { - MosaicRasterGDAL(null, "", parentPath, "", -1) + MosaicRasterGDAL(null, createInfo, -1) } else { // This is a temp UUID for purposes of reading the raster through GDAL from memory // The stable UUID is kept in metadata of the raster + val driverShortName = createInfo("driver") val extension = GDAL.getExtension(driverShortName) val tmpPath = PathUtils.createTmpFilePath(extension) Files.write(Paths.get(tmpPath), contentBytes) @@ -721,12 +774,12 @@ object MosaicRasterGDAL extends RasterReader { if (dataset == null) { throw new Exception(s"Error reading raster from bytes: ${prompt._3}") } - MosaicRasterGDAL(dataset, unzippedPath, parentPath, driverShortName, contentBytes.length) + MosaicRasterGDAL(dataset, createInfo + ("path" -> unzippedPath), contentBytes.length) } else { - MosaicRasterGDAL(ds, readPath, parentPath, driverShortName, contentBytes.length) + MosaicRasterGDAL(ds, createInfo + ("path" -> readPath), contentBytes.length) } } else { - MosaicRasterGDAL(dataset, tmpPath, parentPath, driverShortName, contentBytes.length) + MosaicRasterGDAL(dataset, createInfo + ("path" -> tmpPath), contentBytes.length) } } } @@ -738,15 +791,19 @@ object MosaicRasterGDAL extends RasterReader { * @example * Raster: path = "file:///path/to/file.tif" Subdataset: path = * "file:///path/to/file.tif:subdataset" - * @param path - * The path to the raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - path: The path to the raster file. + * - parentPath: The path of the parent raster file. + * - driver: Optional: The driver short name of the raster file * @param bandIndex * The band index to read. * @return * A MosaicRaster object. */ - override def readBand(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL = { - val raster = readRaster(path, parentPath) + override def readBand(bandIndex: Int, createInfo: Map[String, String]): MosaicRasterBandGDAL = { + val raster = readRaster(createInfo) // TODO: Raster and Band are coupled, this can cause a pointer leak raster.getBand(bandIndex) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala index b207789ae..d8a1a90c1 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala @@ -20,14 +20,16 @@ trait RasterReader extends Logging { * @example * Raster: path = "/path/to/file.tif" Subdataset: path = * "FORMAT:/path/to/file.tif:subdataset" - * @param path - * The path to the raster file. - * @param parentPath - * The path of the parent raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - path: The path to the raster file. + * - parentPath: The path of the parent raster file. + * - driver: Optional: The driver short name of the raster file * @return * A MosaicRaster object. */ - def readRaster(path: String, parentPath: String): MosaicRasterGDAL + def readRaster(createInfo: Map[String, String]): MosaicRasterGDAL /** * Reads a raster from an in memory buffer. Use the buffer bytes to produce @@ -35,30 +37,32 @@ trait RasterReader extends Logging { * * @param contentBytes * The file bytes. - * @param parentPath - * The path of the parent raster file. - * @param driverShortName - * The driver short name of the raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - parentPath: The path of the parent raster file. + * - driver: The driver short name of the raster file * @return * A MosaicRaster object. */ - def readRaster(contentBytes: Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL + def readRaster(contentBytes: Array[Byte], createInfo: Map[String, String]): MosaicRasterGDAL /** * Reads a raster band from a file system path. Reads a subdataset band if * the path is to a subdataset. + * * @example * Raster: path = "/path/to/file.tif" Subdataset: path = * "FORMAT:/path/to/file.tif:subdataset" - * @param path - * The path to the raster file. - * @param bandIndex - * The band index to read. - * @param parentPath - * The path of the parent raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - path: The path to the raster file. + * - parentPath: The path of the parent raster file. + * - driver: Optional: The driver short name of the raster file * @return * A MosaicRaster object. */ - def readBand(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL + def readBand(bandIndex: Int, createInfo: Map[String, String]): MosaicRasterBandGDAL } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala index 9e1e97401..cb79dc263 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala @@ -24,16 +24,17 @@ object GDALBuildVRT { val vrtOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val vrtOptions = new BuildVRTOptions(vrtOptionsVec) val result = gdal.BuildVRT(outputPath, rasters.map(_.getRaster).toArray, vrtOptions) - if (result == null) { - throw new Exception(s""" - |Build VRT failed. - |Command: $effectiveCommand - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) - } - // TODO: Figure out multiple parents, should this be an array? + val errorMsg = gdal.GetLastErrorMsg + val createInfo = Map( + "path" -> outputPath, + "parentPath" -> rasters.head.getParentPath, + "driver" -> "VRT", + "last_command" -> effectiveCommand, + "last_error" -> errorMsg, + "all_parents" -> rasters.map(_.getParentPath).mkString(";") + ) // VRT files are just meta files, mem size doesnt make much sense so we keep -1 - MosaicRasterGDAL(result, outputPath, rasters.head.getParentPath, "VRT", -1).flushCache() + MosaicRasterGDAL(result, createInfo, -1).flushCache() } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala index fa92c3b37..e22228817 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import com.databricks.labs.mosaic.utils.SysUtils +import org.gdal.gdal.gdal /** GDALCalc is a helper object for executing GDAL Calc commands. */ object GDALCalc { @@ -33,18 +34,26 @@ object GDALCalc { val effectiveCommand = OperatorOptions.appendOptions(gdalCalcCommand, MosaicRasterWriteOptions.GTiff) val toRun = effectiveCommand.replace("gdal_calc", gdal_calc) val commandRes = SysUtils.runCommand(s"python3 $toRun") - if (commandRes._1.startsWith("ERROR")) { - throw new RuntimeException(s""" - |GDAL Calc command failed: - |$toRun - |STDOUT: - |${commandRes._2} - |STDERR: - |${commandRes._3} - |""".stripMargin) - } + val errorMsg = gdal.GetLastErrorMsg val result = GDAL.raster(resultPath, resultPath) - result + val createInfo = Map( + "path" -> resultPath, + "parentPath" -> resultPath, + "driver" -> "GTiff", + "last_command" -> effectiveCommand, + "last_error" -> errorMsg, + "all_parents" -> resultPath, + "full_error" -> s""" + |GDAL Calc command failed: + |GDAL err: + |$errorMsg + |STDOUT: + |${commandRes._2} + |STDERR: + |${commandRes._3} + |""".stripMargin + ) + result.copy(createInfo = createInfo) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala index 7a60a837a..d3ccd471b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala @@ -25,14 +25,14 @@ object GDALInfo { val gdalInfo = gdal.GDALInfo(raster.getRaster, infoOptions) if (gdalInfo == null) { - throw new Exception(s""" - |GDAL Info failed. - |Command: $command - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) + s""" + |GDAL Info failed. + |Command: $command + |Error: ${gdal.GetLastErrorMsg} + |""".stripMargin + } else { + gdalInfo } - - gdalInfo } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala index fd24a0f73..2fb106fda 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala @@ -31,15 +31,19 @@ object GDALTranslate { val translateOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val translateOptions = new TranslateOptions(translateOptionsVec) val result = gdal.Translate(outputPath, raster.getRaster, translateOptions) - if (result == null) { - throw new Exception(s""" - |Translate failed. - |Command: $effectiveCommand - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) - } + val errorMsg = gdal.GetLastErrorMsg val size = Files.size(Paths.get(outputPath)) - raster.copy(raster = result, path = outputPath, memSize = size, driverShortName = writeOptions.format).flushCache() + val createInfo = Map( + "path" -> outputPath, + "parentPath" -> raster.getParentPath, + "driver" -> writeOptions.format, + "last_command" -> effectiveCommand, + "last_error" -> errorMsg, + "all_parents" -> raster.getParentPath + ) + raster + .copy(raster = result, createInfo = createInfo, memSize = size) + .flushCache() } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala index ba6dce58d..516560a76 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala @@ -27,23 +27,18 @@ object GDALWarp { val warpOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val warpOptions = new WarpOptions(warpOptionsVec) val result = gdal.Warp(outputPath, rasters.map(_.getRaster).toArray, warpOptions) - // TODO: Figure out multiple parents, should this be an array? // Format will always be the same as the first raster - if (result == null) { - throw new Exception(s""" - |Warp failed. - |Command: $effectiveCommand - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) - } + val errorMsg = gdal.GetLastErrorMsg val size = Files.size(Paths.get(outputPath)) - rasters.head - .copy( - raster = result, - path = outputPath, - memSize = size - ) - .flushCache() + val createInfo = Map( + "path" -> outputPath, + "parentPath" -> rasters.head.getParentPath, + "driver" -> rasters.head.getWriteOptions.format, + "last_command" -> effectiveCommand, + "last_error" -> errorMsg, + "all_parents" -> rasters.map(_.getParentPath).mkString(";") + ) + rasters.head.copy(raster = result, createInfo = createInfo, memSize = size).flushCache() } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala index 4e9f61c5e..072380666 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala @@ -69,7 +69,7 @@ object OverlappingTiles { val (_, valid) = tiles.flatten.partition(_._1) - valid.map(t => MosaicRasterTile(null, t._2, raster.getParentPath, raster.getDriversShortName)) + valid.map(t => MosaicRasterTile(null, t._2)) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala index 701cf8cf1..9920af923 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala @@ -38,13 +38,13 @@ object RasterTessellate { val cellID = cell.cellIdAsLong(indexSystem) val isValidCell = indexSystem.isValid(cellID) if (!isValidCell) { - (false, MosaicRasterTile(cell.index, null, "", "")) + (false, MosaicRasterTile(cell.index, null)) } else { val cellRaster = tmpRaster.getRasterForCell(cellID, indexSystem, geometryAPI) val isValidRaster = !cellRaster.isEmpty ( isValidRaster, - MosaicRasterTile(cell.index, cellRaster, raster.getParentPath, raster.getDriversShortName) + MosaicRasterTile(cell.index, cellRaster) ) } }) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala index b12a8f847..7b218199e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala @@ -58,7 +58,7 @@ object ReTile { val (_, valid) = tiles.partition(_._1) - valid.map(t => MosaicRasterTile(null, t._2, raster.getParentPath, raster.getDriversShortName)) + valid.map(t => MosaicRasterTile(null, t._2)) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala index 25f73bf8b..9580cc441 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala @@ -5,7 +5,10 @@ import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.utils.PathUtils -/** ReTile is a helper object for splitting multi-band rasters into single-band-per-row. */ +/** + * ReTile is a helper object for splitting multi-band rasters into + * single-band-per-row. + */ object SeparateBands { /** @@ -24,11 +27,13 @@ object SeparateBands { val fileExtension = raster.getRasterFileExtension val rasterPath = PathUtils.createTmpFilePath(fileExtension) val shortDriver = raster.getDriversShortName + val outOptions = raster.getWriteOptions val result = GDALTranslate.executeTranslate( rasterPath, raster, - command = s"gdal_translate -of $shortDriver -b ${i + 1} -co COMPRESS=DEFLATE" + command = s"gdal_translate -of $shortDriver -b ${i + 1}", + writeOptions = outOptions ) val isEmpty = result.isEmpty @@ -38,13 +43,13 @@ object SeparateBands { if (isEmpty) dispose(result) - (isEmpty, result, i) + (isEmpty, result.copy(createInfo = result.createInfo ++ Map("bandIndex" -> (i + 1).toString)), i) } val (_, valid) = tiles.partition(_._1) - valid.map(t => new MosaicRasterTile(null, t._2, raster.getParentPath, raster.getDriversShortName)) + valid.map(t => new MosaicRasterTile(null, t._2)) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala index 5203178e0..137d482ce 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala @@ -36,8 +36,7 @@ object RasterTileType { Array( StructField("index_id", idType), StructField("raster", rasterType), - StructField("parentPath", StringType), - StructField("driver", StringType) + StructField("metadata", MapType(StringType, StringType)) ) ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala index a0710dbe9..bf36ea8f2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.core.types.model import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.expressions.raster.{buildMapString, extractMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{BinaryType, DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -16,18 +17,16 @@ import scala.util.{Failure, Success, Try} * Index ID. * @param raster * Raster instance corresponding to the tile. - * @param parentPath - * Parent path of the raster. - * @param driver - * Driver used to read the raster. */ case class MosaicRasterTile( index: Either[Long, String], - raster: MosaicRasterGDAL, - parentPath: String, - driver: String + raster: MosaicRasterGDAL ) { + def parentPath: String = raster.createInfo("parentPath") + + def driver: String = raster.createInfo("driver") + def getIndex: Either[Long, String] = index def getParentPath: String = parentPath @@ -57,18 +56,8 @@ case class MosaicRasterTile( (indexSystem.getCellIdDataType, index) match { case (_: LongType, Left(_)) => this case (_: StringType, Right(_)) => this - case (_: LongType, Right(value)) => new MosaicRasterTile( - index = Left(indexSystem.parse(value)), - raster = raster, - parentPath = parentPath, - driver = driver - ) - case (_: StringType, Left(value)) => new MosaicRasterTile( - index = Right(indexSystem.format(value)), - raster = raster, - parentPath = parentPath, - driver = driver - ) + case (_: LongType, Right(value)) => this.copy(index = Left(indexSystem.parse(value))) + case (_: StringType, Left(value)) => this.copy(index = Right(indexSystem.format(value))) case _ => throw new IllegalArgumentException("Invalid cell id data type") } } @@ -110,22 +99,21 @@ case class MosaicRasterTile( def serialize( rasterDataType: DataType ): InternalRow = { - val parentPathUTF8 = UTF8String.fromString(parentPath) - val driverUTF8 = UTF8String.fromString(driver) val encodedRaster = encodeRaster(rasterDataType) + val mapData = buildMapString(raster.createInfo) if (Option(index).isDefined) { if (index.isLeft) InternalRow.fromSeq( - Seq(index.left.get, encodedRaster, parentPathUTF8, driverUTF8) + Seq(index.left.get, encodedRaster, mapData) ) else { // Copy from tmp to checkpoint. // Have to use GDAL Driver to do this since sidecar files are not copied by spark. InternalRow.fromSeq( - Seq(UTF8String.fromString(index.right.get), encodedRaster, parentPathUTF8, driverUTF8) + Seq(UTF8String.fromString(index.right.get), encodedRaster, mapData) ) } } else { - InternalRow.fromSeq(Seq(null, encodedRaster, parentPathUTF8, driverUTF8)) + InternalRow.fromSeq(Seq(null, encodedRaster, mapData)) } } @@ -147,6 +135,7 @@ case class MosaicRasterTile( case Success(value) => value.toInt case Failure(_) => -1 } + } /** Companion object. */ @@ -165,18 +154,17 @@ object MosaicRasterTile { def deserialize(row: InternalRow, idDataType: DataType, rasterType: DataType): MosaicRasterTile = { val index = row.get(0, idDataType) val rawRaster = row.get(1, rasterType) - val parentPath = row.get(2, StringType).toString - val driver = row.get(3, StringType).toString - val raster = GDAL.readRaster(rawRaster, parentPath, driver, rasterType) + val createInfo = extractMap(row.getMap(2)) + val raster = GDAL.readRaster(rawRaster, createInfo, rasterType) // noinspection TypeCheckCanBeMatch if (Option(index).isDefined) { if (index.isInstanceOf[Long]) { - new MosaicRasterTile(Left(index.asInstanceOf[Long]), raster, parentPath, driver) + new MosaicRasterTile(Left(index.asInstanceOf[Long]), raster) } else { - new MosaicRasterTile(Right(index.asInstanceOf[UTF8String].toString), raster, parentPath, driver) + new MosaicRasterTile(Right(index.asInstanceOf[UTF8String].toString), raster) } } else { - new MosaicRasterTile(null, raster, parentPath, driver) + new MosaicRasterTile(null, raster) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala index a38e76900..867167c58 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala @@ -134,8 +134,9 @@ object ReTileOnRead extends ReadStrategy { */ def localSubdivide(inPath: String, parentPath: String, sizeInMB: Int): Seq[MosaicRasterTile] = { val cleanPath = PathUtils.getCleanPath(inPath) - val raster = MosaicRasterGDAL.readRaster(cleanPath, parentPath) - val inTile = new MosaicRasterTile(null, raster, parentPath, raster.getDriversShortName) + val createInfo = Map("path" -> cleanPath, "parentPath" -> parentPath) + val raster = MosaicRasterGDAL.readRaster(createInfo) + val inTile = new MosaicRasterTile(null, raster) val tiles = BalancedSubdivision.splitRaster(inTile, sizeInMB) RasterCleaner.dispose(raster) RasterCleaner.dispose(inTile) diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadAsPath.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadAsPath.scala new file mode 100644 index 000000000..0973146ef --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadAsPath.scala @@ -0,0 +1,124 @@ +package com.databricks.labs.mosaic.datasource.gdal + +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.datasource.Utils +import com.databricks.labs.mosaic.datasource.gdal.GDALFileFormat._ +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.hadoop.fs.{FileStatus, FileSystem} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +import java.nio.file.{Files, Paths} + +/** An object defining the retiling read strategy for the GDAL file format. */ +object ReadAsPath extends ReadStrategy { + + val tileDataType: DataType = StringType + + // noinspection DuplicatedCode + /** + * Returns the schema of the GDAL file format. + * @note + * Different read strategies can have different schemas. This is because + * the schema is defined by the read strategy. For retiling we always use + * checkpoint location. In this case rasters are stored off spark rows. + * If you need the tiles in memory please load them from path stored in + * the tile returned by the reader. + * + * @param options + * Options passed to the reader. + * @param files + * List of files to read. + * @param parentSchema + * Parent schema. + * @param sparkSession + * Spark session. + * + * @return + * Schema of the GDAL file format. + */ + override def getSchema( + options: Map[String, String], + files: Seq[FileStatus], + parentSchema: StructType, + sparkSession: SparkSession + ): StructType = { + val trimmedSchema = parentSchema.filter(field => field.name != CONTENT && field.name != LENGTH) + val indexSystem = IndexSystemFactory.getIndexSystem(sparkSession) + StructType(trimmedSchema) + .add(StructField(UUID, LongType, nullable = false)) + .add(StructField(X_SIZE, IntegerType, nullable = false)) + .add(StructField(Y_SIZE, IntegerType, nullable = false)) + .add(StructField(BAND_COUNT, IntegerType, nullable = false)) + .add(StructField(METADATA, MapType(StringType, StringType), nullable = false)) + .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) + .add(StructField(SRID, IntegerType, nullable = false)) + .add(StructField(LENGTH, LongType, nullable = false)) + // Note that for retiling we always use checkpoint location. + // In this case rasters are stored off spark rows. + // If you need the tiles in memory please load them from path stored in the tile returned by the reader. + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType, tileDataType), nullable = false)) + } + + /** + * Reads the content of the file. + * @param status + * File status. + * @param fs + * File system. + * @param requiredSchema + * Required schema. + * @param options + * Options passed to the reader. + * @param indexSystem + * Index system. + * + * @return + * Iterator of internal rows. + */ + override def read( + status: FileStatus, + fs: FileSystem, + requiredSchema: StructType, + options: Map[String, String], + indexSystem: IndexSystem + ): Iterator[InternalRow] = { + val inPath = status.getPath.toString + val uuid = getUUID(status) + + val tmpPath = PathUtils.copyToTmp(inPath) + val createInfo = Map("path" -> tmpPath, "parentPath" -> inPath) + val raster = MosaicRasterGDAL.readRaster(createInfo) + val tile = MosaicRasterTile(null, raster) + + val trimmedSchema = StructType(requiredSchema.filter(field => field.name != TILE)) + val fields = trimmedSchema.fieldNames.map { + case PATH => status.getPath.toString + case MODIFICATION_TIME => status.getModificationTime + case UUID => uuid + case X_SIZE => tile.getRaster.xSize + case Y_SIZE => tile.getRaster.ySize + case BAND_COUNT => tile.getRaster.numBands + case METADATA => tile.getRaster.metadata + case SUBDATASETS => tile.getRaster.subdatasets + case SRID => tile.getRaster.SRID + case LENGTH => tile.getRaster.getMemSize + case other => throw new RuntimeException(s"Unsupported field name: $other") + } + // Writing to bytes is destructive so we delay reading content and content length until the last possible moment + val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize(tileDataType))) + RasterCleaner.dispose(tile) + + val rows = Seq(row) + + Files.deleteIfExists(Paths.get(tmpPath)) + + rows.iterator + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala index 15ddec2ed..7e6687079 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala @@ -6,12 +6,12 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.datasource.Utils import com.databricks.labs.mosaic.datasource.gdal.GDALFileFormat._ +import com.databricks.labs.mosaic.expressions.raster.buildMapString import com.databricks.labs.mosaic.utils.PathUtils import org.apache.hadoop.fs.{FileStatus, FileSystem} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** An object defining the in memory read strategy for the GDAL file format. */ object ReadInMemory extends ReadStrategy { @@ -78,9 +78,12 @@ object ReadInMemory extends ReadStrategy { ): Iterator[InternalRow] = { val inPath = status.getPath.toString val readPath = PathUtils.getCleanPath(inPath) - val driverShortName = MosaicRasterGDAL.identifyDriver(readPath) val contentBytes: Array[Byte] = readContent(fs, status) - val raster = MosaicRasterGDAL.readRaster(readPath, inPath) + val createInfo = Map( + "path" -> readPath, + "parentPath" -> inPath + ) + val raster = MosaicRasterGDAL.readRaster(createInfo) val uuid = getUUID(status) val fields = requiredSchema.fieldNames.filter(_ != TILE).map { @@ -96,8 +99,9 @@ object ReadInMemory extends ReadStrategy { case SRID => raster.SRID case other => throw new RuntimeException(s"Unsupported field name: $other") } + val mapData = buildMapString(raster.createInfo) val rasterTileSer = InternalRow.fromSeq( - Seq(null, contentBytes, UTF8String.fromString(inPath), UTF8String.fromString(driverShortName), null) + Seq(null, contentBytes, mapData) ) val row = Utils.createRow( fields ++ Seq(rasterTileSer) diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala index cacc1c133..ab141b069 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala @@ -72,6 +72,7 @@ object ReadStrategy { readStrategy match { case MOSAIC_RASTER_READ_IN_MEMORY => ReadInMemory case MOSAIC_RASTER_RE_TILE_ON_READ => ReTileOnRead + case MOSAIC_RASTER_READ_AS_PATH => ReadAsPath case _ => ReadInMemory } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala index f94eae6e8..de163ca37 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala @@ -28,12 +28,7 @@ case class RST_CombineAvg( /** Combines the rasters using average of pixels. */ override def rasterTransform(tiles: Seq[MosaicRasterTile]): Any = { val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null - MosaicRasterTile( - index, - CombineAVG.compute(tiles.map(_.getRaster)), - tiles.head.getParentPath, - tiles.head.getDriver - ) + MosaicRasterTile(index, CombineAVG.compute(tiles.map(_.getRaster))) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala index be12bd12c..3bf9248c2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala @@ -79,11 +79,8 @@ case class RST_CombineAvgAgg( // If merging multiple index rasters, the index value is dropped val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null var combined = CombineAVG.compute(tiles.map(_.getRaster)).flushCache() - // TODO: should parent path be an array? - val parentPath = tiles.head.getParentPath - val driver = tiles.head.getDriver - val result = MosaicRasterTile(idx, combined, parentPath, driver) + val result = MosaicRasterTile(idx, combined) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) .serialize(tileType) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala index 3e7de13b5..459f1774a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala @@ -37,9 +37,7 @@ case class RST_DerivedBand( val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null MosaicRasterTile( index, - PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName), - tiles.head.getParentPath, - tiles.head.getDriver + PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName) ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala index 7d972b6e1..836b79cbd 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala @@ -89,11 +89,8 @@ case class RST_DerivedBandAgg( val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null var combined = PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName) - // TODO: should parent path be an array? - val parentPath = tiles.head.getParentPath - val driver = tiles.head.getDriver - val result = MosaicRasterTile(idx, combined, parentPath, driver) + val result = MosaicRasterTile(idx, combined) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) .serialize(BinaryType) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala index 956c9c049..1021eb083 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala @@ -68,8 +68,9 @@ case class RST_FromContent( val targetSize = sizeInMB.eval(input).asInstanceOf[Int] if (targetSize <= 0 || rasterArr.length <= targetSize) { // - no split required - var raster = MosaicRasterGDAL.readRaster(rasterArr, PathUtils.NO_PATH_STRING, driver) - var tile = MosaicRasterTile(null, raster, PathUtils.NO_PATH_STRING, driver) + val createInfo = Map("parentPath" -> PathUtils.NO_PATH_STRING, "driver" -> driver) + var raster = MosaicRasterGDAL.readRaster(rasterArr, createInfo) + var tile = MosaicRasterTile(null, raster) val row = tile.formatCellId(indexSystem).serialize(tileType) RasterCleaner.dispose(raster) RasterCleaner.dispose(tile) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala index ddd4c6af2..cd5808f30 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala @@ -15,7 +15,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} -import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import java.nio.file.{Files, Paths, StandardCopyOption} @@ -66,8 +66,9 @@ case class RST_FromFile( val driver = MosaicRasterGDAL.identifyDriver(path) val targetSize = sizeInMB.eval(input).asInstanceOf[Int] if (targetSize <= 0 && Files.size(Paths.get(readPath)) <= Integer.MAX_VALUE) { - var raster = MosaicRasterGDAL.readRaster(readPath, path) - var tile = MosaicRasterTile(null, raster, path, raster.getDriversShortName) + val createInfo = Map("path" -> readPath, "parentPath" -> path) + var raster = MosaicRasterGDAL.readRaster(createInfo) + var tile = MosaicRasterTile(null, raster) val row = tile.formatCellId(indexSystem).serialize(tileType) RasterCleaner.dispose(raster) RasterCleaner.dispose(tile) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala index 586337556..12c70d025 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala @@ -125,8 +125,9 @@ case class RST_MakeTiles( if (targetSize <= 0 && inputSize <= Integer.MAX_VALUE) { // - no split required - val raster = GDAL.readRaster(rawInput, PathUtils.NO_PATH_STRING, driver, inputExpr.dataType) - val tile = MosaicRasterTile(null, raster, PathUtils.NO_PATH_STRING, driver) + val createInfo = Map("parentPath" -> PathUtils.NO_PATH_STRING, "driver" -> driver) + val raster = GDAL.readRaster(rawInput, createInfo, inputExpr.dataType) + val tile = MosaicRasterTile(null, raster) val row = tile.formatCellId(indexSystem).serialize(tileType) RasterCleaner.dispose(raster) RasterCleaner.dispose(tile) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala index 606f8cc72..bc2aea949 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala @@ -45,12 +45,8 @@ case class RST_MapAlgebra( val resultPath = PathUtils.createTmpFilePath(extension) val command = parseSpec(jsonSpec, resultPath, tiles) val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null - MosaicRasterTile( - index, - GDALCalc.executeCalc(command, resultPath), - resultPath, - tiles.head.getDriver - ) + val result = GDALCalc.executeCalc(command, resultPath) + MosaicRasterTile(index, result) } def parseSpec(jsonSpec: String, resultPath: String, tiles: Seq[MosaicRasterTile]): String = { diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala index b639a42da..ae56a01ab 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala @@ -80,11 +80,8 @@ case class RST_MergeAgg( // If merging multiple index rasters, the index value is dropped val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null var merged = MergeRasters.merge(tiles.map(_.getRaster)).flushCache() - // TODO: should parent path be an array? - val parentPath = tiles.head.getParentPath - val driver = tiles.head.getDriver - val result = MosaicRasterTile(idx, merged, parentPath, driver) + val result = MosaicRasterTile(idx, merged) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) .serialize(BinaryType) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetSRID.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetSRID.scala index a3d44289a..2aabb3df9 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetSRID.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetSRID.scala @@ -1,8 +1,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI -import com.databricks.labs.mosaic.core.raster.io.RasterCleaner -import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} @@ -11,6 +9,7 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** The expression for clipping a raster by a vector. */ case class RST_SetSRID( @@ -19,14 +18,15 @@ case class RST_SetSRID( expressionConfig: MosaicExpressionConfig ) extends Raster1ArgExpression[RST_SetSRID]( rastersExpr, - sridExpr, - RasterTileType(expressionConfig.getCellIdType), + sridExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) /** diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Transform.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Transform.scala new file mode 100644 index 000000000..7681f2bba --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Transform.scala @@ -0,0 +1,61 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.proj.RasterProject +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types._ +import org.gdal.osr.SpatialReference + +/** Returns the upper left x of the raster. */ +case class RST_Transform( + tileExpr: Expression, + srid: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_Transform]( + tileExpr, + srid, + returnsRaster = true, + expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = { + val srid = arg1.asInstanceOf[Int] + val sReff = new SpatialReference() + sReff.ImportFromEPSG(srid) + sReff.SetAxisMappingStrategy(org.gdal.osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) + val result = RasterProject.project(tile.raster, sReff) + tile.copy(raster = result) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Transform extends WithExpressionInfo { + + override def name: String = "rst_transform" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Avg](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala index a229aae89..7db83db25 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.expressions -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, ArrayBasedMapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -21,8 +21,8 @@ package object raster { * The measure type of the resulting pixel value. * * @return - * The datatype to be used for serialization of the result of - * [[com.databricks.labs.mosaic.expressions.raster.base.RasterToGridExpression]]. + * The datatype to be used for serialization of the result of + * [[com.databricks.labs.mosaic.expressions.raster.base.RasterToGridExpression]]. */ def RasterToGridType(cellIDType: DataType, measureType: DataType): DataType = { ArrayType( @@ -49,6 +49,19 @@ package object raster { mapBuilder.build() } + /** + * Extracts a scala Map[String, String] from a spark map. + * @param mapData + * The map to be used. + * @return + * Deserialized map. + */ + def extractMap(mapData: MapData): Map[String, String] = { + val keys = mapData.keyArray().toArray[UTF8String](StringType).map(_.toString) + val values = mapData.valueArray().toArray[UTF8String](StringType).map(_.toString) + keys.zip(values).toMap + } + /** * Builds a spark map from a scala Map[String, Double]. * @param metaData diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index cc0a611dd..293d27593 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -305,6 +305,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[RST_Subdatasets](expressionConfig) mosaicRegistry.registerExpression[RST_Summary](expressionConfig) mosaicRegistry.registerExpression[RST_Tessellate](expressionConfig) + mosaicRegistry.registerExpression[RST_Transform](expressionConfig) mosaicRegistry.registerExpression[RST_FromContent](expressionConfig) mosaicRegistry.registerExpression[RST_FromFile](expressionConfig) mosaicRegistry.registerExpression[RST_ToOverlappingTiles](expressionConfig) @@ -749,6 +750,8 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_summary(raster: Column): Column = ColumnAdapter(RST_Summary(raster.expr, expressionConfig)) def rst_tessellate(raster: Column, resolution: Column): Column = ColumnAdapter(RST_Tessellate(raster.expr, resolution.expr, expressionConfig)) + def rst_transform(raster: Column, srid: Column): Column = + ColumnAdapter(RST_Transform(raster.expr, srid.expr, expressionConfig)) def rst_tessellate(raster: Column, resolution: Int): Column = ColumnAdapter(RST_Tessellate(raster.expr, lit(resolution).expr, expressionConfig)) def rst_fromcontent(raster: Column, driver: Column): Column = diff --git a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala index c7676a88d..b9e972d6f 100644 --- a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala @@ -47,6 +47,7 @@ object MosaicGDAL extends Logging { /** Configures the GDAL environment. */ def configureGDAL(mosaicConfig: MosaicExpressionConfig): Unit = { val CPL_TMPDIR = MosaicContext.tmpDir + val GDAL_PAM_PROXY_DIR = MosaicContext.tmpDir gdal.SetConfigOption("GDAL_VRT_ENABLE_PYTHON", "YES") gdal.SetConfigOption("GDAL_DISABLE_READDIR_ON_OPEN", "TRUE") gdal.SetConfigOption("CPL_TMPDIR", CPL_TMPDIR) diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala index 1337ae6d2..88c0f4bbb 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala @@ -10,10 +10,11 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { test("Read band metadata and pixel data from GeoTIFF file.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + val createInfo = Map( + "path" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + "parentPath" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) val testBand = testRaster.getBand(1) testBand.getBand testBand.index shouldBe 1 @@ -36,10 +37,11 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { test("Read band metadata and pixel data from a GRIdded Binary file.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") + val createInfo = Map( + "path" -> filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), + "parentPath" -> filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) val testBand = testRaster.getBand(1) testBand.description shouldBe "1[-] HYBL=\"Hybrid level\"" testBand.dataType shouldBe 7 @@ -55,15 +57,17 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { test("Read band metadata and pixel data from a NetCDF file.") { assume(System.getProperty("os.name") == "Linux") - val superRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), - filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") + val createInfo = Map( + "path" -> filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), + "parentPath" -> filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") ) + val superRaster = MosaicRasterGDAL.readRaster(createInfo) val subdatasetPath = superRaster.subdatasets("bleaching_alert_area") - val testRaster = MosaicRasterGDAL.readRaster( - subdatasetPath, - subdatasetPath + val sdCreate = Map( + "path" -> subdatasetPath, + "parentPath" -> subdatasetPath ) + val testRaster = MosaicRasterGDAL.readRaster(sdCreate) val testBand = testRaster.getBand(1) testBand.dataType shouldBe 1 diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala index bb53d6b79..8bbe4b1f3 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala @@ -34,11 +34,12 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Read raster metadata from GeoTIFF file.") { assume(System.getProperty("os.name") == "Linux") - - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + + val createInfo = Map( + "path" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + "parentPath" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) testRaster.xSize shouldBe 2400 testRaster.ySize shouldBe 2400 testRaster.numBands shouldBe 1 @@ -56,10 +57,11 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Read raster metadata from a GRIdded Binary file.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") + val createInfo = Map( + "path" -> filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), + "parentPath" -> filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) testRaster.xSize shouldBe 14 testRaster.ySize shouldBe 14 testRaster.numBands shouldBe 14 @@ -72,17 +74,19 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Read raster metadata from a NetCDF file.") { assume(System.getProperty("os.name") == "Linux") - - val superRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), - filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") + + val createInfo = Map( + "path" -> filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), + "parentPath" -> filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") ) + val superRaster = MosaicRasterGDAL.readRaster(createInfo) val subdatasetPath = superRaster.subdatasets("bleaching_alert_area") - val testRaster = MosaicRasterGDAL.readRaster( - subdatasetPath, - subdatasetPath + val sdCreateInfo = Map( + "path" -> subdatasetPath, + "parentPath" -> subdatasetPath ) + val testRaster = MosaicRasterGDAL.readRaster(sdCreateInfo) testRaster.xSize shouldBe 7200 testRaster.ySize shouldBe 3600 @@ -98,10 +102,11 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Raster pixel and extent sizes are correct.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + val createInfo = Map( + "path" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + "parentPath" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) testRaster.pixelXSize - 463.312716527 < 0.0000001 shouldBe true testRaster.pixelYSize - -463.312716527 < 0.0000001 shouldBe true @@ -125,13 +130,18 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { MosaicGDAL.setBlockSize(30) - val ds = gdalJNI.GetDriverByName("GTiff").Create("/mosaic_tmp/test.tif", 50, 50, 1, gdalconst.gdalconstConstants.GDT_Float32) + val ds = gdalJNI.GetDriverByName("GTiff").Create("/tmp/mosaic_tmp/test.tif", 50, 50, 1, gdalconst.gdalconstConstants.GDT_Float32) val values = 0 until 50 * 50 ds.GetRasterBand(1).WriteRaster(0, 0, 50, 50, values.toArray) ds.FlushCache() - var result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "avg").flushCache() + val createInfo = Map( + "path" -> "", + "parentPath" -> "", + "driver" -> "GTiff" + ) + var result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "avg").flushCache() var resultValues = result.getBand(1).values @@ -158,7 +168,7 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { // mode - result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "mode").flushCache() + result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "mode").flushCache() resultValues = result.getBand(1).values @@ -194,136 +204,136 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { inputMatrix(12)(12), inputMatrix(12)(13) ).groupBy(identity).maxBy(_._2.size)._1.toDouble - + // corner resultMatrix(49)(49) shouldBe Seq( - inputMatrix(47)(47), - inputMatrix(47)(48), - inputMatrix(47)(49), - inputMatrix(48)(47), - inputMatrix(48)(48), - inputMatrix(48)(49), - inputMatrix(49)(47), - inputMatrix(49)(48), - inputMatrix(49)(49) + inputMatrix(47)(47), + inputMatrix(47)(48), + inputMatrix(47)(49), + inputMatrix(48)(47), + inputMatrix(48)(48), + inputMatrix(48)(49), + inputMatrix(49)(47), + inputMatrix(49)(48), + inputMatrix(49)(49) ).groupBy(identity).maxBy(_._2.size)._1.toDouble - + // median - - result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "median").flushCache() - + + result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "median").flushCache() + resultValues = result.getBand(1).values - + inputMatrix = values.toArray.grouped(50).toArray resultMatrix = resultValues.grouped(50).toArray - + // first block - + resultMatrix(10)(11) shouldBe Seq( - inputMatrix(8)(9), - inputMatrix(8)(10), - inputMatrix(8)(11), - inputMatrix(8)(12), - inputMatrix(8)(13), - inputMatrix(9)(9), - inputMatrix(9)(10), - inputMatrix(9)(11), - inputMatrix(9)(12), - inputMatrix(9)(13), - inputMatrix(10)(9), - inputMatrix(10)(10), - inputMatrix(10)(11), - inputMatrix(10)(12), - inputMatrix(10)(13), - inputMatrix(11)(9), - inputMatrix(11)(10), - inputMatrix(11)(11), - inputMatrix(11)(12), - inputMatrix(11)(13), - inputMatrix(12)(9), - inputMatrix(12)(10), - inputMatrix(12)(11), - inputMatrix(12)(12), - inputMatrix(12)(13) + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) ).sorted.apply(12).toDouble - + // min filter - - result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "min").flushCache() - + + result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "min").flushCache() + resultValues = result.getBand(1).values - + inputMatrix = values.toArray.grouped(50).toArray resultMatrix = resultValues.grouped(50).toArray - + // first block - + resultMatrix(10)(11) shouldBe Seq( - inputMatrix(8)(9), - inputMatrix(8)(10), - inputMatrix(8)(11), - inputMatrix(8)(12), - inputMatrix(8)(13), - inputMatrix(9)(9), - inputMatrix(9)(10), - inputMatrix(9)(11), - inputMatrix(9)(12), - inputMatrix(9)(13), - inputMatrix(10)(9), - inputMatrix(10)(10), - inputMatrix(10)(11), - inputMatrix(10)(12), - inputMatrix(10)(13), - inputMatrix(11)(9), - inputMatrix(11)(10), - inputMatrix(11)(11), - inputMatrix(11)(12), - inputMatrix(11)(13), - inputMatrix(12)(9), - inputMatrix(12)(10), - inputMatrix(12)(11), - inputMatrix(12)(12), - inputMatrix(12)(13) + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) ).min.toDouble - + // max filter - - result = MosaicRasterGDAL(ds, "", "", "GTiff", -1).filter(5, "max").flushCache() - + + result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "max").flushCache() + resultValues = result.getBand(1).values - + inputMatrix = values.toArray.grouped(50).toArray resultMatrix = resultValues.grouped(50).toArray - + // first block - + resultMatrix(10)(11) shouldBe Seq( - inputMatrix(8)(9), - inputMatrix(8)(10), - inputMatrix(8)(11), - inputMatrix(8)(12), - inputMatrix(8)(13), - inputMatrix(9)(9), - inputMatrix(9)(10), - inputMatrix(9)(11), - inputMatrix(9)(12), - inputMatrix(9)(13), - inputMatrix(10)(9), - inputMatrix(10)(10), - inputMatrix(10)(11), - inputMatrix(10)(12), - inputMatrix(10)(13), - inputMatrix(11)(9), - inputMatrix(11)(10), - inputMatrix(11)(11), - inputMatrix(11)(12), - inputMatrix(11)(13), - inputMatrix(12)(9), - inputMatrix(12)(10), - inputMatrix(12)(11), - inputMatrix(12)(12), - inputMatrix(12)(13) + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) ).max.toDouble } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala index 611bf8f77..4b7943d8a 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala @@ -4,7 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.collect_set +import org.apache.spark.sql.functions.{collect_list, collect_set} import org.scalatest.matchers.should.Matchers._ trait RST_CombineAvgBehaviors extends QueryTest { @@ -28,7 +28,7 @@ trait RST_CombineAvgBehaviors extends QueryTest { .select("path", "tiles") .groupBy("path") .agg( - rst_combineavg(collect_set($"tiles")).as("tiles") + rst_combineavg(collect_list($"tiles")).as("tiles") ) .select("tiles") @@ -38,7 +38,7 @@ trait RST_CombineAvgBehaviors extends QueryTest { //noException should be thrownBy spark.sql(""" - |select rst_combineavg(collect_set(tiles)) as tiles + |select rst_combineavg(collect_list(tiles)) as tiles |from ( | select path, rst_tessellate(tile, 2) as tiles | from source diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandBehaviors.scala index ef6466a88..d883fc5cc 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandBehaviors.scala @@ -4,7 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.{collect_set, lit} +import org.apache.spark.sql.functions.{collect_list, lit} import org.scalatest.matchers.should.Matchers._ trait RST_DerivedBandBehaviors extends QueryTest { @@ -40,7 +40,7 @@ trait RST_DerivedBandBehaviors extends QueryTest { .select("path", "tiles") .groupBy("path") .agg( - rst_derivedband(collect_set($"tiles"), lit(pyFuncCode), lit(funcName)).as("tiles") + rst_derivedband(collect_list($"tiles"), lit(pyFuncCode), lit(funcName)).as("tiles") ) .select("tiles") @@ -52,7 +52,7 @@ trait RST_DerivedBandBehaviors extends QueryTest { noException should be thrownBy spark.sql( """ |select rst_derivedband( - | collect_set(tiles), + | collect_list(tiles), |" |import numpy as np |def multiply(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize,raster_ysize, buf_radius, gt, **kwargs): diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala index 893d6bdf4..330345f20 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala @@ -4,7 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.collect_set +import org.apache.spark.sql.functions.collect_list import org.scalatest.matchers.should.Matchers._ trait RST_MergeBehaviors extends QueryTest { @@ -29,7 +29,7 @@ trait RST_MergeBehaviors extends QueryTest { .select("path", "tile") .groupBy("path") .agg( - collect_set("tile").as("tiles") + collect_list("tile").as("tiles") ) .select( rst_merge($"tiles").as("tile") @@ -41,7 +41,7 @@ trait RST_MergeBehaviors extends QueryTest { spark.sql(""" |select rst_merge(tiles) as tile |from ( - | select collect_set(tile) as tiles + | select collect_list(tile) as tiles | from ( | select path, rst_tessellate(tile, 3) as tile | from source @@ -55,7 +55,7 @@ trait RST_MergeBehaviors extends QueryTest { .select("path", "tile") .groupBy("path") .agg( - collect_set("tile").as("tiles") + collect_list("tile").as("tiles") ) .select( rst_merge($"tiles").as("tile") diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala index 050e5cb4d..38d3fc778 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala @@ -42,7 +42,7 @@ trait RST_TessellateBehaviors extends QueryTest { val result = gridTiles.select(explode(col("avg")).alias("a")).groupBy("a").count().collect() - result.length should be(462) + result.length should be(441) val netcdf = spark.read .format("gdal") diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformBehaviors.scala new file mode 100644 index 000000000..9ce449b13 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformBehaviors.scala @@ -0,0 +1,49 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_TransformBehaviors extends QueryTest { + + // noinspection MapGetGet + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("tile", rst_transform($"tile", lit(27700))) + .withColumn("bbox", st_aswkt(rst_boundingbox($"tile"))) + .select("bbox", "path", "tile") + .withColumn("avg", rst_avg($"tile")) + + rastersInMemory + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_transform(tile, 27700) from source + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("tile", rst_transform($"tile", lit(27700))) + .select("tile") + + val result = gridTiles.select(explode(col("avg")).alias("a")).groupBy("a").count().collect() + + result.length should be(7) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformTest.scala new file mode 100644 index 000000000..b7c10e548 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_TransformTest extends QueryTest with SharedSparkSessionGDAL with RST_TransformBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_Transform with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} From 591e1fda0cff3d68fbbf39e36b190ba13a3489f9 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 26 Feb 2024 14:35:47 +0000 Subject: [PATCH 07/26] Remove large file test, the test can only be run locally since the file was 2GB. --- .../multiread/RasterAsGridReaderTest.scala | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala index f174954cd..721201eaa 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala @@ -12,28 +12,6 @@ import java.nio.file.{Files, Paths} class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSessionGDAL { - test("Read big tif with Raster As Grid Reader") { - assume(System.getProperty("os.name") == "Linux") - MosaicContext.build(H3IndexSystem, JTS) - - val tif = "/modis/" - val filePath = getClass.getResource(tif).getPath - - val df = MosaicContext.read - .format("raster_to_grid") - .option("retile", "true") - .option("sizeInMB", "128") - .option("resolution", "1") - .load(filePath) - .select("measure") - - df.queryExecution.optimizedPlan - - noException should be thrownBy df.queryExecution.executedPlan - - df.count() - } - test("Read netcdf with Raster As Grid Reader") { assume(System.getProperty("os.name") == "Linux") MosaicContext.build(H3IndexSystem, JTS) From ff07073a228b2900f1b7ed13fbadd8339aa99fa7 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 26 Feb 2024 15:03:21 +0000 Subject: [PATCH 08/26] Fix python build. --- .github/actions/python_build/action.yml | 10 +++++----- .github/actions/scala_build/action.yml | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/actions/python_build/action.yml b/.github/actions/python_build/action.yml index 17e0c53f6..3a8230af5 100644 --- a/.github/actions/python_build/action.yml +++ b/.github/actions/python_build/action.yml @@ -13,15 +13,15 @@ runs: # - install pip libs # note: gdal requires the extra args cd python - pip install build wheel pyspark==${{ matrix.spark }} numpy==${{ matrix.numpy }} - pip install --no-build-isolation --no-cache-dir --force-reinstall gdal==${{ matrix.gdal }} - pip install . + sudo pip install build wheel pyspark==${{ matrix.spark }} numpy==${{ matrix.numpy }} + sudo pip install --no-build-isolation --no-cache-dir --force-reinstall gdal==${{ matrix.gdal }} + sudo pip install . - name: Test and build python package shell: bash run: | cd python - python -m unittest - python -m build + sudo python -m unittest + sudo python -m build - name: Copy python artifacts to GH Actions run shell: bash run: cp python/dist/*.whl staging \ No newline at end of file diff --git a/.github/actions/scala_build/action.yml b/.github/actions/scala_build/action.yml index b33c1b453..e2c19c0a8 100644 --- a/.github/actions/scala_build/action.yml +++ b/.github/actions/scala_build/action.yml @@ -38,7 +38,7 @@ runs: - name: Test and build the scala JAR - skip tests is false if: inputs.skip_tests == 'false' shell: bash - run: sudo mvn -q clean install + run: sudo mvn -q clean install -DskipTests -Dscoverage.skip - name: Build the scala JAR - skip tests is true if: inputs.skip_tests == 'true' shell: bash From 6d3bce6ebeca951bfa22cd45f03effa4e7dbc9ea Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 26 Feb 2024 15:24:24 +0000 Subject: [PATCH 09/26] Fix python build. --- python/test/utils/spark_test_case.py | 1 + .../labs/mosaic/functions/MosaicContext.scala | 12 +++++++++++- .../mosaic/functions/MosaicExpressionConfig.scala | 7 +++++++ .../com/databricks/labs/mosaic/gdal/MosaicGDAL.scala | 4 ++-- .../scala/com/databricks/labs/mosaic/package.scala | 1 + .../com/databricks/labs/mosaic/utils/PathUtils.scala | 6 +++--- .../labs/mosaic/models/knn/SpatialKNNBehaviors.scala | 4 ++-- 7 files changed, 27 insertions(+), 8 deletions(-) diff --git a/python/test/utils/spark_test_case.py b/python/test/utils/spark_test_case.py index 6ae23b1b3..c9747fbef 100644 --- a/python/test/utils/spark_test_case.py +++ b/python/test/utils/spark_test_case.py @@ -33,6 +33,7 @@ def setUpClass(cls) -> None: .getOrCreate() ) cls.spark.conf.set("spark.databricks.labs.mosaic.jar.autoattach", "false") + cls.spark.conf.set("spark.databricks.labs.mosaic.raster.tmp.prefix", "/") cls.spark.sparkContext.setLogLevel("FATAL") @classmethod diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 293d27593..9cec6033e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -1025,10 +1025,20 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends object MosaicContext extends Logging { - val tmpDir: String = FileUtils.createMosaicTempDir() + var _tmpDir: String = "" val mosaicVersion: String = "0.4.0" private var instance: Option[MosaicContext] = None + + def tmpDir(mosaicConfig: MosaicExpressionConfig): String = { + if (_tmpDir == "" || mosaicConfig == null) { + val prefix = mosaicConfig.getTmpPrefix + _tmpDir = FileUtils.createMosaicTempDir(prefix) + _tmpDir + } else { + _tmpDir + } + } def build(indexSystem: IndexSystem, geometryAPI: GeometryAPI): MosaicContext = { instance = Some(new MosaicContext(indexSystem, geometryAPI)) diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala index d6643f59b..b16b719cc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala @@ -35,6 +35,8 @@ case class MosaicExpressionConfig(configs: Map[String, String]) { def getCellIdType: DataType = IndexSystemFactory.getIndexSystem(getIndexSystem).cellIdType def getRasterBlockSize: Int = configs.getOrElse(MOSAIC_RASTER_BLOCKSIZE, MOSAIC_RASTER_BLOCKSIZE_DEFAULT).toInt + + def getTmpPrefix: String = configs.getOrElse(MOSAIC_RASTER_TMP_PREFIX, "/tmp") def setGDALConf(conf: RuntimeConfig): MosaicExpressionConfig = { val toAdd = conf.getAll.filter(_._1.startsWith(MOSAIC_GDAL_PREFIX)) @@ -56,6 +58,10 @@ case class MosaicExpressionConfig(configs: Map[String, String]) { def setRasterCheckpoint(checkpoint: String): MosaicExpressionConfig = { MosaicExpressionConfig(configs + (MOSAIC_RASTER_CHECKPOINT -> checkpoint)) } + + def setTmpPrefix(prefix: String): MosaicExpressionConfig = { + MosaicExpressionConfig(configs + (MOSAIC_RASTER_TMP_PREFIX -> prefix)) + } def setConfig(key: String, value: String): MosaicExpressionConfig = { MosaicExpressionConfig(configs + (key -> value)) @@ -75,6 +81,7 @@ object MosaicExpressionConfig { .setGeometryAPI(spark.conf.get(MOSAIC_GEOMETRY_API, JTS.name)) .setIndexSystem(spark.conf.get(MOSAIC_INDEX_SYSTEM, H3.name)) .setRasterCheckpoint(spark.conf.get(MOSAIC_RASTER_CHECKPOINT, MOSAIC_RASTER_CHECKPOINT_DEFAULT)) + .setTmpPrefix(spark.conf.get(MOSAIC_RASTER_TMP_PREFIX, "/tmp")) .setGDALConf(spark.conf) } diff --git a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala index b9e972d6f..6cc928edd 100644 --- a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala @@ -46,8 +46,8 @@ object MosaicGDAL extends Logging { /** Configures the GDAL environment. */ def configureGDAL(mosaicConfig: MosaicExpressionConfig): Unit = { - val CPL_TMPDIR = MosaicContext.tmpDir - val GDAL_PAM_PROXY_DIR = MosaicContext.tmpDir + val CPL_TMPDIR = MosaicContext.tmpDir(mosaicConfig) + val GDAL_PAM_PROXY_DIR = MosaicContext.tmpDir(mosaicConfig) gdal.SetConfigOption("GDAL_VRT_ENABLE_PYTHON", "YES") gdal.SetConfigOption("GDAL_DISABLE_READDIR_ON_OPEN", "TRUE") gdal.SetConfigOption("CPL_TMPDIR", CPL_TMPDIR) diff --git a/src/main/scala/com/databricks/labs/mosaic/package.scala b/src/main/scala/com/databricks/labs/mosaic/package.scala index eea63cd79..86bdbcec7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/package.scala +++ b/src/main/scala/com/databricks/labs/mosaic/package.scala @@ -22,6 +22,7 @@ package object mosaic { val MOSAIC_GDAL_NATIVE = "spark.databricks.labs.mosaic.gdal.native" val MOSAIC_RASTER_CHECKPOINT = "spark.databricks.labs.mosaic.raster.checkpoint" val MOSAIC_RASTER_CHECKPOINT_DEFAULT = "/dbfs/tmp/mosaic/raster/checkpoint" + val MOSAIC_RASTER_TMP_PREFIX = "spark.databricks.labs.mosaic.raster.tmp.prefix" val MOSAIC_RASTER_BLOCKSIZE = "spark.databricks.labs.mosaic.raster.blocksize" val MOSAIC_RASTER_BLOCKSIZE_DEFAULT = "128" diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala index 469bb0f44..aacd897f8 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala @@ -49,7 +49,7 @@ object PathUtils { } def createTmpFilePath(extension: String): String = { - val tmpDir = MosaicContext.tmpDir + val tmpDir = MosaicContext.tmpDir(null) val uuid = java.util.UUID.randomUUID.toString val outPath = s"$tmpDir/raster_${uuid.replace("-", "_")}.$extension" Files.createDirectories(Paths.get(outPath).getParent) @@ -80,9 +80,9 @@ object PathUtils { val fullFileName = copyFromPath.split("/").last val stemRegex = getStemRegex(inPath) - wildcardCopy(inPathDir, MosaicContext.tmpDir, stemRegex.toString) + wildcardCopy(inPathDir, MosaicContext.tmpDir(null), stemRegex.toString) - s"${MosaicContext.tmpDir}/$fullFileName" + s"${MosaicContext.tmpDir(null)}/$fullFileName" } def wildcardCopy(inDirPath: String, outDirPath: String, pattern: String): Unit = { diff --git a/src/test/scala/com/databricks/labs/mosaic/models/knn/SpatialKNNBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/models/knn/SpatialKNNBehaviors.scala index e9508c3c6..be9b5c402 100644 --- a/src/test/scala/com/databricks/labs/mosaic/models/knn/SpatialKNNBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/models/knn/SpatialKNNBehaviors.scala @@ -29,7 +29,7 @@ trait SpatialKNNBehaviors { this: AnyFlatSpec => val boroughs: DataFrame = getBoroughs(mc) - val tempLocation = MosaicContext.tmpDir + val tempLocation = MosaicContext.tmpDir(null) spark.sparkContext.setCheckpointDir(tempLocation) spark.sparkContext.setLogLevel("ERROR") @@ -94,7 +94,7 @@ trait SpatialKNNBehaviors { this: AnyFlatSpec => val boroughs: DataFrame = getBoroughs(mc) - val tempLocation = MosaicContext.tmpDir + val tempLocation = MosaicContext.tmpDir(null) spark.sparkContext.setCheckpointDir(tempLocation) spark.sparkContext.setLogLevel("ERROR") From 398386b997bd742e1045d94aa9c0430929fe2ea7 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 26 Feb 2024 15:31:33 +0000 Subject: [PATCH 10/26] Fix python build. --- .github/actions/python_build/action.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/actions/python_build/action.yml b/.github/actions/python_build/action.yml index 3a8230af5..17e0c53f6 100644 --- a/.github/actions/python_build/action.yml +++ b/.github/actions/python_build/action.yml @@ -13,15 +13,15 @@ runs: # - install pip libs # note: gdal requires the extra args cd python - sudo pip install build wheel pyspark==${{ matrix.spark }} numpy==${{ matrix.numpy }} - sudo pip install --no-build-isolation --no-cache-dir --force-reinstall gdal==${{ matrix.gdal }} - sudo pip install . + pip install build wheel pyspark==${{ matrix.spark }} numpy==${{ matrix.numpy }} + pip install --no-build-isolation --no-cache-dir --force-reinstall gdal==${{ matrix.gdal }} + pip install . - name: Test and build python package shell: bash run: | cd python - sudo python -m unittest - sudo python -m build + python -m unittest + python -m build - name: Copy python artifacts to GH Actions run shell: bash run: cp python/dist/*.whl staging \ No newline at end of file From 3609daab194c44a74e63b65b89f251a91dcdfdbe Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 26 Feb 2024 15:48:08 +0000 Subject: [PATCH 11/26] Fix python build. --- python/test/utils/spark_test_case.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/test/utils/spark_test_case.py b/python/test/utils/spark_test_case.py index c9747fbef..223f5d021 100644 --- a/python/test/utils/spark_test_case.py +++ b/python/test/utils/spark_test_case.py @@ -17,6 +17,8 @@ def setUpClass(cls) -> None: cls.library_location = f"{mosaic.__path__[0]}/lib/mosaic-{version('databricks-mosaic')}-jar-with-dependencies.jar" if not os.path.exists(cls.library_location): cls.library_location = f"{mosaic.__path__[0]}/lib/mosaic-{version('databricks-mosaic')}-SNAPSHOT-jar-with-dependencies.jar" + if not os.path.exists("/mosaic_test/"): + os.makedirs("/mosaic_test/") cls.spark = ( SparkSession.builder.master("local[*]") @@ -33,7 +35,7 @@ def setUpClass(cls) -> None: .getOrCreate() ) cls.spark.conf.set("spark.databricks.labs.mosaic.jar.autoattach", "false") - cls.spark.conf.set("spark.databricks.labs.mosaic.raster.tmp.prefix", "/") + cls.spark.conf.set("spark.databricks.labs.mosaic.raster.tmp.prefix", "/mosaic_test/") cls.spark.sparkContext.setLogLevel("FATAL") @classmethod From db50866b706e8dc7fa8c44ee821e8cf971e8771b Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 26 Feb 2024 16:00:30 +0000 Subject: [PATCH 12/26] Fix python build. --- python/test/utils/spark_test_case.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/test/utils/spark_test_case.py b/python/test/utils/spark_test_case.py index 223f5d021..42917f278 100644 --- a/python/test/utils/spark_test_case.py +++ b/python/test/utils/spark_test_case.py @@ -17,8 +17,11 @@ def setUpClass(cls) -> None: cls.library_location = f"{mosaic.__path__[0]}/lib/mosaic-{version('databricks-mosaic')}-jar-with-dependencies.jar" if not os.path.exists(cls.library_location): cls.library_location = f"{mosaic.__path__[0]}/lib/mosaic-{version('databricks-mosaic')}-SNAPSHOT-jar-with-dependencies.jar" - if not os.path.exists("/mosaic_test/"): - os.makedirs("/mosaic_test/") + + pwd_dir = os.getcwd() + tmp_dir = f"{pwd_dir}/mosaic_test/" + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) cls.spark = ( SparkSession.builder.master("local[*]") @@ -35,7 +38,7 @@ def setUpClass(cls) -> None: .getOrCreate() ) cls.spark.conf.set("spark.databricks.labs.mosaic.jar.autoattach", "false") - cls.spark.conf.set("spark.databricks.labs.mosaic.raster.tmp.prefix", "/mosaic_test/") + cls.spark.conf.set("spark.databricks.labs.mosaic.raster.tmp.prefix", tmp_dir) cls.spark.sparkContext.setLogLevel("FATAL") @classmethod From f78629a07dd6ce7966c58efce0d93791860e9d1d Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 26 Feb 2024 16:18:20 +0000 Subject: [PATCH 13/26] Fix python build. --- .../com/databricks/labs/mosaic/functions/MosaicContext.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 9cec6033e..281dd0eac 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -1031,7 +1031,7 @@ object MosaicContext extends Logging { private var instance: Option[MosaicContext] = None def tmpDir(mosaicConfig: MosaicExpressionConfig): String = { - if (_tmpDir == "" || mosaicConfig == null) { + if (_tmpDir == "" || mosaicConfig != null) { val prefix = mosaicConfig.getTmpPrefix _tmpDir = FileUtils.createMosaicTempDir(prefix) _tmpDir From 67b73850356dbd334349224a5507fe386032b7d6 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Fri, 1 Mar 2024 13:32:38 +0000 Subject: [PATCH 14/26] Fix python build. --- python/test/utils/spark_test_case.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/test/utils/spark_test_case.py b/python/test/utils/spark_test_case.py index 42917f278..640713ba7 100644 --- a/python/test/utils/spark_test_case.py +++ b/python/test/utils/spark_test_case.py @@ -20,8 +20,11 @@ def setUpClass(cls) -> None: pwd_dir = os.getcwd() tmp_dir = f"{pwd_dir}/mosaic_test/" + check_dir = f"{pwd_dir}/checkpoint" if not os.path.exists(tmp_dir): os.makedirs(tmp_dir) + if not os.path.exists(check_dir): + os.makedirs(check_dir) cls.spark = ( SparkSession.builder.master("local[*]") @@ -39,6 +42,7 @@ def setUpClass(cls) -> None: ) cls.spark.conf.set("spark.databricks.labs.mosaic.jar.autoattach", "false") cls.spark.conf.set("spark.databricks.labs.mosaic.raster.tmp.prefix", tmp_dir) + cls.spark.conf.set("spark.databricks.labs.mosaic.raster.checkpoint", check_dir) cls.spark.sparkContext.setLogLevel("FATAL") @classmethod From 11d15be2a06d8260957cdd8d7964206c6e33542d Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Fri, 1 Mar 2024 13:35:00 +0000 Subject: [PATCH 15/26] Fix python build. --- python/test/test_raster_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/test/test_raster_functions.py b/python/test/test_raster_functions.py index be5fd4656..834c02a5c 100644 --- a/python/test/test_raster_functions.py +++ b/python/test/test_raster_functions.py @@ -19,7 +19,7 @@ def test_read_raster(self): result.metadata["LONGNAME"], "MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m", ) - self.assertEqual(result.tile["driver"], "GTiff") + self.assertEqual(result.tile["metadata"]["driver"], "GTiff") def test_raster_scalar_functions(self): result = ( @@ -115,7 +115,7 @@ def test_raster_flatmap_functions(self): ) tessellate_result.write.format("noop").mode("overwrite").save() - self.assertEqual(tessellate_result.count(), 66) + self.assertEqual(tessellate_result.count(), 63) overlap_result = ( self.generate_singleband_raster_df() From 5b4233471692c3c4f29232af6d970e654d0a297a Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Fri, 1 Mar 2024 13:48:33 +0000 Subject: [PATCH 16/26] Fix python build. --- .../mosaic/core/raster/operator/clip/RasterClipByVector.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala index 56c29563f..ddae2fed2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala @@ -49,7 +49,7 @@ object RasterClipByVector { val result = GDALWarp.executeWarp( resultFileName, Seq(raster), - command = s"gdalwarp -wo CUTLINE_ALL_TOUCHED=TRUE -wo SOURCE_EXTRA=3 -cutline $shapeFileName -crop_to_cutline" + command = s"gdalwarp -wo CUTLINE_ALL_TOUCHED=TRUE -cutline $shapeFileName -crop_to_cutline" ) VectorClipper.cleanUpClipper(shapeFileName) From 86a3901fac68f464f061f4ea375af4618e1b6d3a Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Fri, 1 Mar 2024 16:46:48 +0000 Subject: [PATCH 17/26] Fix python build. --- .github/actions/scala_build/action.yml | 2 +- python/test/test_raster_functions.py | 3 +- python/test/test_vector_functions.py | 4 +-- .../core/raster/gdal/MosaicRasterGDAL.scala | 33 ++++++++++--------- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/.github/actions/scala_build/action.yml b/.github/actions/scala_build/action.yml index e2c19c0a8..b9e366658 100644 --- a/.github/actions/scala_build/action.yml +++ b/.github/actions/scala_build/action.yml @@ -27,7 +27,7 @@ runs: sudo apt-get update -y # - install natives sudo apt-get install -y unixodbc libcurl3-gnutls libsnappy-dev libopenjp2-7 - sudo apt-get install -y gdal-bin libgdal-dev python3-numpy python3-gdal + sudo apt-get install -y gdal-bin libgdal-dev python3-numpy python3-gdal zip unzip # - install pip libs pip install --upgrade pip pip install gdal==${{ matrix.gdal }} diff --git a/python/test/test_raster_functions.py b/python/test/test_raster_functions.py index 834c02a5c..cda55143d 100644 --- a/python/test/test_raster_functions.py +++ b/python/test/test_raster_functions.py @@ -187,11 +187,12 @@ def test_netcdf_load_tessellate_clip_merge(self): df = ( self.spark.read.format("gdal") - .option("raster.read.strategy", "retile_on_read") + .option("raster.read.strategy", "in_memory") .load( "test/data/prAdjust_day_HadGEM2-CC_SMHI-DBSrev930-GFD-1981-2010-postproc_rcp45_r1i1p1_20201201-20201231.nc" ) .select(api.rst_separatebands("tile").alias("tile")) + .repartition(self.spark.sparkContext.defaultParallelism) .withColumn( "timestep", element_at( diff --git a/python/test/test_vector_functions.py b/python/test/test_vector_functions.py index afca19778..c2bb2be74 100644 --- a/python/test/test_vector_functions.py +++ b/python/test/test_vector_functions.py @@ -37,9 +37,7 @@ def test_st_z(self): .select(col("id").cast("double")) .withColumn( "points", - api.st_geomfromwkt( - concat(lit("POINT (9 9 "), "id", lit(")")) - ), + api.st_geomfromwkt(concat(lit("POINT (9 9 "), "id", lit(")"))), ) .withColumn("z", api.st_z("points")) .collect() diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index 10f04416d..2749d7cfa 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -76,9 +76,10 @@ case class MosaicRasterGDAL( * @return * The raster's driver short name. */ - def getDriversShortName: String = driverShortName.getOrElse( - Try(raster.GetDriver().getShortName).getOrElse("NONE") - ) + def getDriversShortName: String = + driverShortName.getOrElse( + Try(raster.GetDriver().getShortName).getOrElse("NONE") + ) /** * @return @@ -469,7 +470,8 @@ case class MosaicRasterGDAL( if (Files.isDirectory(Paths.get(tmpPath))) { val parentDir = Paths.get(tmpPath).getParent.toString val fileName = Paths.get(tmpPath).getFileName.toString - SysUtils.runScript(Array("/bin/sh", "-c", s"cd $parentDir && zip -r0 $fileName.zip $fileName")) + val prompt = SysUtils.runScript(Array("/bin/sh", "-c", s"cd $parentDir && zip -r0 $fileName.zip $fileName")) + if (prompt._3.nonEmpty) throw new Exception(s"Error zipping file: ${prompt._3}. Please verify that zip is installed. Run 'apt install zip'.") s"$tmpPath.zip" } else { tmpPath @@ -566,12 +568,11 @@ case class MosaicRasterGDAL( val gdalError = gdal.GetLastErrorMsg() val error = path match { case Some(_) => "" - case None => - s""" - |Subdataset $subsetName not found! - |Available subdatasets: - | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} - | """.stripMargin + case None => s""" + |Subdataset $subsetName not found! + |Available subdatasets: + | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} + | """.stripMargin } val sanitized = PathUtils.getCleanPath(path.getOrElse(PathUtils.NO_PATH_STRING)) val subdatasetPath = PathUtils.getSubdatasetPath(sanitized) @@ -584,11 +585,13 @@ case class MosaicRasterGDAL( "path" -> path.getOrElse(PathUtils.NO_PATH_STRING), "parentPath" -> parentPath, "driver" -> getDriversShortName, - "last_error" -> - s""" - |GDAL Error: $gdalError - |$error - |""".stripMargin + "last_error" -> { + if (gdalError.nonEmpty || error.nonEmpty) s""" + |GDAL Error: $gdalError + |$error + |""".stripMargin + else "" + } ) MosaicRasterGDAL(ds, createInfo, -1) } From 61ba42b30fc2e0787640f7839df1d4a73f25b697 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Fri, 1 Mar 2024 16:48:35 +0000 Subject: [PATCH 18/26] Fix python build. --- .github/actions/scala_build/action.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/actions/scala_build/action.yml b/.github/actions/scala_build/action.yml index b9e366658..8a0d20359 100644 --- a/.github/actions/scala_build/action.yml +++ b/.github/actions/scala_build/action.yml @@ -38,7 +38,7 @@ runs: - name: Test and build the scala JAR - skip tests is false if: inputs.skip_tests == 'false' shell: bash - run: sudo mvn -q clean install -DskipTests -Dscoverage.skip + run: sudo mvn -q clean install - name: Build the scala JAR - skip tests is true if: inputs.skip_tests == 'true' shell: bash From 778e3be1a2587b592098f1354ca40d3f6ed04414 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Fri, 1 Mar 2024 17:07:35 +0000 Subject: [PATCH 19/26] Fix python build. --- .../databricks/labs/mosaic/functions/MosaicContext.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 281dd0eac..5fd5f6e6b 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -682,15 +682,17 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_maketiles(input: Column, driver: String, size: Int, withCheckpoint: Boolean): Column = ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) def rst_maketiles(input: Column): Column = - ColumnAdapter(RST_MakeTiles(input.expr, lit(MOSAIC_NO_DRIVER).expr, lit(-1).expr, lit(false).expr, expressionConfig)) + ColumnAdapter(RST_MakeTiles(input.expr, lit("no_driver").expr, lit(-1).expr, lit(false).expr, expressionConfig)) def rst_maketiles(input: Column, size: Int): Column = - ColumnAdapter(RST_MakeTiles(input.expr, lit(MOSAIC_NO_DRIVER).expr, lit(size).expr, lit(false).expr, expressionConfig)) + ColumnAdapter(RST_MakeTiles(input.expr, lit("no_driver").expr, lit(size).expr, lit(false).expr, expressionConfig)) def rst_maketiles(input: Column, driver: String): Column = ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(-1).expr, lit(false).expr, expressionConfig)) + def rst_maketiles(input: Column, driver: String, size: Int): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(size).expr, lit(false).expr, expressionConfig)) def rst_maketiles(input: Column, driver: String, withCheckpoint: Boolean): Column = ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(-1).expr, lit(withCheckpoint).expr, expressionConfig)) def rst_maketiles(input: Column, size: Int, withCheckpoint: Boolean): Column = - ColumnAdapter(RST_MakeTiles(input.expr, lit(MOSAIC_NO_DRIVER).expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) + ColumnAdapter(RST_MakeTiles(input.expr, lit("no_driver").expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) def rst_max(raster: Column): Column = ColumnAdapter(RST_Max(raster.expr, expressionConfig)) def rst_min(raster: Column): Column = ColumnAdapter(RST_Min(raster.expr, expressionConfig)) def rst_median(raster: Column): Column = ColumnAdapter(RST_Median(raster.expr, expressionConfig)) From d92b93d8ab9889f8f11903320a6f1f9a9bec5654 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Fri, 1 Mar 2024 17:20:57 +0000 Subject: [PATCH 20/26] Fix R build. --- .../labs/mosaic/functions/MosaicContext.scala | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 5fd5f6e6b..40be5cb24 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -681,18 +681,10 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(RST_MakeTiles(input.expr, driver.expr, size.expr, withCheckpoint.expr, expressionConfig)) def rst_maketiles(input: Column, driver: String, size: Int, withCheckpoint: Boolean): Column = ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) - def rst_maketiles(input: Column): Column = - ColumnAdapter(RST_MakeTiles(input.expr, lit("no_driver").expr, lit(-1).expr, lit(false).expr, expressionConfig)) - def rst_maketiles(input: Column, size: Int): Column = - ColumnAdapter(RST_MakeTiles(input.expr, lit("no_driver").expr, lit(size).expr, lit(false).expr, expressionConfig)) - def rst_maketiles(input: Column, driver: String): Column = - ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(-1).expr, lit(false).expr, expressionConfig)) def rst_maketiles(input: Column, driver: String, size: Int): Column = ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(size).expr, lit(false).expr, expressionConfig)) - def rst_maketiles(input: Column, driver: String, withCheckpoint: Boolean): Column = - ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(-1).expr, lit(withCheckpoint).expr, expressionConfig)) - def rst_maketiles(input: Column, size: Int, withCheckpoint: Boolean): Column = - ColumnAdapter(RST_MakeTiles(input.expr, lit("no_driver").expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) + def rst_maketiles(input: Column): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit("no_driver").expr, lit(-1).expr, lit(false).expr, expressionConfig)) def rst_max(raster: Column): Column = ColumnAdapter(RST_Max(raster.expr, expressionConfig)) def rst_min(raster: Column): Column = ColumnAdapter(RST_Min(raster.expr, expressionConfig)) def rst_median(raster: Column): Column = ColumnAdapter(RST_Median(raster.expr, expressionConfig)) From 7d803bed65316242b5a71e74e04731ea917206d0 Mon Sep 17 00:00:00 2001 From: Stuart Lynn Date: Fri, 1 Mar 2024 17:34:29 +0000 Subject: [PATCH 21/26] fix R tests --- R/sparkR-mosaic/sparkrMosaic/DESCRIPTION | 4 ++-- .../sparkrMosaic/tests/testthat/testRasterFunctions.R | 6 +++--- R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION | 4 ++-- .../sparklyrMosaic/tests/testthat/testRasterFunctions.R | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION b/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION index f689fe17a..876a46cf7 100644 --- a/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION +++ b/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION @@ -8,7 +8,7 @@ Description: This package extends SparkR to bring the Databricks Mosaic for geos License: Databricks Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 Collate: 'enableGDAL.R' 'enableMosaic.R' @@ -20,4 +20,4 @@ Imports: Suggests: testthat (>= 3.0.0), readr (>= 2.1.5) -Config/testthat/edition: 3 \ No newline at end of file +Config/testthat/edition: 3 diff --git a/R/sparkR-mosaic/sparkrMosaic/tests/testthat/testRasterFunctions.R b/R/sparkR-mosaic/sparkrMosaic/tests/testthat/testRasterFunctions.R index 36296e9d8..6e23454dc 100644 --- a/R/sparkR-mosaic/sparkrMosaic/tests/testthat/testRasterFunctions.R +++ b/R/sparkR-mosaic/sparkrMosaic/tests/testthat/testRasterFunctions.R @@ -15,7 +15,7 @@ test_that("mosaic can read single-band GeoTiff", { expect_equal(row$srid, 0) expect_equal(row$bandCount, 1) expect_equal(row$metadata[[1]]$LONGNAME, "MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m") - expect_equal(row$tile[[1]]$driver, "GTiff") + expect_equal(row$tile[[1]]$metadata$driver, "GTiff") }) @@ -61,7 +61,7 @@ test_that("raster flatmap functions behave as intended", { tessellate_sdf <- withColumn(tessellate_sdf, "rst_tessellate", rst_tessellate(column("tile"), lit(3L))) expect_no_error(write.df(tessellate_sdf, source = "noop", mode = "overwrite")) - expect_equal(nrow(tessellate_sdf), 66) + expect_equal(nrow(tessellate_sdf), 63) overlap_sdf <- generate_singleband_raster_df() overlap_sdf <- withColumn(overlap_sdf, "rst_to_overlapping_tiles", rst_to_overlapping_tiles(column("tile"), lit(200L), lit(200L), lit(10L))) @@ -117,7 +117,7 @@ test_that("the tessellate-join-clip-merge flow works on NetCDF files", { raster_sdf <- read.df( path = "sparkrMosaic/tests/testthat/data/prAdjust_day_HadGEM2-CC_SMHI-DBSrev930-GFD-1981-2010-postproc_rcp45_r1i1p1_20201201-20201231.nc", source = "gdal", - raster.read.strategy = "retile_on_read" + raster.read.strategy = "in_memory" ) raster_sdf <- withColumn(raster_sdf, "tile", rst_separatebands(column("tile"))) diff --git a/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION b/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION index c9d2048fb..315e6bf3c 100644 --- a/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION +++ b/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION @@ -8,7 +8,7 @@ Description: This package extends sparklyr to bring the Databricks Mosaic for ge License: Databricks Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 Collate: 'enableGDAL.R' 'enableMosaic.R' @@ -20,4 +20,4 @@ Suggests: testthat (>= 3.0.0), sparklyr.nested (>= 0.0.4), readr (>= 2.1.5) -Config/testthat/edition: 3 \ No newline at end of file +Config/testthat/edition: 3 diff --git a/R/sparklyr-mosaic/sparklyrMosaic/tests/testthat/testRasterFunctions.R b/R/sparklyr-mosaic/sparklyrMosaic/tests/testthat/testRasterFunctions.R index 3bb021c64..3cf016fa7 100644 --- a/R/sparklyr-mosaic/sparklyrMosaic/tests/testthat/testRasterFunctions.R +++ b/R/sparklyr-mosaic/sparklyrMosaic/tests/testthat/testRasterFunctions.R @@ -18,7 +18,7 @@ test_that("mosaic can read single-band GeoTiff", { expect_equal(row$srid, 0) expect_equal(row$bandCount, 1) expect_equal(row$metadata[[1]]$LONGNAME, "MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m") - expect_equal(row$tile[[1]]$driver, "GTiff") + expect_equal(row$tile[[1]]$metadata$driver, "GTiff") }) @@ -90,7 +90,7 @@ test_that("raster flatmap functions behave as intended", { mutate(rst_tessellate = rst_tessellate(tile, 3L)) expect_no_error(spark_write_source(tessellate_sdf, "noop", mode = "overwrite")) - expect_equal(sdf_nrow(tessellate_sdf), 66) + expect_equal(sdf_nrow(tessellate_sdf), 63) overlap_sdf <- generate_singleband_raster_df() %>% mutate(rst_to_overlapping_tiles = rst_to_overlapping_tiles(tile, 200L, 200L, 10L)) @@ -157,7 +157,7 @@ test_that("the tessellate-join-clip-merge flow works on NetCDF files", { name = "raster_raw", source = "gdal", path = "data/prAdjust_day_HadGEM2-CC_SMHI-DBSrev930-GFD-1981-2010-postproc_rcp45_r1i1p1_20201201-20201231.nc", - options = list("raster.read.strategy" = "retile_on_read") + options = list("raster.read.strategy" = "in_memory") ) %>% mutate(tile = rst_separatebands(tile)) %>% sdf_register("raster") From 5e9e84696f1cfb8a5e6c8c3b5643d907d5bde676 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Fri, 1 Mar 2024 17:57:38 +0000 Subject: [PATCH 22/26] Fix gribs build. --- .../datasource/GDALFileFormatTest.scala | 66 ++++++++++--------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala index 623993b01..16017d29e 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala @@ -34,36 +34,6 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .take(1) } - - test("Read grib with GDALFileFormat") { - assume(System.getProperty("os.name") == "Linux") - - val grib = "/binary/grib-cams/" - val filePath = getClass.getResource(grib).getPath - - noException should be thrownBy spark.read - .format("gdal") - .option("extensions", "grb") - .option("raster.read.strategy", "retile_on_read") - .load(filePath) - .take(1) - - noException should be thrownBy spark.read - .format("gdal") - .option("extensions", "grb") - .option("raster.read.strategy", "retile_on_read") - .load(filePath) - .take(1) - - noException should be thrownBy spark.read - .format("gdal") - .option("extensions", "grb") - .option("raster.read.strategy", "retile_on_read") - .load(filePath) - .select("metadata") - .take(1) - - } test("Read tif with GDALFileFormat") { assume(System.getProperty("os.name") == "Linux") @@ -82,14 +52,16 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .load(filePath) .take(1) - noException should be thrownBy spark.read + // noException should be thrownBy + + spark.read .format("gdal") .option("driverName", "TIF") .load(filePath) .select("metadata") .take(1) - noException should be thrownBy spark.read + noException should be thrownBy spark.read .format("gdal") .option(MOSAIC_RASTER_READ_STRATEGY, "retile_on_read") .load(filePath) @@ -138,4 +110,34 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { } + test("Read grib with GDALFileFormat") { + assume(System.getProperty("os.name") == "Linux") + + val grib = "/binary/grib-cams/" + val filePath = getClass.getResource(grib).getPath + + noException should be thrownBy spark.read + .format("gdal") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") + .load(filePath) + .take(1) + + noException should be thrownBy spark.read + .format("gdal") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") + .load(filePath) + .take(1) + + noException should be thrownBy spark.read + .format("gdal") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") + .load(filePath) + .select("metadata") + .take(1) + + } + } From 71db41a9e9126d507ed13c621944d5350a66072b Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 4 Mar 2024 12:39:52 +0000 Subject: [PATCH 23/26] Fix gribs build. --- .../mosaic/core/raster/gdal/MosaicRasterGDAL.scala | 3 +++ .../labs/mosaic/datasource/GDALFileFormatTest.scala | 10 ++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index 2749d7cfa..bb9d09f25 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -411,6 +411,9 @@ case class MosaicRasterGDAL( def getMemSize: Long = { if (memSize == -1) { val toRead = if (path.startsWith("/vsizip/")) path.replace("/vsizip/", "") else path + if (Files.notExists(Paths.get(toRead))) { + throw new Exception(s"File not found: ${gdal.GetLastErrorMsg()}") + } Files.size(Paths.get(toRead)) } else { memSize diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala index 16017d29e..c51190ea6 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala @@ -51,9 +51,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .option("driverName", "TIF") .load(filePath) .take(1) - - // noException should be thrownBy - + spark.read .format("gdal") .option("driverName", "TIF") @@ -61,7 +59,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .select("metadata") .take(1) - noException should be thrownBy spark.read + spark.read .format("gdal") .option(MOSAIC_RASTER_READ_STRATEGY, "retile_on_read") .load(filePath) @@ -116,7 +114,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { val grib = "/binary/grib-cams/" val filePath = getClass.getResource(grib).getPath - noException should be thrownBy spark.read + spark.read .format("gdal") .option("extensions", "grb") .option("raster.read.strategy", "retile_on_read") @@ -130,7 +128,7 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .load(filePath) .take(1) - noException should be thrownBy spark.read + spark.read .format("gdal") .option("extensions", "grb") .option("raster.read.strategy", "retile_on_read") From 850cb55a71b9fbb19a5533558826e17e0ee527ed Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 4 Mar 2024 17:02:50 +0000 Subject: [PATCH 24/26] Fix gribs build. --- .../mosaic/datasource/gdal/ReTileOnRead.scala | 2 +- .../labs/mosaic/utils/PathUtils.scala | 37 ++++++++++++------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala index 867167c58..b68a580d2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala @@ -93,7 +93,7 @@ object ReTileOnRead extends ReadStrategy { val uuid = getUUID(status) val sizeInMB = options.getOrElse("sizeInMB", "16").toInt - val tmpPath = PathUtils.copyToTmp(inPath) + var tmpPath = PathUtils.copyToTmpWithRetry(inPath, 5) val tiles = localSubdivide(tmpPath, inPath, sizeInMB) val rows = tiles.map(tile => { diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala index aacd897f8..2f896c046 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala @@ -2,7 +2,8 @@ package com.databricks.labs.mosaic.utils import com.databricks.labs.mosaic.functions.MosaicContext -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Path, Paths} +import scala.jdk.CollectionConverters._ object PathUtils { @@ -63,7 +64,7 @@ object PathUtils { if (filePath.endsWith("\"")) result = result.dropRight(1) result } - + def getStemRegex(path: String): String = { val cleanPath = replaceDBFSTokens(path) val fileName = Paths.get(cleanPath).getFileName.toString @@ -72,15 +73,25 @@ object PathUtils { val stemRegex = s"$stemEscaped\\..*".r stemRegex.toString } + + def copyToTmpWithRetry(inPath: String, retries: Int = 3): String = { + var tmpPath = copyToTmp(inPath) + var i = 0 + while (Files.notExists(Paths.get(tmpPath)) && i < retries) { + tmpPath = copyToTmp(inPath) + i += 1 + } + tmpPath + } def copyToTmp(inPath: String): String = { val copyFromPath = replaceDBFSTokens(inPath) val inPathDir = Paths.get(copyFromPath).getParent.toString - + val fullFileName = copyFromPath.split("/").last val stemRegex = getStemRegex(inPath) - wildcardCopy(inPathDir, MosaicContext.tmpDir(null), stemRegex.toString) + wildcardCopy(inPathDir, MosaicContext.tmpDir(null), stemRegex) s"${MosaicContext.tmpDir(null)}/$fullFileName" } @@ -93,17 +104,17 @@ object PathUtils { val toCopy = Files .list(Paths.get(copyFromPath)) .filter(_.getFileName.toString.matches(pattern)) - - toCopy.forEach(path => { + .collect(java.util.stream.Collectors.toList[Path]) + .asScala + + for (path <- toCopy) { val destination = Paths.get(copyToPath, path.getFileName.toString) - //noinspection SimplifyBooleanMatch - Files.isDirectory(path) match { - case true => FileUtils.copyDirectory(path.toFile, destination.toFile) - case false => Files.copy(path, destination) - } - }) + // noinspection SimplifyBooleanMatch + if (Files.isDirectory(path)) FileUtils.copyDirectory(path.toFile, destination.toFile) + else FileUtils.copyFile(path.toFile, destination.toFile) + } } - + def parseUnzippedPathFromExtracted(lastExtracted: String, extension: String): String = { val trimmed = lastExtracted.replace("extracting: ", "").replace(" ", "") val indexOfFormat = trimmed.indexOf(s".$extension/") From e84834b198e772435daa33f1a56b8ef75ea77656 Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Mon, 4 Mar 2024 17:57:10 +0000 Subject: [PATCH 25/26] Fix R build. --- R/sparkR-mosaic/tests.R | 8 ++++++-- R/sparklyr-mosaic/tests.R | 2 ++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/R/sparkR-mosaic/tests.R b/R/sparkR-mosaic/tests.R index 556b48b0f..7253cac2d 100644 --- a/R/sparkR-mosaic/tests.R +++ b/R/sparkR-mosaic/tests.R @@ -21,11 +21,15 @@ print("Looking for mosaic jar in") mosaic_jar_path <- paste0(staging_dir, mosaic_jar) print(mosaic_jar_path) +pwd <- getwd() spark <- sparkR.session( master = "local[*]" - ,sparkJars = mosaic_jar_path + ,sparkJars = mosaic_jar_path, + sparkConfig = list( + spark.databricks.labs.mosaic.raster.tmp.prefix = paste0(pwd, "/mosaic_tmp", sep="") + ,spark.databricks.labs.mosaic.raster.checkpoint = paste0(pwd, "/mosaic_checkpoint", sep="") + ) ) - enableMosaic() testthat::test_local(path="./sparkrMosaic") \ No newline at end of file diff --git a/R/sparklyr-mosaic/tests.R b/R/sparklyr-mosaic/tests.R index eef6d7c8b..b806a3d73 100644 --- a/R/sparklyr-mosaic/tests.R +++ b/R/sparklyr-mosaic/tests.R @@ -21,6 +21,8 @@ print(paste("Looking for mosaic jar in", mosaic_jar_path)) config <- sparklyr::spark_config() config$`sparklyr.jars.default` <- c(mosaic_jar_path) +config$`spark.databricks.labs.mosaic.raster.tmp.prefix` <- paste0(getwd(), "/mosaic_tmp", sep="") +config$`spark.databricks.labs.mosaic.raster.checkpoint` <- paste0(getwd(), "/mosaic_checkpoint", sep="") sc <- spark_connect(master="local[*]", config=config) enableMosaic(sc) From c53961521134f288e596dfa56a836470cf84ff2a Mon Sep 17 00:00:00 2001 From: "milos.colic" Date: Tue, 5 Mar 2024 16:54:06 +0000 Subject: [PATCH 26/26] Fix coverage tests. Fix convolve operations. --- python/mosaic/api/raster.py | 91 +++++++++++++++- .../labs/mosaic/core/raster/api/GDAL.scala | 1 - .../mosaic/core/raster/gdal/GDALBlock.scala | 3 +- .../raster/gdal/MosaicRasterBandGDAL.scala | 64 ++++++----- .../core/raster/gdal/MosaicRasterGDAL.scala | 15 ++- .../core/types/model/MosaicRasterTile.scala | 7 +- .../expressions/raster/RST_Convolve.scala | 11 +- .../expressions/raster/RST_FromFile.scala | 3 +- .../expressions/raster/RST_MakeTiles.scala | 22 ++-- .../labs/mosaic/functions/MosaicContext.scala | 4 + .../labs/mosaic/utils/FileUtils.scala | 3 +- .../raster/RST_ConvolveBehaviors.scala | 43 ++++++++ .../expressions/raster/RST_ConvolveTest.scala | 32 ++++++ .../raster/RST_FilterBehaviors.scala | 42 ++++++++ .../raster/RST_MakeTilesBehaviors.scala | 101 ++++++++++++++++++ .../raster/RST_MakeTilesTest.scala | 31 ++++++ 16 files changed, 425 insertions(+), 48 deletions(-) create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveTest.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesBehaviors.scala create mode 100644 src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesTest.scala diff --git a/python/mosaic/api/raster.py b/python/mosaic/api/raster.py index c61bafcd2..c703b9134 100644 --- a/python/mosaic/api/raster.py +++ b/python/mosaic/api/raster.py @@ -5,7 +5,6 @@ from pyspark.sql.functions import lit from typing import Any - ####################### # Raster functions # ####################### @@ -15,16 +14,19 @@ "rst_boundingbox", "rst_clip", "rst_combineavg", + "rst_convolve", "rst_derivedband", "rst_frombands", "rst_fromcontent", "rst_fromfile", + "rst_filter", "rst_georeference", "rst_getnodata", "rst_getsubdataset", "rst_height", "rst_initnodata", "rst_isempty", + "rst_maketiles", "rst_mapalgebra", "rst_memsize", "rst_merge", @@ -157,6 +159,32 @@ def rst_combineavg(raster_tiles: ColumnOrName) -> Column: ) +def rst_convolve(raster_tile: ColumnOrName, kernel: ColumnOrName) -> Column: + """ + Applies a convolution filter to the raster. + The result is Mosaic raster tile struct column to the filtered raster. + The result is stored in the checkpoint directory. + + Parameters + ---------- + raster_tile : Column (RasterTileType) + Mosaic raster tile struct column. + kernel : Column (ArrayType(ArrayType(DoubleType))) + The kernel to apply to the raster. + + Returns + ------- + Column (RasterTileType) + Mosaic raster tile struct column. + + """ + return config.mosaic_context.invoke_function( + "rst_convolve", + pyspark_to_java_column(raster_tile), + pyspark_to_java_column(kernel), + ) + + def rst_derivedband( raster_tile: ColumnOrName, python_func: ColumnOrName, func_name: ColumnOrName ) -> Column: @@ -317,6 +345,43 @@ def rst_isempty(raster_tile: ColumnOrName) -> Column: ) +def rst_maketiles(input: ColumnOrName, driver: Any = "no_driver", size_in_mb: Any = -1, + with_checkpoint: Any = False) -> Column: + """ + Tiles the raster into tiles of the given size. + :param input: If the raster is stored on disc, the path + to the raster is provided. If the raster is stored in memory, the bytes of + the raster are provided. + :param driver: The driver to use for reading the raster. If not specified, the driver is + inferred from the file extension. If the input is a byte array, the driver + has to be specified. + :param size_in_mb: The size of the tiles in MB. If set to -1, the file is loaded and returned + as a single tile. If set to 0, the file is loaded and subdivided into + tiles of size 64MB. If set to a positive value, the file is loaded and + subdivided into tiles of the specified size. If the file is too big to fit + in memory, it is subdivided into tiles of size 64MB. + :param with_checkpoint: If set to true, the tiles are written to the checkpoint directory. If set + to false, the tiles are returned as a in-memory byte arrays. + :return: A collection of tiles of the raster. + """ + if type(size_in_mb) == int: + size_in_mb = lit(size_in_mb) + + if type(with_checkpoint) == bool: + with_checkpoint = lit(with_checkpoint) + + if type(driver) == str: + driver = lit(driver) + + return config.mosaic_context.invoke_function( + "rst_maketiles", + pyspark_to_java_column(input), + pyspark_to_java_column(driver), + pyspark_to_java_column(size_in_mb), + pyspark_to_java_column(with_checkpoint), + ) + + def rst_mapalgebra(raster_tile: ColumnOrName, json_spec: ColumnOrName) -> Column: """ Parameters @@ -631,7 +696,7 @@ def rst_rastertogridmin(raster_tile: ColumnOrName, resolution: ColumnOrName) -> def rst_rastertoworldcoord( - raster_tile: ColumnOrName, x: ColumnOrName, y: ColumnOrName + raster_tile: ColumnOrName, x: ColumnOrName, y: ColumnOrName ) -> Column: """ Computes the world coordinates of the raster pixel at the given x and y coordinates. @@ -1062,6 +1127,28 @@ def rst_fromfile(raster_path: ColumnOrName, size_in_mb: Any = -1) -> Column: ) +def rst_filter(raster_tile: ColumnOrName, kernel_size: Any, operation: Any) -> Column: + """ + Applies a filter to the raster. + :param raster_tile: Mosaic raster tile struct column. + :param kernel_size: The size of the kernel. Has to be odd. + :param operation: The operation to apply to the kernel. + :return: A new raster tile with the filter applied. + """ + if type(kernel_size) == int: + kernel_size = lit(kernel_size) + + if type(operation) == str: + operation = lit(operation) + + return config.mosaic_context.invoke_function( + "rst_filter", + pyspark_to_java_column(raster_tile), + pyspark_to_java_column(kernel_size), + pyspark_to_java_column(operation), + ) + + def rst_to_overlapping_tiles( raster_tile: ColumnOrName, width: ColumnOrName, diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala index 6e2fee0f6..3517f9368 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala @@ -101,7 +101,6 @@ object GDAL { ): MosaicRasterGDAL = { inputDT match { case StringType => - val path = inputRaster.asInstanceOf[UTF8String].toString MosaicRasterGDAL.readRaster(createInfo) case BinaryType => val bytes = inputRaster.asInstanceOf[Array[Byte]] diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala index 8c5c7a495..5014b0ee3 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala @@ -57,7 +57,6 @@ case class GDALBlock[T: ClassTag]( } } - // TODO: Test and fix, not tested, other filters work. def convolveAt(x: Int, y: Int, kernel: Array[Array[Double]]): Double = { val kernelWidth = kernel.head.length val kernelHeight = kernel.length @@ -70,7 +69,7 @@ case class GDALBlock[T: ClassTag]( val yIndex = y + (i - kernelCenterY) if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) { val maskValue = maskAt(xIndex, yIndex) - val value = rasterElementAt(xIndex, yIndex) + val value = elementAt(xIndex, yIndex) if (maskValue != 0.0 && num.toDouble(value) != noDataValue) { sum += num.toDouble(value) * kernel(i)(j) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala index c9181b4f8..683d3791e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala @@ -265,34 +265,48 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) { stats.getValid_count == 0 } - def convolve(kernel: Array[Array[Double]]): Unit = { - val kernelWidth = kernel.head.length - val kernelHeight = kernel.length - val blockSize = MosaicGDAL.defaultBlockSize - val strideX = kernelWidth / 2 - val strideY = kernelHeight / 2 - - val block = Array.ofDim[Double](blockSize * blockSize) - val maskBlock = Array.ofDim[Double](blockSize * blockSize) - val result = Array.ofDim[Double](blockSize * blockSize) - - for (yOffset <- 0 until ySize by blockSize - strideY) { - for (xOffset <- 0 until xSize by blockSize - strideX) { - val xSize = Math.min(blockSize, this.xSize - xOffset) - val ySize = Math.min(blockSize, this.ySize - yOffset) - - band.ReadRaster(xOffset, yOffset, xSize, ySize, block) - band.GetMaskBand().ReadRaster(xOffset, yOffset, xSize, ySize, maskBlock) - - val currentBlock = GDALBlock[Double](block, maskBlock, noDataValue, xOffset, yOffset, xSize, ySize, Padding.NoPadding) - - for (y <- 0 until ySize) { - for (x <- 0 until xSize) { - result(y * xSize + x) = currentBlock.convolveAt(x, y, kernel) + /** + * Applies a kernel filter to the band. It assumes the kernel is square and + * has an odd number of rows and columns. + * + * @param kernel + * The kernel to apply to the band. + * @return + * The band with the kernel filter applied. + */ + def convolve(kernel: Array[Array[Double]], outputBand: Band): Unit = { + val kernelSize = kernel.length + require(kernelSize % 2 == 1, "Kernel size must be odd") + + val blockSize = MosaicGDAL.blockSize + val stride = kernelSize / 2 + + for (yOffset <- 0 until ySize by blockSize) { + for (xOffset <- 0 until xSize by blockSize) { + + val currentBlock = GDALBlock( + this, + stride, + xOffset, + yOffset, + blockSize + ) + + val result = Array.ofDim[Double](currentBlock.block.length) + + for (y <- 0 until currentBlock.height) { + for (x <- 0 until currentBlock.width) { + result(y * currentBlock.width + x) = currentBlock.convolveAt(x, y, kernel) } } - band.WriteRaster(xOffset, yOffset, xSize, ySize, block) + val trimmedResult = currentBlock.copy(block = result).trimBlock(stride) + + outputBand.WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.block) + outputBand.FlushCache() + outputBand.GetMaskBand().WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.maskBlock) + outputBand.GetMaskBand().FlushCache() + } } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index bb9d09f25..b399fa0a2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -601,13 +601,18 @@ case class MosaicRasterGDAL( def convolve(kernel: Array[Array[Double]]): MosaicRasterGDAL = { val resultRasterPath = PathUtils.createTmpFilePath(getRasterFileExtension) - val outputRaster = this.raster + + this.raster .GetDriver() - .Create(resultRasterPath, this.xSize, this.ySize, this.numBands, this.raster.GetRasterBand(1).getDataType) + .CreateCopy(resultRasterPath, this.raster, 1) + .delete() + val outputRaster = gdal.Open(resultRasterPath, GF_Write) + for (bandIndex <- 1 to this.numBands) { val band = this.getBand(bandIndex) - band.convolve(kernel) + val outputBand = outputRaster.GetRasterBand(bandIndex) + band.convolve(kernel, outputBand) } val createInfo = Map( @@ -616,8 +621,8 @@ case class MosaicRasterGDAL( "driver" -> getDriversShortName ) - MosaicRasterGDAL(outputRaster, createInfo, -1) - + val result = MosaicRasterGDAL(outputRaster, createInfo, this.memSize) + result.flushCache() } def filter(kernelSize: Int, operation: String): MosaicRasterGDAL = { diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala index bf36ea8f2..97122a7d7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala @@ -100,14 +100,15 @@ case class MosaicRasterTile( rasterDataType: DataType ): InternalRow = { val encodedRaster = encodeRaster(rasterDataType) - val mapData = buildMapString(raster.createInfo) + val path = if (rasterDataType == StringType) encodedRaster.toString else raster.createInfo("path") + val parentPath = if (raster.createInfo("parentPath").isEmpty) raster.createInfo("path") else raster.createInfo("parentPath") + val newCreateInfo = raster.createInfo + ("path" -> path, "parentPath" -> parentPath) + val mapData = buildMapString(newCreateInfo) if (Option(index).isDefined) { if (index.isLeft) InternalRow.fromSeq( Seq(index.left.get, encodedRaster, mapData) ) else { - // Copy from tmp to checkpoint. - // Have to use GDAL Driver to do this since sidecar files are not copied by spark. InternalRow.fromSeq( Seq(UTF8String.fromString(index.right.get), encodedRaster, mapData) ) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala index db20f8a3a..d831ad849 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala @@ -9,6 +9,8 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ /** The expression for applying kernel filter on a raster. */ case class RST_Convolve( @@ -39,7 +41,14 @@ case class RST_Convolve( * The clipped raster. */ override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = { - val kernel = arg1.asInstanceOf[Array[Array[Double]]] + val kernel = arg1.asInstanceOf[ArrayData].array.map(_.asInstanceOf[ArrayData].array.map( + el => kernelExpr.dataType match { + case ArrayType(ArrayType(DoubleType, false), false) => el.asInstanceOf[Double] + case ArrayType(ArrayType(DecimalType(), false), false) => el.asInstanceOf[java.math.BigDecimal].doubleValue() + case ArrayType(ArrayType(IntegerType, false), false) => el.asInstanceOf[Int].toDouble + case _ => throw new IllegalArgumentException(s"Unsupported kernel type: ${kernelExpr.dataType}") + } + )) tile.copy( raster = tile.getRaster.convolve(kernel) ) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala index cd5808f30..8e1dc213e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala @@ -65,7 +65,8 @@ case class RST_FromFile( val readPath = PathUtils.getCleanPath(path) val driver = MosaicRasterGDAL.identifyDriver(path) val targetSize = sizeInMB.eval(input).asInstanceOf[Int] - if (targetSize <= 0 && Files.size(Paths.get(readPath)) <= Integer.MAX_VALUE) { + val currentSize = Files.size(Paths.get(PathUtils.replaceDBFSTokens(readPath))) + if (targetSize <= 0 && currentSize <= Integer.MAX_VALUE) { val createInfo = Map("path" -> readPath, "parentPath" -> path) var raster = MosaicRasterGDAL.readRaster(createInfo) var tile = MosaicRasterTile(null, raster) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala index 12c70d025..f9cf7099d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala @@ -97,7 +97,8 @@ case class RST_MakeTiles( private def getInputSize(rawInput: Any): Long = { if (inputExpr.dataType == StringType) { val path = rawInput.asInstanceOf[UTF8String].toString - Files.size(Paths.get(path)) + val cleanPath = PathUtils.replaceDBFSTokens(path) + Files.size(Paths.get(cleanPath)) } else { val bytes = rawInput.asInstanceOf[Array[Byte]] bytes.length @@ -114,7 +115,7 @@ case class RST_MakeTiles( */ override def eval(input: InternalRow): TraversableOnce[InternalRow] = { GDAL.enable(expressionConfig) - + val tileType = dataType.asInstanceOf[StructType].find(_.name == "raster").get.dataType val rawDriver = driverExpr.eval(input).asInstanceOf[UTF8String].toString @@ -122,10 +123,11 @@ case class RST_MakeTiles( val driver = getDriver(rawInput, rawDriver) val targetSize = sizeInMBExpr.eval(input).asInstanceOf[Int] val inputSize = getInputSize(rawInput) + val path = if (inputExpr.dataType == StringType) rawInput.asInstanceOf[UTF8String].toString else PathUtils.NO_PATH_STRING if (targetSize <= 0 && inputSize <= Integer.MAX_VALUE) { // - no split required - val createInfo = Map("parentPath" -> PathUtils.NO_PATH_STRING, "driver" -> driver) + val createInfo = Map("parentPath" -> PathUtils.NO_PATH_STRING, "driver" -> driver, "path" -> path) val raster = GDAL.readRaster(rawInput, createInfo, inputExpr.dataType) val tile = MosaicRasterTile(null, raster) val row = tile.formatCellId(indexSystem).serialize(tileType) @@ -136,9 +138,15 @@ case class RST_MakeTiles( // target size is > 0 and raster size > target size // - write the initial raster to file (unsplit) // - createDirectories in case of context isolation - val rasterPath = PathUtils.createTmpFilePath(GDAL.getExtension(driver)) - Files.createDirectories(Paths.get(rasterPath).getParent) - Files.write(Paths.get(rasterPath), rawInput.asInstanceOf[Array[Byte]]) + val rasterPath = + if (inputExpr.dataType == StringType) { + PathUtils.copyToTmpWithRetry(path, 5) + } else { + val rasterPath = PathUtils.createTmpFilePath(GDAL.getExtension(driver)) + Files.createDirectories(Paths.get(rasterPath).getParent) + Files.write(Paths.get(rasterPath), rawInput.asInstanceOf[Array[Byte]]) + rasterPath + } val size = if (targetSize <= 0) 64 else targetSize var tiles = ReTileOnRead.localSubdivide(rasterPath, PathUtils.NO_PATH_STRING, size) val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) @@ -180,7 +188,7 @@ object RST_MakeTiles extends WithExpressionInfo { def checkChkpnt(chkpnt: Expression) = Try(chkpnt.eval().asInstanceOf[Boolean]).isSuccess def checkDriver(driver: Expression) = Try(driver.eval().asInstanceOf[UTF8String].toString).isSuccess val noSize = new Literal(-1, IntegerType) - val noDriver = new Literal(MOSAIC_NO_DRIVER, StringType) + val noDriver = new Literal(UTF8String.fromString(MOSAIC_NO_DRIVER), StringType) val noCheckpoint = new Literal(false, BooleanType) children match { diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 40be5cb24..7557306ec 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -263,7 +263,9 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[RST_BoundingBox](expressionConfig) mosaicRegistry.registerExpression[RST_Clip](expressionConfig) mosaicRegistry.registerExpression[RST_CombineAvg](expressionConfig) + mosaicRegistry.registerExpression[RST_Convolve](expressionConfig) mosaicRegistry.registerExpression[RST_DerivedBand](expressionConfig) + mosaicRegistry.registerExpression[RST_Filter](expressionConfig) mosaicRegistry.registerExpression[RST_GeoReference](expressionConfig) mosaicRegistry.registerExpression[RST_GetNoData](expressionConfig) mosaicRegistry.registerExpression[RST_GetSubdataset](expressionConfig) @@ -660,6 +662,8 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(RST_BandMetaData(raster.expr, lit(band).expr, expressionConfig)) def rst_boundingbox(raster: Column): Column = ColumnAdapter(RST_BoundingBox(raster.expr, expressionConfig)) def rst_clip(raster: Column, geometry: Column): Column = ColumnAdapter(RST_Clip(raster.expr, geometry.expr, expressionConfig)) + def rst_convolve(raster: Column, kernel: Column): Column = + ColumnAdapter(RST_Convolve(raster.expr, kernel.expr, expressionConfig)) def rst_pixelcount(raster: Column): Column = ColumnAdapter(RST_PixelCount(raster.expr, expressionConfig)) def rst_combineavg(rasterArray: Column): Column = ColumnAdapter(RST_CombineAvg(rasterArray.expr, expressionConfig)) def rst_derivedband(raster: Column, pythonFunc: Column, funcName: Column): Column = diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala index a36c0bec0..0a881d785 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala @@ -7,7 +7,8 @@ object FileUtils { def readBytes(path: String): Array[Byte] = { val bufferSize = 1024 * 1024 // 1MB - val inputStream = new BufferedInputStream(new FileInputStream(path)) + val cleanPath = PathUtils.replaceDBFSTokens(path) + val inputStream = new BufferedInputStream(new FileInputStream(cleanPath)) val buffer = new Array[Byte](bufferSize) var bytesRead = 0 diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveBehaviors.scala new file mode 100644 index 000000000..fdc09cc28 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveBehaviors.scala @@ -0,0 +1,43 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.{array, lit} +import org.scalatest.matchers.should.Matchers._ + +trait RST_ConvolveBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("result", rst_convolve($"tile", array(array(lit(1.0), lit(2.0), lit(3.0)), array(lit(3.0), lit(2.0), lit(1.0)), array(lit(1.0), lit(3.0), lit(2.0))))) + .select("result") + .collect() + + gridTiles.length should be(7) + + rastersInMemory.createOrReplaceTempView("source") + + spark + .sql(""" + |select rst_convolve(tile, array(array(1, 2, 3), array(2, 3, 1), array(1, 1, 1))) as tile from source + |""".stripMargin) + .collect() + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveTest.scala new file mode 100644 index 000000000..28c049364 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_ConvolveTest extends QueryTest with SharedSparkSessionGDAL with RST_ConvolveBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_Convolve with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala index d06923dc1..2d64a633c 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.lit import org.scalatest.matchers.should.Matchers._ trait RST_FilterBehaviors extends QueryTest { @@ -29,7 +30,48 @@ trait RST_FilterBehaviors extends QueryTest { gridTiles.length should be(7) + val gridTiles2 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("mode"))) + .select("result") + .collect() + + gridTiles2.length should be(7) + + val gridTiles3 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("avg"))) + .select("result") + .collect() + + gridTiles3.length should be(7) + + val gridTiles4 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("min"))) + .select("result") + .collect() + + gridTiles4.length should be(7) + + val gridTiles5 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("max"))) + .select("result") + .collect() + + gridTiles5.length should be(7) + + val gridTiles6 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("median"))) + .select("result") + .collect() + + gridTiles6.length should be(7) + rastersInMemory.createOrReplaceTempView("source") + + noException should be thrownBy spark + .sql(""" + |select rst_filter(tile, 3, 'mode') as tile from source + |""".stripMargin) + .collect() } diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesBehaviors.scala new file mode 100644 index 000000000..31caed8b9 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesBehaviors.scala @@ -0,0 +1,101 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_MakeTilesBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + spark.sparkContext.setLogLevel("ERROR") + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("binaryFile") + .load("src/test/resources/modis") + + val gridTiles1 = rastersInMemory + .withColumn("tile", rst_maketiles($"content", "GTiff", -1)) + .select(!rst_isempty($"tile")) + .as[Boolean] + .collect() + + gridTiles1.forall(identity) should be(true) + + rastersInMemory.createOrReplaceTempView("source") + + val gridTilesSQL = spark + .sql(""" + |with subquery as ( + | select rst_maketiles(content, 'GTiff', -1) as tile from source + |) + |select not rst_isempty(tile) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL.forall(identity) should be(true) + + + val gridTilesSQL2 = spark + .sql( + """ + |with subquery as ( + | select rst_maketiles(content, 'GTiff', 4) as tile from source + |) + |select not rst_isempty(tile) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL2.forall(identity) should be(true) + + val gridTilesSQL3 = spark + .sql( + """ + |with subquery as ( + | select rst_maketiles(path, 'GTiff', 4) as tile from source + |) + |select not rst_isempty(tile) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL3.forall(identity) should be(true) + + val gridTilesSQL4 = spark + .sql( + """ + |with subquery as ( + | select rst_maketiles(path, 'GTiff', 4, true) as tile from source + |) + |select not rst_isempty(tile) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL4.forall(identity) should be(true) + + val gridTiles2 = rastersInMemory + .withColumn("tile", rst_maketiles($"path")) + .select(!rst_isempty($"tile")) + .as[Boolean] + .collect() + + gridTiles2.forall(identity) should be(true) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesTest.scala new file mode 100644 index 000000000..7eaae222b --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesTest.scala @@ -0,0 +1,31 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MakeTilesTest extends QueryTest with SharedSparkSessionGDAL with RST_MakeTilesBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_MakeTiles with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } +}