diff --git a/.github/actions/scala_build/action.yml b/.github/actions/scala_build/action.yml index b33c1b453..8a0d20359 100644 --- a/.github/actions/scala_build/action.yml +++ b/.github/actions/scala_build/action.yml @@ -27,7 +27,7 @@ runs: sudo apt-get update -y # - install natives sudo apt-get install -y unixodbc libcurl3-gnutls libsnappy-dev libopenjp2-7 - sudo apt-get install -y gdal-bin libgdal-dev python3-numpy python3-gdal + sudo apt-get install -y gdal-bin libgdal-dev python3-numpy python3-gdal zip unzip # - install pip libs pip install --upgrade pip pip install gdal==${{ matrix.gdal }} diff --git a/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION b/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION index f689fe17a..876a46cf7 100644 --- a/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION +++ b/R/sparkR-mosaic/sparkrMosaic/DESCRIPTION @@ -8,7 +8,7 @@ Description: This package extends SparkR to bring the Databricks Mosaic for geos License: Databricks Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 Collate: 'enableGDAL.R' 'enableMosaic.R' @@ -20,4 +20,4 @@ Imports: Suggests: testthat (>= 3.0.0), readr (>= 2.1.5) -Config/testthat/edition: 3 \ No newline at end of file +Config/testthat/edition: 3 diff --git a/R/sparkR-mosaic/sparkrMosaic/tests/testthat/testRasterFunctions.R b/R/sparkR-mosaic/sparkrMosaic/tests/testthat/testRasterFunctions.R index 36296e9d8..6e23454dc 100644 --- a/R/sparkR-mosaic/sparkrMosaic/tests/testthat/testRasterFunctions.R +++ b/R/sparkR-mosaic/sparkrMosaic/tests/testthat/testRasterFunctions.R @@ -15,7 +15,7 @@ test_that("mosaic can read single-band GeoTiff", { expect_equal(row$srid, 0) expect_equal(row$bandCount, 1) expect_equal(row$metadata[[1]]$LONGNAME, "MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m") - expect_equal(row$tile[[1]]$driver, "GTiff") + expect_equal(row$tile[[1]]$metadata$driver, "GTiff") }) @@ -61,7 +61,7 @@ test_that("raster flatmap functions behave as intended", { tessellate_sdf <- withColumn(tessellate_sdf, "rst_tessellate", rst_tessellate(column("tile"), lit(3L))) expect_no_error(write.df(tessellate_sdf, source = "noop", mode = "overwrite")) - expect_equal(nrow(tessellate_sdf), 66) + expect_equal(nrow(tessellate_sdf), 63) overlap_sdf <- generate_singleband_raster_df() overlap_sdf <- withColumn(overlap_sdf, "rst_to_overlapping_tiles", rst_to_overlapping_tiles(column("tile"), lit(200L), lit(200L), lit(10L))) @@ -117,7 +117,7 @@ test_that("the tessellate-join-clip-merge flow works on NetCDF files", { raster_sdf <- read.df( path = "sparkrMosaic/tests/testthat/data/prAdjust_day_HadGEM2-CC_SMHI-DBSrev930-GFD-1981-2010-postproc_rcp45_r1i1p1_20201201-20201231.nc", source = "gdal", - raster.read.strategy = "retile_on_read" + raster.read.strategy = "in_memory" ) raster_sdf <- withColumn(raster_sdf, "tile", rst_separatebands(column("tile"))) diff --git a/R/sparkR-mosaic/tests.R b/R/sparkR-mosaic/tests.R index 556b48b0f..7253cac2d 100644 --- a/R/sparkR-mosaic/tests.R +++ b/R/sparkR-mosaic/tests.R @@ -21,11 +21,15 @@ print("Looking for mosaic jar in") mosaic_jar_path <- paste0(staging_dir, mosaic_jar) print(mosaic_jar_path) +pwd <- getwd() spark <- sparkR.session( master = "local[*]" - ,sparkJars = mosaic_jar_path + ,sparkJars = mosaic_jar_path, + sparkConfig = list( + spark.databricks.labs.mosaic.raster.tmp.prefix = paste0(pwd, "/mosaic_tmp", sep="") + ,spark.databricks.labs.mosaic.raster.checkpoint = paste0(pwd, "/mosaic_checkpoint", sep="") + ) ) - enableMosaic() testthat::test_local(path="./sparkrMosaic") \ No newline at end of file diff --git a/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION b/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION index c9d2048fb..315e6bf3c 100644 --- a/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION +++ b/R/sparklyr-mosaic/sparklyrMosaic/DESCRIPTION @@ -8,7 +8,7 @@ Description: This package extends sparklyr to bring the Databricks Mosaic for ge License: Databricks Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.1 Collate: 'enableGDAL.R' 'enableMosaic.R' @@ -20,4 +20,4 @@ Suggests: testthat (>= 3.0.0), sparklyr.nested (>= 0.0.4), readr (>= 2.1.5) -Config/testthat/edition: 3 \ No newline at end of file +Config/testthat/edition: 3 diff --git a/R/sparklyr-mosaic/sparklyrMosaic/tests/testthat/testRasterFunctions.R b/R/sparklyr-mosaic/sparklyrMosaic/tests/testthat/testRasterFunctions.R index 3bb021c64..3cf016fa7 100644 --- a/R/sparklyr-mosaic/sparklyrMosaic/tests/testthat/testRasterFunctions.R +++ b/R/sparklyr-mosaic/sparklyrMosaic/tests/testthat/testRasterFunctions.R @@ -18,7 +18,7 @@ test_that("mosaic can read single-band GeoTiff", { expect_equal(row$srid, 0) expect_equal(row$bandCount, 1) expect_equal(row$metadata[[1]]$LONGNAME, "MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m") - expect_equal(row$tile[[1]]$driver, "GTiff") + expect_equal(row$tile[[1]]$metadata$driver, "GTiff") }) @@ -90,7 +90,7 @@ test_that("raster flatmap functions behave as intended", { mutate(rst_tessellate = rst_tessellate(tile, 3L)) expect_no_error(spark_write_source(tessellate_sdf, "noop", mode = "overwrite")) - expect_equal(sdf_nrow(tessellate_sdf), 66) + expect_equal(sdf_nrow(tessellate_sdf), 63) overlap_sdf <- generate_singleband_raster_df() %>% mutate(rst_to_overlapping_tiles = rst_to_overlapping_tiles(tile, 200L, 200L, 10L)) @@ -157,7 +157,7 @@ test_that("the tessellate-join-clip-merge flow works on NetCDF files", { name = "raster_raw", source = "gdal", path = "data/prAdjust_day_HadGEM2-CC_SMHI-DBSrev930-GFD-1981-2010-postproc_rcp45_r1i1p1_20201201-20201231.nc", - options = list("raster.read.strategy" = "retile_on_read") + options = list("raster.read.strategy" = "in_memory") ) %>% mutate(tile = rst_separatebands(tile)) %>% sdf_register("raster") diff --git a/R/sparklyr-mosaic/tests.R b/R/sparklyr-mosaic/tests.R index eef6d7c8b..b806a3d73 100644 --- a/R/sparklyr-mosaic/tests.R +++ b/R/sparklyr-mosaic/tests.R @@ -21,6 +21,8 @@ print(paste("Looking for mosaic jar in", mosaic_jar_path)) config <- sparklyr::spark_config() config$`sparklyr.jars.default` <- c(mosaic_jar_path) +config$`spark.databricks.labs.mosaic.raster.tmp.prefix` <- paste0(getwd(), "/mosaic_tmp", sep="") +config$`spark.databricks.labs.mosaic.raster.checkpoint` <- paste0(getwd(), "/mosaic_checkpoint", sep="") sc <- spark_connect(master="local[*]", config=config) enableMosaic(sc) diff --git a/python/mosaic/api/raster.py b/python/mosaic/api/raster.py index 3638510dc..c703b9134 100644 --- a/python/mosaic/api/raster.py +++ b/python/mosaic/api/raster.py @@ -5,7 +5,6 @@ from pyspark.sql.functions import lit from typing import Any - ####################### # Raster functions # ####################### @@ -15,16 +14,19 @@ "rst_boundingbox", "rst_clip", "rst_combineavg", + "rst_convolve", "rst_derivedband", "rst_frombands", "rst_fromcontent", "rst_fromfile", + "rst_filter", "rst_georeference", "rst_getnodata", "rst_getsubdataset", "rst_height", "rst_initnodata", "rst_isempty", + "rst_maketiles", "rst_mapalgebra", "rst_memsize", "rst_merge", @@ -55,6 +57,7 @@ "rst_subdivide", "rst_summary", "rst_tessellate", + "rst_transform", "rst_to_overlapping_tiles", "rst_tryopen", "rst_upperleftx", @@ -156,6 +159,32 @@ def rst_combineavg(raster_tiles: ColumnOrName) -> Column: ) +def rst_convolve(raster_tile: ColumnOrName, kernel: ColumnOrName) -> Column: + """ + Applies a convolution filter to the raster. + The result is Mosaic raster tile struct column to the filtered raster. + The result is stored in the checkpoint directory. + + Parameters + ---------- + raster_tile : Column (RasterTileType) + Mosaic raster tile struct column. + kernel : Column (ArrayType(ArrayType(DoubleType))) + The kernel to apply to the raster. + + Returns + ------- + Column (RasterTileType) + Mosaic raster tile struct column. + + """ + return config.mosaic_context.invoke_function( + "rst_convolve", + pyspark_to_java_column(raster_tile), + pyspark_to_java_column(kernel), + ) + + def rst_derivedband( raster_tile: ColumnOrName, python_func: ColumnOrName, func_name: ColumnOrName ) -> Column: @@ -316,6 +345,43 @@ def rst_isempty(raster_tile: ColumnOrName) -> Column: ) +def rst_maketiles(input: ColumnOrName, driver: Any = "no_driver", size_in_mb: Any = -1, + with_checkpoint: Any = False) -> Column: + """ + Tiles the raster into tiles of the given size. + :param input: If the raster is stored on disc, the path + to the raster is provided. If the raster is stored in memory, the bytes of + the raster are provided. + :param driver: The driver to use for reading the raster. If not specified, the driver is + inferred from the file extension. If the input is a byte array, the driver + has to be specified. + :param size_in_mb: The size of the tiles in MB. If set to -1, the file is loaded and returned + as a single tile. If set to 0, the file is loaded and subdivided into + tiles of size 64MB. If set to a positive value, the file is loaded and + subdivided into tiles of the specified size. If the file is too big to fit + in memory, it is subdivided into tiles of size 64MB. + :param with_checkpoint: If set to true, the tiles are written to the checkpoint directory. If set + to false, the tiles are returned as a in-memory byte arrays. + :return: A collection of tiles of the raster. + """ + if type(size_in_mb) == int: + size_in_mb = lit(size_in_mb) + + if type(with_checkpoint) == bool: + with_checkpoint = lit(with_checkpoint) + + if type(driver) == str: + driver = lit(driver) + + return config.mosaic_context.invoke_function( + "rst_maketiles", + pyspark_to_java_column(input), + pyspark_to_java_column(driver), + pyspark_to_java_column(size_in_mb), + pyspark_to_java_column(with_checkpoint), + ) + + def rst_mapalgebra(raster_tile: ColumnOrName, json_spec: ColumnOrName) -> Column: """ Parameters @@ -630,7 +696,7 @@ def rst_rastertogridmin(raster_tile: ColumnOrName, resolution: ColumnOrName) -> def rst_rastertoworldcoord( - raster_tile: ColumnOrName, x: ColumnOrName, y: ColumnOrName + raster_tile: ColumnOrName, x: ColumnOrName, y: ColumnOrName ) -> Column: """ Computes the world coordinates of the raster pixel at the given x and y coordinates. @@ -997,6 +1063,32 @@ def rst_tessellate(raster_tile: ColumnOrName, resolution: ColumnOrName) -> Colum ) +def rst_transform(raster_tile: ColumnOrName, srid: ColumnOrName) -> Column: + """ + Transforms the raster to the given SRID. + The result is a Mosaic raster tile struct of the transformed raster. + The result is stored in the checkpoint directory. + + Parameters + ---------- + raster_tile : Column (RasterTileType) + Mosaic raster tile struct column. + srid : Column (IntegerType) + EPSG authority code for the file's projection. + + Returns + ------- + Column (RasterTileType) + Mosaic raster tile struct column. + + """ + return config.mosaic_context.invoke_function( + "rst_transform", + pyspark_to_java_column(raster_tile), + pyspark_to_java_column(srid), + ) + + def rst_fromcontent( raster_bin: ColumnOrName, driver: ColumnOrName, size_in_mb: Any = -1 ) -> Column: @@ -1035,6 +1127,28 @@ def rst_fromfile(raster_path: ColumnOrName, size_in_mb: Any = -1) -> Column: ) +def rst_filter(raster_tile: ColumnOrName, kernel_size: Any, operation: Any) -> Column: + """ + Applies a filter to the raster. + :param raster_tile: Mosaic raster tile struct column. + :param kernel_size: The size of the kernel. Has to be odd. + :param operation: The operation to apply to the kernel. + :return: A new raster tile with the filter applied. + """ + if type(kernel_size) == int: + kernel_size = lit(kernel_size) + + if type(operation) == str: + operation = lit(operation) + + return config.mosaic_context.invoke_function( + "rst_filter", + pyspark_to_java_column(raster_tile), + pyspark_to_java_column(kernel_size), + pyspark_to_java_column(operation), + ) + + def rst_to_overlapping_tiles( raster_tile: ColumnOrName, width: ColumnOrName, diff --git a/python/test/test_raster_functions.py b/python/test/test_raster_functions.py index be5fd4656..cda55143d 100644 --- a/python/test/test_raster_functions.py +++ b/python/test/test_raster_functions.py @@ -19,7 +19,7 @@ def test_read_raster(self): result.metadata["LONGNAME"], "MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m", ) - self.assertEqual(result.tile["driver"], "GTiff") + self.assertEqual(result.tile["metadata"]["driver"], "GTiff") def test_raster_scalar_functions(self): result = ( @@ -115,7 +115,7 @@ def test_raster_flatmap_functions(self): ) tessellate_result.write.format("noop").mode("overwrite").save() - self.assertEqual(tessellate_result.count(), 66) + self.assertEqual(tessellate_result.count(), 63) overlap_result = ( self.generate_singleband_raster_df() @@ -187,11 +187,12 @@ def test_netcdf_load_tessellate_clip_merge(self): df = ( self.spark.read.format("gdal") - .option("raster.read.strategy", "retile_on_read") + .option("raster.read.strategy", "in_memory") .load( "test/data/prAdjust_day_HadGEM2-CC_SMHI-DBSrev930-GFD-1981-2010-postproc_rcp45_r1i1p1_20201201-20201231.nc" ) .select(api.rst_separatebands("tile").alias("tile")) + .repartition(self.spark.sparkContext.defaultParallelism) .withColumn( "timestep", element_at( diff --git a/python/test/test_vector_functions.py b/python/test/test_vector_functions.py index afca19778..c2bb2be74 100644 --- a/python/test/test_vector_functions.py +++ b/python/test/test_vector_functions.py @@ -37,9 +37,7 @@ def test_st_z(self): .select(col("id").cast("double")) .withColumn( "points", - api.st_geomfromwkt( - concat(lit("POINT (9 9 "), "id", lit(")")) - ), + api.st_geomfromwkt(concat(lit("POINT (9 9 "), "id", lit(")"))), ) .withColumn("z", api.st_z("points")) .collect() diff --git a/python/test/utils/spark_test_case.py b/python/test/utils/spark_test_case.py index 6ae23b1b3..640713ba7 100644 --- a/python/test/utils/spark_test_case.py +++ b/python/test/utils/spark_test_case.py @@ -18,6 +18,14 @@ def setUpClass(cls) -> None: if not os.path.exists(cls.library_location): cls.library_location = f"{mosaic.__path__[0]}/lib/mosaic-{version('databricks-mosaic')}-SNAPSHOT-jar-with-dependencies.jar" + pwd_dir = os.getcwd() + tmp_dir = f"{pwd_dir}/mosaic_test/" + check_dir = f"{pwd_dir}/checkpoint" + if not os.path.exists(tmp_dir): + os.makedirs(tmp_dir) + if not os.path.exists(check_dir): + os.makedirs(check_dir) + cls.spark = ( SparkSession.builder.master("local[*]") .config("spark.jars", cls.library_location) @@ -33,6 +41,8 @@ def setUpClass(cls) -> None: .getOrCreate() ) cls.spark.conf.set("spark.databricks.labs.mosaic.jar.autoattach", "false") + cls.spark.conf.set("spark.databricks.labs.mosaic.raster.tmp.prefix", tmp_dir) + cls.spark.conf.set("spark.databricks.labs.mosaic.raster.checkpoint", check_dir) cls.spark.sparkContext.setLogLevel("FATAL") @classmethod diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/FormatLookup.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/FormatLookup.scala index e3aeb5296..8bf2d9cdb 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/FormatLookup.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/FormatLookup.scala @@ -17,6 +17,7 @@ object FormatLookup { "CAD" -> "dwg", "CEOS" -> "ceos", "COASP" -> "coasp", + "COG" -> "tif", "COSAR" -> "cosar", "CPG" -> "cpg", "CSW" -> "csw", diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala index 66bde39a3..3517f9368 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/api/GDAL.scala @@ -4,6 +4,7 @@ import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterBandGDAL, Mosaic import com.databricks.labs.mosaic.core.raster.io.RasterCleaner import com.databricks.labs.mosaic.core.raster.operator.transform.RasterTransform import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.gdal.MosaicGDAL import com.databricks.labs.mosaic.gdal.MosaicGDAL.configureGDAL import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.{BinaryType, DataType, StringType} @@ -95,25 +96,25 @@ object GDAL { */ def readRaster( inputRaster: Any, - parentPath: String, - shortDriverName: String, + createInfo: Map[String, String], inputDT: DataType ): MosaicRasterGDAL = { inputDT match { case StringType => - val path = inputRaster.asInstanceOf[UTF8String].toString - MosaicRasterGDAL.readRaster(path, parentPath) + MosaicRasterGDAL.readRaster(createInfo) case BinaryType => val bytes = inputRaster.asInstanceOf[Array[Byte]] - val raster = MosaicRasterGDAL.readRaster(bytes, parentPath, shortDriverName) + val raster = MosaicRasterGDAL.readRaster(bytes, createInfo) // If the raster is coming as a byte array, we can't check for zip condition. // We first try to read the raster directly, if it fails, we read it as a zip. if (raster == null) { + val parentPath = createInfo("parentPath") val zippedPath = s"/vsizip/$parentPath" - MosaicRasterGDAL.readRaster(bytes, zippedPath, shortDriverName) + MosaicRasterGDAL.readRaster(bytes, createInfo + ("path" -> zippedPath)) } else { raster } + case _ => throw new IllegalArgumentException(s"Unsupported data type: $inputDT") } } @@ -122,19 +123,17 @@ object GDAL { * * @param generatedRasters * The rasters to write. - * @param checkpointPath - * The path to write the rasters to. * @return * Returns the paths of the written rasters. */ - def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], checkpointPath: String, rasterDT: DataType): Seq[Any] = { + def writeRasters(generatedRasters: Seq[MosaicRasterGDAL], rasterDT: DataType): Seq[Any] = { generatedRasters.map(raster => if (raster != null) { rasterDT match { case StringType => val uuid = UUID.randomUUID().toString val extension = GDAL.getExtension(raster.getDriversShortName) - val writePath = s"$checkpointPath/$uuid.$extension" + val writePath = s"${MosaicGDAL.checkpointPath}/$uuid.$extension" val outPath = raster.writeToPath(writePath) RasterCleaner.dispose(raster) UTF8String.fromString(outPath) @@ -159,7 +158,10 @@ object GDAL { * @return * Returns a Raster object. */ - def raster(path: String, parentPath: String): MosaicRasterGDAL = MosaicRasterGDAL.readRaster(path, parentPath) + def raster(path: String, parentPath: String): MosaicRasterGDAL = { + val createInfo = Map("path" -> path, "parentPath" -> parentPath) + MosaicRasterGDAL.readRaster(createInfo) + } /** * Reads a raster from the given byte array. If the byte array is a zip @@ -170,8 +172,10 @@ object GDAL { * @return * Returns a Raster object. */ - def raster(content: Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL = - MosaicRasterGDAL.readRaster(content, parentPath, driverShortName) + def raster(content: Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL = { + val createInfo = Map("parentPath" -> parentPath, "driver" -> driverShortName) + MosaicRasterGDAL.readRaster(content, createInfo) + } /** * Reads a raster from the given path. It extracts the specified band from @@ -185,8 +189,10 @@ object GDAL { * @return * Returns a Raster band object. */ - def band(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL = - MosaicRasterGDAL.readBand(path, bandIndex, parentPath) + def band(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL = { + val createInfo = Map("path" -> path, "parentPath" -> parentPath) + MosaicRasterGDAL.readBand(bandIndex, createInfo) + } /** * Converts raster x, y coordinates to lat, lon coordinates. diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala new file mode 100644 index 000000000..5014b0ee3 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/GDALBlock.scala @@ -0,0 +1,185 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +import scala.reflect.ClassTag + +case class GDALBlock[T: ClassTag]( + block: Array[T], + maskBlock: Array[Double], + noDataValue: Double, + xOffset: Int, + yOffset: Int, + width: Int, + height: Int, + padding: Padding +)(implicit + num: Numeric[T] +) { + + def elementAt(index: Int): T = block(index) + + def maskAt(index: Int): Double = maskBlock(index) + + def elementAt(x: Int, y: Int): T = block(y * width + x) + + def maskAt(x: Int, y: Int): Double = maskBlock(y * width + x) + + def rasterElementAt(x: Int, y: Int): T = block((y - yOffset) * width + (x - xOffset)) + + def rasterMaskAt(x: Int, y: Int): Double = maskBlock((y - yOffset) * width + (x - xOffset)) + + def valuesAt(x: Int, y: Int, kernelWidth: Int, kernelHeight: Int): Array[Double] = { + val kernelCenterX = kernelWidth / 2 + val kernelCenterY = kernelHeight / 2 + val values = Array.fill[Double](kernelWidth * kernelHeight)(noDataValue) + var n = 0 + for (i <- 0 until kernelHeight) { + for (j <- 0 until kernelWidth) { + val xIndex = x + (j - kernelCenterX) + val yIndex = y + (i - kernelCenterY) + if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) { + val maskValue = maskAt(xIndex, yIndex) + val value = elementAt(xIndex, yIndex) + if (maskValue != 0.0 && num.toDouble(value) != noDataValue) { + values(n) = num.toDouble(value) + n += 1 + } + } + } + } + val result = values.filter(_ != noDataValue) + // always return only one NoDataValue if no valid values are found + // one and only one NoDataValue can be returned + // in all cases that have some valid values, the NoDataValue will be filtered out + if (result.isEmpty) { + Array(noDataValue) + } else { + result + } + } + + def convolveAt(x: Int, y: Int, kernel: Array[Array[Double]]): Double = { + val kernelWidth = kernel.head.length + val kernelHeight = kernel.length + val kernelCenterX = kernelWidth / 2 + val kernelCenterY = kernelHeight / 2 + var sum = 0.0 + for (i <- 0 until kernelHeight) { + for (j <- 0 until kernelWidth) { + val xIndex = x + (j - kernelCenterX) + val yIndex = y + (i - kernelCenterY) + if (xIndex >= 0 && xIndex < width && yIndex >= 0 && yIndex < height) { + val maskValue = maskAt(xIndex, yIndex) + val value = elementAt(xIndex, yIndex) + if (maskValue != 0.0 && num.toDouble(value) != noDataValue) { + sum += num.toDouble(value) * kernel(i)(j) + } + } + } + } + sum + } + + def avgFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + values.sum / values.length + } + + def minFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + values.min + } + + def maxFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + values.max + } + + def medianFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + val n = values.length + scala.util.Sorting.quickSort(values) + values(n / 2) + } + + def modeFilterAt(x: Int, y: Int, kernelSize: Int): Double = { + val values = valuesAt(x, y, kernelSize, kernelSize) + val counts = values.groupBy(identity).mapValues(_.length) + counts.maxBy(_._2)._1 + } + + def trimBlock(stride: Int): GDALBlock[Double] = { + val resultBlock = padding.removePadding(block.map(num.toDouble), width, stride) + val resultMask = padding.removePadding(maskBlock, width, stride) + + val newOffset = padding.newOffset(xOffset, yOffset, stride) + val newSize = padding.newSize(width, height, stride) + + new GDALBlock[Double]( + resultBlock, + resultMask, + noDataValue, + newOffset._1, + newOffset._2, + newSize._1, + newSize._2, + Padding.NoPadding + ) + } + +} + +object GDALBlock { + + def getSize(offset: Int, maxSize: Int, blockSize: Int, stride: Int, paddingStrides: Int): Int = { + if (offset + blockSize + paddingStrides * stride > maxSize) { + maxSize - offset + } else { + blockSize + paddingStrides * stride + } + } + + def apply( + band: MosaicRasterBandGDAL, + stride: Int, + xOffset: Int, + yOffset: Int, + blockSize: Int + ): GDALBlock[Double] = { + val noDataValue = band.noDataValue + val rasterWidth = band.xSize + val rasterHeight = band.ySize + // Always read blockSize + stride pixels on every side + // This is fine since kernel size is always much smaller than blockSize + + val padding = Padding( + left = xOffset != 0, + right = xOffset + blockSize < rasterWidth - 1, // not sure about -1 + top = yOffset != 0, + bottom = yOffset + blockSize < rasterHeight - 1 + ) + + val xo = Math.max(0, xOffset - stride) + val yo = Math.max(0, yOffset - stride) + + val xs = getSize(xo, rasterWidth, blockSize, stride, padding.horizontalStrides) + val ys = getSize(yo, rasterHeight, blockSize, stride, padding.verticalStrides) + + val block = Array.ofDim[Double](xs * ys) + val maskBlock = Array.ofDim[Double](xs * ys) + + band.getBand.ReadRaster(xo, yo, xs, ys, block) + band.getBand.GetMaskBand().ReadRaster(xo, yo, xs, ys, maskBlock) + + GDALBlock( + block, + maskBlock, + noDataValue, + xo, + yo, + xs, + ys, + padding + ) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala index 5ff3ceeaf..683d3791e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterBandGDAL.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.core.raster.gdal +import com.databricks.labs.mosaic.gdal.MosaicGDAL import org.gdal.gdal.Band import org.gdal.gdalconst.gdalconstConstants @@ -219,6 +220,30 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) { } } + /** + * Counts the number of pixels in the band. The mask is used to determine + * if a pixel is valid. If pixel value is noData or mask value is 0.0, the + * pixel is not counted. + * + * @return + * Returns the band's pixel count. + */ + def pixelCount: Int = { + val line = Array.ofDim[Double](band.GetXSize()) + val maskLine = Array.ofDim[Double](band.GetXSize()) + var count = 0 + for (y <- 0 until band.GetYSize()) { + band.ReadRaster(0, y, band.GetXSize(), 1, line) + val maskRead = band.GetMaskBand().ReadRaster(0, y, band.GetXSize(), 1, maskLine) + if (maskRead != gdalconstConstants.CE_None) { + count = count + line.count(_ != noDataValue) + } else { + count = count + line.zip(maskLine).count { case (pixel, mask) => pixel != noDataValue && mask != 0.0 } + } + } + count + } + /** * @return * Returns the band's mask flags. @@ -240,4 +265,94 @@ case class MosaicRasterBandGDAL(band: Band, id: Int) { stats.getValid_count == 0 } + /** + * Applies a kernel filter to the band. It assumes the kernel is square and + * has an odd number of rows and columns. + * + * @param kernel + * The kernel to apply to the band. + * @return + * The band with the kernel filter applied. + */ + def convolve(kernel: Array[Array[Double]], outputBand: Band): Unit = { + val kernelSize = kernel.length + require(kernelSize % 2 == 1, "Kernel size must be odd") + + val blockSize = MosaicGDAL.blockSize + val stride = kernelSize / 2 + + for (yOffset <- 0 until ySize by blockSize) { + for (xOffset <- 0 until xSize by blockSize) { + + val currentBlock = GDALBlock( + this, + stride, + xOffset, + yOffset, + blockSize + ) + + val result = Array.ofDim[Double](currentBlock.block.length) + + for (y <- 0 until currentBlock.height) { + for (x <- 0 until currentBlock.width) { + result(y * currentBlock.width + x) = currentBlock.convolveAt(x, y, kernel) + } + } + + val trimmedResult = currentBlock.copy(block = result).trimBlock(stride) + + outputBand.WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.block) + outputBand.FlushCache() + outputBand.GetMaskBand().WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.maskBlock) + outputBand.GetMaskBand().FlushCache() + + } + } + } + + def filter(kernelSize: Int, operation: String, outputBand: Band): Unit = { + require(kernelSize % 2 == 1, "Kernel size must be odd") + + val blockSize = MosaicGDAL.blockSize + val stride = kernelSize / 2 + + for (yOffset <- 0 until ySize by blockSize) { + for (xOffset <- 0 until xSize by blockSize) { + + val currentBlock = GDALBlock( + this, + stride, + xOffset, + yOffset, + blockSize + ) + + val result = Array.ofDim[Double](currentBlock.block.length) + + for (y <- 0 until currentBlock.height) { + for (x <- 0 until currentBlock.width) { + result(y * currentBlock.width + x) = operation match { + case "avg" => currentBlock.avgFilterAt(x, y, kernelSize) + case "min" => currentBlock.minFilterAt(x, y, kernelSize) + case "max" => currentBlock.maxFilterAt(x, y, kernelSize) + case "median" => currentBlock.medianFilterAt(x, y, kernelSize) + case "mode" => currentBlock.modeFilterAt(x, y, kernelSize) + case _ => throw new Exception("Invalid operation") + } + } + } + + val trimmedResult = currentBlock.copy(block = result).trimBlock(stride) + + outputBand.WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.block) + outputBand.FlushCache() + outputBand.GetMaskBand().WriteRaster(xOffset, yOffset, trimmedResult.width, trimmedResult.height, trimmedResult.maskBlock) + outputBand.GetMaskBand().FlushCache() + + } + } + + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala index 33f980748..b399fa0a2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterGDAL.scala @@ -9,9 +9,9 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose import com.databricks.labs.mosaic.core.raster.io.{RasterCleaner, RasterReader, RasterWriter} import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector import com.databricks.labs.mosaic.core.types.model.GeometryTypeEnum.POLYGON -import com.databricks.labs.mosaic.utils.PathUtils -import org.gdal.gdal.gdal.GDALInfo -import org.gdal.gdal.{Dataset, InfoOptions, gdal} +import com.databricks.labs.mosaic.gdal.MosaicGDAL +import com.databricks.labs.mosaic.utils.{FileUtils, PathUtils, SysUtils} +import org.gdal.gdal.{Dataset, gdal} import org.gdal.gdalconst.gdalconstConstants._ import org.gdal.osr import org.gdal.osr.SpatialReference @@ -26,37 +26,60 @@ import scala.util.{Failure, Success, Try} //noinspection DuplicatedCode case class MosaicRasterGDAL( raster: Dataset, - path: String, - parentPath: String, - driverShortName: String, + createInfo: Map[String, String], memSize: Long ) extends RasterWriter with RasterCleaner { + def path: String = createInfo("path") + + def parentPath: String = createInfo("parentPath") + + def driverShortName: Option[String] = createInfo.get("driver") + + def getWriteOptions: MosaicRasterWriteOptions = MosaicRasterWriteOptions(this) + + def getCompression: String = { + val compression = Option(raster.GetMetadata_Dict("IMAGE_STRUCTURE")) + .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) + .getOrElse(Map.empty[String, String]) + .getOrElse("COMPRESSION", "NONE") + compression + } + def getSpatialReference: SpatialReference = { - if (raster != null) { - spatialRef + val spatialRef = + if (raster != null) { + raster.GetSpatialRef + } else { + val tmp = refresh() + val result = tmp.raster.GetSpatialRef + dispose(tmp) + result + } + if (spatialRef == null) { + MosaicGDAL.WSG84 } else { - val tmp = refresh() - val result = tmp.spatialRef - dispose(tmp) - result + spatialRef } } + def isSubDataset: Boolean = { + val isSubdataset = PathUtils.isSubdataset(path) + isSubdataset + } + // Factory for creating CRS objects protected val crsFactory: CRSFactory = new CRSFactory - // Only use this with GDAL rasters - private val wgs84 = new osr.SpatialReference() - wgs84.ImportFromEPSG(4326) - wgs84.SetAxisMappingStrategy(osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) - /** * @return * The raster's driver short name. */ - def getDriversShortName: String = driverShortName + def getDriversShortName: String = + driverShortName.getOrElse( + Try(raster.GetDriver().getShortName).getOrElse("NONE") + ) /** * @return @@ -98,7 +121,7 @@ case class MosaicRasterGDAL( def pixelDiagSize: Double = math.sqrt(pixelXSize * pixelXSize + pixelYSize * pixelYSize) /** @return Returns file extension. */ - def getRasterFileExtension: String = GDAL.getExtension(driverShortName) + def getRasterFileExtension: String = GDAL.getExtension(getDriversShortName) /** @return Returns the raster's bands as a Seq. */ def getBands: Seq[MosaicRasterBandGDAL] = (1 to numBands).map(getBand) @@ -125,7 +148,7 @@ case class MosaicRasterGDAL( * A MosaicRaster object. */ def openRaster(path: String): Dataset = { - MosaicRasterGDAL.openRaster(path, Some(driverShortName)) + MosaicRasterGDAL.openRaster(path, driverShortName) } /** @@ -159,6 +182,7 @@ case class MosaicRasterGDAL( .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) .getOrElse(Map.empty[String, String]) val keys = subdatasetsMap.keySet + val sanitizedParentPath = PathUtils.getCleanPath(parentPath) keys.flatMap(key => if (key.toUpperCase(Locale.ROOT).contains("NAME")) { val path = subdatasetsMap(key) @@ -166,7 +190,7 @@ case class MosaicRasterGDAL( Seq( key -> pieces.last, s"${pieces.last}_tmp" -> path, - pieces.last -> s"${pieces.head}:$parentPath:${pieces.last}" + pieces.last -> s"${pieces.head}:$sanitizedParentPath:${pieces.last}" ) } else Seq(key -> subdatasetsMap(key)) ).toMap @@ -185,7 +209,6 @@ case class MosaicRasterGDAL( .toInt } - /** * @return * Sets the raster's SRID. This is the EPSG code of the raster's CRS. @@ -195,15 +218,18 @@ case class MosaicRasterGDAL( srs.ImportFromEPSG(srid) raster.SetSpatialRef(srs) val driver = raster.GetDriver() - val newPath = PathUtils.createTmpFilePath(GDAL.getExtension(driverShortName)) + val newPath = PathUtils.createTmpFilePath(GDAL.getExtension(getDriversShortName)) driver.CreateCopy(newPath, raster) - val newRaster = MosaicRasterGDAL.openRaster(newPath, Some(driverShortName)) + val newRaster = MosaicRasterGDAL.openRaster(newPath, driverShortName) dispose(this) - MosaicRasterGDAL(newRaster, newPath, parentPath, driverShortName, -1) + val createInfo = Map( + "path" -> newPath, + "parentPath" -> parentPath, + "driver" -> getDriversShortName + ) + MosaicRasterGDAL(newRaster, createInfo, -1) } - - /** * @return * Returns the raster's proj4 string. @@ -280,12 +306,6 @@ case class MosaicRasterGDAL( */ def getRaster: Dataset = this.raster - /** - * @return - * Returns the raster's spatial reference. - */ - def spatialRef: SpatialReference = Option(raster.GetSpatialRef()).getOrElse(wgs84) - /** * Applies a function to each band of the raster. * @param f @@ -299,10 +319,10 @@ case class MosaicRasterGDAL( * @return * Returns MosaicGeometry representing bounding box of the raster. */ - def bbox(geometryAPI: GeometryAPI, destCRS: SpatialReference = wgs84): MosaicGeometry = { + def bbox(geometryAPI: GeometryAPI, destCRS: SpatialReference = MosaicGDAL.WSG84): MosaicGeometry = { val gt = getGeoTransform - val sourceCRS = spatialRef + val sourceCRS = getSpatialReference val transform = new osr.CoordinateTransformation(sourceCRS, destCRS) val bbox = geometryAPI.geometry( @@ -329,10 +349,9 @@ case class MosaicRasterGDAL( def isEmpty: Boolean = { val bands = getBands if (bands.isEmpty) { - subdatasets - .values - .filter(_.toLowerCase(Locale.ROOT).startsWith(driverShortName.toLowerCase(Locale.ROOT))) - .flatMap(readRaster(_, path).getBands) + subdatasets.values + .filter(_.toLowerCase(Locale.ROOT).startsWith(getDriversShortName.toLowerCase(Locale.ROOT))) + .flatMap(bp => readRaster(createInfo + ("path" -> bp)).getBands) .takeWhile(_.isEmpty) .nonEmpty } else { @@ -367,11 +386,18 @@ case class MosaicRasterGDAL( val isSubdataset = PathUtils.isSubdataset(path) val filePath = if (isSubdataset) PathUtils.fromSubdatasetPath(path) else path val pamFilePath = s"$filePath.aux.xml" + val cleanPath = filePath.replace("/vsizip/", "") + val zipPath = if (cleanPath.endsWith("zip")) cleanPath else s"$cleanPath.zip" if (path != PathUtils.getCleanPath(parentPath)) { - Try(gdal.GetDriverByName(driverShortName).Delete(path)) + Try(gdal.GetDriverByName(getDriversShortName).Delete(path)) + Try(Files.deleteIfExists(Paths.get(cleanPath))) Try(Files.deleteIfExists(Paths.get(path))) Try(Files.deleteIfExists(Paths.get(filePath))) Try(Files.deleteIfExists(Paths.get(pamFilePath))) + if (Files.exists(Paths.get(zipPath))) { + Try(Files.deleteIfExists(Paths.get(zipPath.replace(".zip", "")))) + } + Try(Files.deleteIfExists(Paths.get(zipPath))) } } @@ -385,6 +411,9 @@ case class MosaicRasterGDAL( def getMemSize: Long = { if (memSize == -1) { val toRead = if (path.startsWith("/vsizip/")) path.replace("/vsizip/", "") else path + if (Files.notExists(Paths.get(toRead))) { + throw new Exception(s"File not found: ${gdal.GetLastErrorMsg()}") + } Files.size(Paths.get(toRead)) } else { memSize @@ -402,12 +431,26 @@ case class MosaicRasterGDAL( * A boolean indicating if the write was successful. */ def writeToPath(path: String, dispose: Boolean = true): String = { - val driver = raster.GetDriver() - val ds = driver.CreateCopy(path, this.flushCache().getRaster) - ds.FlushCache() - ds.delete() - if (dispose) RasterCleaner.dispose(this) - path + if (isSubDataset) { + val driver = raster.GetDriver() + val ds = driver.CreateCopy(path, this.flushCache().getRaster, 1) + if (ds == null) { + val error = gdal.GetLastErrorMsg() + throw new Exception(s"Error writing raster to path: $error") + } + ds.FlushCache() + ds.delete() + if (dispose) RasterCleaner.dispose(this) + path + } else { + val thisPath = Paths.get(this.path) + val fromDir = thisPath.getParent + val toDir = Paths.get(path).getParent + val stemRegex = PathUtils.getStemRegex(this.path) + PathUtils.wildcardCopy(fromDir.toString, toDir.toString, stemRegex) + if (dispose) RasterCleaner.dispose(this) + s"$toDir/${thisPath.getFileName}" + } } /** @@ -418,17 +461,36 @@ case class MosaicRasterGDAL( */ def writeToBytes(dispose: Boolean = true): Array[Byte] = { val isSubdataset = PathUtils.isSubdataset(path) - val readPath = - if (isSubdataset) { - val tmpPath = PathUtils.createTmpFilePath(getRasterFileExtension) - writeToPath(tmpPath, dispose = false) + val readPath = { + val tmpPath = + if (isSubdataset) { + val tmpPath = PathUtils.createTmpFilePath(getRasterFileExtension) + writeToPath(tmpPath, dispose = false) + tmpPath + } else { + path + } + if (Files.isDirectory(Paths.get(tmpPath))) { + val parentDir = Paths.get(tmpPath).getParent.toString + val fileName = Paths.get(tmpPath).getFileName.toString + val prompt = SysUtils.runScript(Array("/bin/sh", "-c", s"cd $parentDir && zip -r0 $fileName.zip $fileName")) + if (prompt._3.nonEmpty) throw new Exception(s"Error zipping file: ${prompt._3}. Please verify that zip is installed. Run 'apt install zip'.") + s"$tmpPath.zip" } else { - path + tmpPath } - val byteArray = Files.readAllBytes(Paths.get(readPath)) + } + val byteArray = FileUtils.readBytes(readPath) if (dispose) RasterCleaner.dispose(this) if (readPath != PathUtils.getCleanPath(parentPath)) { Files.deleteIfExists(Paths.get(readPath)) + if (readPath.endsWith(".zip")) { + val nonZipPath = readPath.replace(".zip", "") + if (Files.isDirectory(Paths.get(nonZipPath))) { + SysUtils.runCommand(s"rm -rf $nonZipPath") + } + Files.deleteIfExists(Paths.get(readPath.replace(".zip", ""))) + } } byteArray } @@ -452,7 +514,7 @@ case class MosaicRasterGDAL( * usable again. */ def refresh(): MosaicRasterGDAL = { - MosaicRasterGDAL(openRaster(path), path, parentPath, driverShortName, memSize) + MosaicRasterGDAL(openRaster(path), createInfo, memSize) } /** @@ -484,6 +546,20 @@ case class MosaicRasterGDAL( .toMap } + /** + * @return + * Returns the raster's band valid pixel count. + */ + def getValidCount: Map[Int, Long] = { + (1 to numBands) + .map(i => { + val band = raster.GetRasterBand(i) + val validCount = band.AsMDArray().GetStatistics().getValid_count + i -> validCount + }) + .toMap + } + /** * @param subsetName * The name of the subdataset to get. @@ -491,22 +567,88 @@ case class MosaicRasterGDAL( * Returns the raster's subdataset with given name. */ def getSubdataset(subsetName: String): MosaicRasterGDAL = { - subdatasets - val path = Option(raster.GetMetadata_Dict("SUBDATASETS")) - .map(_.asScala.toMap.asInstanceOf[Map[String, String]]) - .getOrElse(Map.empty[String, String]) - .values - .find(_.toUpperCase(Locale.ROOT).endsWith(subsetName.toUpperCase(Locale.ROOT))) - .getOrElse(throw new Exception(s""" - |Subdataset $subsetName not found! - |Available subdatasets: - | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} - """.stripMargin)) - val ds = openRaster(path) + val path = subdatasets.get(s"${subsetName}_tmp") + val gdalError = gdal.GetLastErrorMsg() + val error = path match { + case Some(_) => "" + case None => s""" + |Subdataset $subsetName not found! + |Available subdatasets: + | ${subdatasets.keys.filterNot(_.startsWith("SUBDATASET_")).mkString(", ")} + | """.stripMargin + } + val sanitized = PathUtils.getCleanPath(path.getOrElse(PathUtils.NO_PATH_STRING)) + val subdatasetPath = PathUtils.getSubdatasetPath(sanitized) + + val ds = openRaster(subdatasetPath) // Avoid costly IO to compute MEM size here // It will be available when the raster is serialized for next operation // If value is needed then it will be computed when getMemSize is called - MosaicRasterGDAL(ds, path, parentPath, driverShortName, -1) + val createInfo = Map( + "path" -> path.getOrElse(PathUtils.NO_PATH_STRING), + "parentPath" -> parentPath, + "driver" -> getDriversShortName, + "last_error" -> { + if (gdalError.nonEmpty || error.nonEmpty) s""" + |GDAL Error: $gdalError + |$error + |""".stripMargin + else "" + } + ) + MosaicRasterGDAL(ds, createInfo, -1) + } + + def convolve(kernel: Array[Array[Double]]): MosaicRasterGDAL = { + val resultRasterPath = PathUtils.createTmpFilePath(getRasterFileExtension) + + this.raster + .GetDriver() + .CreateCopy(resultRasterPath, this.raster, 1) + .delete() + + val outputRaster = gdal.Open(resultRasterPath, GF_Write) + + for (bandIndex <- 1 to this.numBands) { + val band = this.getBand(bandIndex) + val outputBand = outputRaster.GetRasterBand(bandIndex) + band.convolve(kernel, outputBand) + } + + val createInfo = Map( + "path" -> resultRasterPath, + "parentPath" -> parentPath, + "driver" -> getDriversShortName + ) + + val result = MosaicRasterGDAL(outputRaster, createInfo, this.memSize) + result.flushCache() + } + + def filter(kernelSize: Int, operation: String): MosaicRasterGDAL = { + val resultRasterPath = PathUtils.createTmpFilePath(getRasterFileExtension) + + this.raster + .GetDriver() + .CreateCopy(resultRasterPath, this.raster, 1) + .delete() + + val outputRaster = gdal.Open(resultRasterPath, GF_Write) + + for (bandIndex <- 1 to this.numBands) { + val band = this.getBand(bandIndex) + val outputBand = outputRaster.GetRasterBand(bandIndex) + band.filter(kernelSize, operation, outputBand) + } + + val createInfo = Map( + "path" -> resultRasterPath, + "parentPath" -> parentPath, + "driver" -> getDriversShortName + ) + + val result = MosaicRasterGDAL(outputRaster, createInfo, this.memSize) + result.flushCache() } } @@ -560,25 +702,44 @@ object MosaicRasterGDAL extends RasterReader { * @example * Raster: path = "file:///path/to/file.tif" Subdataset: path = * "file:///path/to/file.tif:subdataset" - * @param inPath - * The path to the raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - path: The path to the raster file. + * - parentPath: The path of the parent raster file. * @return * A MosaicRaster object. */ - override def readRaster(inPath: String, parentPath: String): MosaicRasterGDAL = { + override def readRaster(createInfo: Map[String, String]): MosaicRasterGDAL = { + val inPath = createInfo("path") val isSubdataset = PathUtils.isSubdataset(inPath) val path = PathUtils.getCleanPath(inPath) val readPath = if (isSubdataset) PathUtils.getSubdatasetPath(path) else PathUtils.getZipPath(path) val dataset = openRaster(readPath, None) - val driverShortName = dataset.GetDriver().getShortName - + val error = + if (dataset == null) { + val error = gdal.GetLastErrorMsg() + s""" + Error reading raster from path: $readPath + Error: $error + """ + } else "" + val driverShortName = Try(dataset.GetDriver().getShortName).getOrElse("NONE") // Avoid costly IO to compute MEM size here // It will be available when the raster is serialized for next operation // If value is needed then it will be computed when getMemSize is called // We cannot just use memSize value of the parent due to the fact that the raster could be a subdataset - val raster = MosaicRasterGDAL(dataset, path, parentPath, driverShortName, -1) + val raster = MosaicRasterGDAL( + dataset, + createInfo ++ + Map( + "driver" -> driverShortName, + "last_error" -> error + ), + -1 + ) raster } @@ -586,30 +747,50 @@ object MosaicRasterGDAL extends RasterReader { * Reads a raster from a byte array. * @param contentBytes * The byte array containing the raster data. - * @param driverShortName - * The driver short name of the raster. + * @param createInfo + * Mosaic creation info of the raster. Note: This is not the same as the + * metadata of the raster. This is not the same as GDAL creation options. * @return * A MosaicRaster object. */ - override def readRaster(contentBytes: Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL = { + override def readRaster(contentBytes: Array[Byte], createInfo: Map[String, String]): MosaicRasterGDAL = { if (Option(contentBytes).isEmpty || contentBytes.isEmpty) { - MosaicRasterGDAL(null, "", parentPath, "", -1) + MosaicRasterGDAL(null, createInfo, -1) } else { // This is a temp UUID for purposes of reading the raster through GDAL from memory // The stable UUID is kept in metadata of the raster + val driverShortName = createInfo("driver") val extension = GDAL.getExtension(driverShortName) val tmpPath = PathUtils.createTmpFilePath(extension) Files.write(Paths.get(tmpPath), contentBytes) // Try reading as a tmp file, if that fails, rename as a zipped file val dataset = openRaster(tmpPath, Some(driverShortName)) if (dataset == null) { - val zippedPath = PathUtils.createTmpFilePath("zip") + val zippedPath = s"$tmpPath.zip" Files.move(Paths.get(tmpPath), Paths.get(zippedPath), StandardCopyOption.REPLACE_EXISTING) val readPath = PathUtils.getZipPath(zippedPath) val ds = openRaster(readPath, Some(driverShortName)) - MosaicRasterGDAL(ds, readPath, parentPath, driverShortName, contentBytes.length) + if (ds == null) { + // the way we zip using uuid is not compatible with GDAL + // we need to unzip and read the file if it was zipped by us + val parentDir = Paths.get(zippedPath).getParent + val prompt = SysUtils.runScript(Array("/bin/sh", "-c", s"cd $parentDir && unzip -o $zippedPath -d $parentDir")) + // zipped files will have the old uuid name of the raster + // we need to get the last extracted file name, but the last extracted file name is not the raster name + // we can't list folders due to concurrent writes + val extension = GDAL.getExtension(driverShortName) + val lastExtracted = SysUtils.getLastOutputLine(prompt) + val unzippedPath = PathUtils.parseUnzippedPathFromExtracted(lastExtracted, extension) + val dataset = openRaster(unzippedPath, Some(driverShortName)) + if (dataset == null) { + throw new Exception(s"Error reading raster from bytes: ${prompt._3}") + } + MosaicRasterGDAL(dataset, createInfo + ("path" -> unzippedPath), contentBytes.length) + } else { + MosaicRasterGDAL(ds, createInfo + ("path" -> readPath), contentBytes.length) + } } else { - MosaicRasterGDAL(dataset, tmpPath, parentPath, driverShortName, contentBytes.length) + MosaicRasterGDAL(dataset, createInfo + ("path" -> tmpPath), contentBytes.length) } } } @@ -621,15 +802,19 @@ object MosaicRasterGDAL extends RasterReader { * @example * Raster: path = "file:///path/to/file.tif" Subdataset: path = * "file:///path/to/file.tif:subdataset" - * @param path - * The path to the raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - path: The path to the raster file. + * - parentPath: The path of the parent raster file. + * - driver: Optional: The driver short name of the raster file * @param bandIndex * The band index to read. * @return * A MosaicRaster object. */ - override def readBand(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL = { - val raster = readRaster(path, parentPath) + override def readBand(bandIndex: Int, createInfo: Map[String, String]): MosaicRasterBandGDAL = { + val raster = readRaster(createInfo) // TODO: Raster and Band are coupled, this can cause a pointer leak raster.getBand(bandIndex) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala new file mode 100644 index 000000000..68a7bd75a --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/MosaicRasterWriteOptions.scala @@ -0,0 +1,55 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +import com.databricks.labs.mosaic.gdal.MosaicGDAL +import org.gdal.osr.SpatialReference + +case class MosaicRasterWriteOptions( + compression: String = "DEFLATE", + format: String = "GTiff", + extension: String = "tif", + resampling: String = "nearest", + crs: SpatialReference = MosaicGDAL.WSG84, // Assume WGS84 + pixelSize: Option[(Double, Double)] = None, + noDataValue: Option[Double] = None, + missingGeoRef: Boolean = false, + options: Map[String, String] = Map.empty[String, String] +) + +object MosaicRasterWriteOptions { + + val VRT: MosaicRasterWriteOptions = + MosaicRasterWriteOptions( + compression = "NONE", + format = "VRT", + extension = "vrt", + crs = MosaicGDAL.WSG84, + pixelSize = None, + noDataValue = None, + options = Map.empty[String, String] + ) + + val GTiff: MosaicRasterWriteOptions = MosaicRasterWriteOptions() + + def noGPCsNoTransform(raster: MosaicRasterGDAL): Boolean = { + val noGPCs = raster.getRaster.GetGCPCount == 0 + val noGeoTransform = raster.getRaster.GetGeoTransform == null || + (raster.getRaster.GetGeoTransform sameElements Array(0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) + noGPCs && noGeoTransform + } + + def apply(): MosaicRasterWriteOptions = new MosaicRasterWriteOptions() + + def apply(raster: MosaicRasterGDAL): MosaicRasterWriteOptions = { + val compression = raster.getCompression + val format = raster.getRaster.GetDriver.getShortName + val extension = raster.getRasterFileExtension + val resampling = "nearest" + val pixelSize = None + val noDataValue = None + val options = Map.empty[String, String] + val crs = raster.getSpatialReference + val missingGeoRef = noGPCsNoTransform(raster) + new MosaicRasterWriteOptions(compression, format, extension, resampling, crs, pixelSize, noDataValue, missingGeoRef, options) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala new file mode 100644 index 000000000..bb32e772f --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/gdal/Padding.scala @@ -0,0 +1,58 @@ +package com.databricks.labs.mosaic.core.raster.gdal + +case class Padding( + left: Boolean, + right: Boolean, + top: Boolean, + bottom: Boolean +) { + + def removePadding(array: Array[Double], rowWidth: Int, stride: Int): Array[Double] = { + val l = if (left) 1 else 0 + val r = if (right) 1 else 0 + val t = if (top) 1 else 0 + val b = if (bottom) 1 else 0 + + val yStart = t * stride * rowWidth + val yEnd = array.length - b * stride * rowWidth + + val slices = for (i <- yStart until yEnd by rowWidth) yield { + val xStart = i + l * stride + val xEnd = i + rowWidth - r * stride + array.slice(xStart, xEnd) + } + + slices.flatten.toArray + } + + def horizontalStrides: Int = { + if (left && right) 2 + else if (left || right) 1 + else 0 + } + + def verticalStrides: Int = { + if (top && bottom) 2 + else if (top || bottom) 1 + else 0 + } + + def newOffset(xOffset: Int, yOffset: Int, stride: Int): (Int, Int) = { + val x = if (left) xOffset + stride else xOffset + val y = if (top) yOffset + stride else yOffset + (x, y) + } + + def newSize(width: Int, height: Int, stride: Int): (Int, Int) = { + val w = if (left && right) width - 2 * stride else if (left || right) width - stride else width + val h = if (top && bottom) height - 2 * stride else if (top || bottom) height - stride else height + (w, h) + } + +} + +object Padding { + + val NoPadding: Padding = Padding(left = false, right = false, top = false, bottom = false) + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala index b207789ae..d8a1a90c1 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/io/RasterReader.scala @@ -20,14 +20,16 @@ trait RasterReader extends Logging { * @example * Raster: path = "/path/to/file.tif" Subdataset: path = * "FORMAT:/path/to/file.tif:subdataset" - * @param path - * The path to the raster file. - * @param parentPath - * The path of the parent raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - path: The path to the raster file. + * - parentPath: The path of the parent raster file. + * - driver: Optional: The driver short name of the raster file * @return * A MosaicRaster object. */ - def readRaster(path: String, parentPath: String): MosaicRasterGDAL + def readRaster(createInfo: Map[String, String]): MosaicRasterGDAL /** * Reads a raster from an in memory buffer. Use the buffer bytes to produce @@ -35,30 +37,32 @@ trait RasterReader extends Logging { * * @param contentBytes * The file bytes. - * @param parentPath - * The path of the parent raster file. - * @param driverShortName - * The driver short name of the raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - parentPath: The path of the parent raster file. + * - driver: The driver short name of the raster file * @return * A MosaicRaster object. */ - def readRaster(contentBytes: Array[Byte], parentPath: String, driverShortName: String): MosaicRasterGDAL + def readRaster(contentBytes: Array[Byte], createInfo: Map[String, String]): MosaicRasterGDAL /** * Reads a raster band from a file system path. Reads a subdataset band if * the path is to a subdataset. + * * @example * Raster: path = "/path/to/file.tif" Subdataset: path = * "FORMAT:/path/to/file.tif:subdataset" - * @param path - * The path to the raster file. - * @param bandIndex - * The band index to read. - * @param parentPath - * The path of the parent raster file. + * @param createInfo + * The create info for the raster. This should contain the following + * keys: + * - path: The path to the raster file. + * - parentPath: The path of the parent raster file. + * - driver: Optional: The driver short name of the raster file * @return * A MosaicRaster object. */ - def readBand(path: String, bandIndex: Int, parentPath: String): MosaicRasterBandGDAL + def readBand(bandIndex: Int, createInfo: Map[String, String]): MosaicRasterBandGDAL } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala index 41c967947..ddae2fed2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/clip/RasterClipByVector.scala @@ -19,8 +19,7 @@ object RasterClipByVector { * abstractions over GDAL Warp. It uses CUTLINE_ALL_TOUCHED=TRUE to ensure * that all pixels that touch the geometry are included. This will avoid * the issue of having a pixel that is half in and half out of the - * geometry, important for tessellation. It also uses COMPRESS=DEFLATE to - * ensure that the output is compressed. The method also uses the geometry + * geometry, important for tessellation. The method also uses the geometry * API to generate a shapefile that is used to clip the raster. The * shapefile is deleted after the clip is complete. * @@ -38,16 +37,19 @@ object RasterClipByVector { def clip(raster: MosaicRasterGDAL, geometry: MosaicGeometry, geomCRS: SpatialReference, geometryAPI: GeometryAPI): MosaicRasterGDAL = { val rasterCRS = raster.getSpatialReference val outShortName = raster.getDriversShortName - val geomSrcCRS = if (geomCRS == null ) rasterCRS else geomCRS + val geomSrcCRS = if (geomCRS == null) rasterCRS else geomCRS val resultFileName = PathUtils.createTmpFilePath(GDAL.getExtension(outShortName)) val shapeFileName = VectorClipper.generateClipper(geometry, geomSrcCRS, rasterCRS, geometryAPI) + // For -wo consult https://gdal.org/doxygen/structGDALWarpOptions.html + // SOURCE_EXTRA=3 is used to ensure that when the raster is clipped, the + // pixels that touch the geometry are included. The default is 1, 3 seems to be a good empirical value. val result = GDALWarp.executeWarp( resultFileName, Seq(raster), - command = s"gdalwarp -wo CUTLINE_ALL_TOUCHED=TRUE -of $outShortName -cutline $shapeFileName -crop_to_cutline -co COMPRESS=DEFLATE" + command = s"gdalwarp -wo CUTLINE_ALL_TOUCHED=TRUE -cutline $shapeFileName -crop_to_cutline" ) VectorClipper.cleanUpClipper(shapeFileName) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala index 389defad6..cb79dc263 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALBuildVRT.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import org.gdal.gdal.{BuildVRTOptions, gdal} /** GDALBuildVRT is a wrapper for the GDAL BuildVRT command. */ @@ -20,20 +20,21 @@ object GDALBuildVRT { */ def executeVRT(outputPath: String, rasters: Seq[MosaicRasterGDAL], command: String): MosaicRasterGDAL = { require(command.startsWith("gdalbuildvrt"), "Not a valid GDAL Build VRT command.") - val vrtOptionsVec = OperatorOptions.parseOptions(command) + val effectiveCommand = OperatorOptions.appendOptions(command, MosaicRasterWriteOptions.VRT) + val vrtOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val vrtOptions = new BuildVRTOptions(vrtOptionsVec) val result = gdal.BuildVRT(outputPath, rasters.map(_.getRaster).toArray, vrtOptions) - if (result == null) { - throw new Exception( - s""" - |Build VRT failed. - |Command: $command - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) - } - // TODO: Figure out multiple parents, should this be an array? + val errorMsg = gdal.GetLastErrorMsg + val createInfo = Map( + "path" -> outputPath, + "parentPath" -> rasters.head.getParentPath, + "driver" -> "VRT", + "last_command" -> effectiveCommand, + "last_error" -> errorMsg, + "all_parents" -> rasters.map(_.getParentPath).mkString(";") + ) // VRT files are just meta files, mem size doesnt make much sense so we keep -1 - MosaicRasterGDAL(result, outputPath, rasters.head.getParentPath, "VRT", -1).flushCache() + MosaicRasterGDAL(result, createInfo, -1).flushCache() } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala index c33e293d6..e22228817 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala @@ -1,8 +1,9 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal import com.databricks.labs.mosaic.core.raster.api.GDAL -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import com.databricks.labs.mosaic.utils.SysUtils +import org.gdal.gdal.gdal /** GDALCalc is a helper object for executing GDAL Calc commands. */ object GDALCalc { @@ -30,19 +31,29 @@ object GDALCalc { */ def executeCalc(gdalCalcCommand: String, resultPath: String): MosaicRasterGDAL = { require(gdalCalcCommand.startsWith("gdal_calc"), "Not a valid GDAL Calc command.") - val toRun = gdalCalcCommand.replace("gdal_calc", gdal_calc) + val effectiveCommand = OperatorOptions.appendOptions(gdalCalcCommand, MosaicRasterWriteOptions.GTiff) + val toRun = effectiveCommand.replace("gdal_calc", gdal_calc) val commandRes = SysUtils.runCommand(s"python3 $toRun") - if (commandRes._1 == "ERROR") { - throw new RuntimeException(s""" - |GDAL Calc command failed: - |STDOUT: - |${commandRes._2} - |STDERR: - |${commandRes._3} - |""".stripMargin) - } + val errorMsg = gdal.GetLastErrorMsg val result = GDAL.raster(resultPath, resultPath) - result + val createInfo = Map( + "path" -> resultPath, + "parentPath" -> resultPath, + "driver" -> "GTiff", + "last_command" -> effectiveCommand, + "last_error" -> errorMsg, + "all_parents" -> resultPath, + "full_error" -> s""" + |GDAL Calc command failed: + |GDAL err: + |$errorMsg + |STDOUT: + |${commandRes._2} + |STDERR: + |${commandRes._3} + |""".stripMargin + ) + result.copy(createInfo = createInfo) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala new file mode 100644 index 000000000..d3ccd471b --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALInfo.scala @@ -0,0 +1,38 @@ +package com.databricks.labs.mosaic.core.raster.operator.gdal + +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import org.gdal.gdal.{InfoOptions, gdal} + +/** GDALBuildVRT is a wrapper for the GDAL BuildVRT command. */ +object GDALInfo { + + /** + * Executes the GDAL BuildVRT command. For flags check the way gdalinfo.py + * script is called, InfoOptions expects a collection of same flags. + * + * @param raster + * The raster to get info from. + * @param command + * The GDAL Info command. + * @return + * A result json string. + */ + def executeInfo(raster: MosaicRasterGDAL, command: String): String = { + require(command.startsWith("gdalinfo"), "Not a valid GDAL Info command.") + + val infoOptionsVec = OperatorOptions.parseOptions(command) + val infoOptions = new InfoOptions(infoOptionsVec) + val gdalInfo = gdal.GDALInfo(raster.getRaster, infoOptions) + + if (gdalInfo == null) { + s""" + |GDAL Info failed. + |Command: $command + |Error: ${gdal.GetLastErrorMsg} + |""".stripMargin + } else { + gdalInfo + } + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala index bf266cfbf..2fb106fda 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALTranslate.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.gdal.{MosaicRasterGDAL, MosaicRasterWriteOptions} import org.gdal.gdal.{TranslateOptions, gdal} import java.nio.file.{Files, Paths} @@ -20,21 +20,30 @@ object GDALTranslate { * @return * A MosaicRaster object. */ - def executeTranslate(outputPath: String, raster: MosaicRasterGDAL, command: String): MosaicRasterGDAL = { + def executeTranslate( + outputPath: String, + raster: MosaicRasterGDAL, + command: String, + writeOptions: MosaicRasterWriteOptions + ): MosaicRasterGDAL = { require(command.startsWith("gdal_translate"), "Not a valid GDAL Translate command.") - val translateOptionsVec = OperatorOptions.parseOptions(command) + val effectiveCommand = OperatorOptions.appendOptions(command, writeOptions) + val translateOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val translateOptions = new TranslateOptions(translateOptionsVec) val result = gdal.Translate(outputPath, raster.getRaster, translateOptions) - if (result == null) { - throw new Exception( - s""" - |Translate failed. - |Command: $command - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) - } + val errorMsg = gdal.GetLastErrorMsg val size = Files.size(Paths.get(outputPath)) - raster.copy(raster = result, path = outputPath, memSize = size).flushCache() + val createInfo = Map( + "path" -> outputPath, + "parentPath" -> raster.getParentPath, + "driver" -> writeOptions.format, + "last_command" -> effectiveCommand, + "last_error" -> errorMsg, + "all_parents" -> raster.getParentPath + ) + raster + .copy(raster = result, createInfo = createInfo, memSize = size) + .flushCache() } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala index 2b13a957b..516560a76 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALWarp.scala @@ -23,26 +23,22 @@ object GDALWarp { def executeWarp(outputPath: String, rasters: Seq[MosaicRasterGDAL], command: String): MosaicRasterGDAL = { require(command.startsWith("gdalwarp"), "Not a valid GDAL Warp command.") // Test: gdal.ParseCommandLine(command) - val warpOptionsVec = OperatorOptions.parseOptions(command) + val effectiveCommand = OperatorOptions.appendOptions(command, rasters.head.getWriteOptions) + val warpOptionsVec = OperatorOptions.parseOptions(effectiveCommand) val warpOptions = new WarpOptions(warpOptionsVec) val result = gdal.Warp(outputPath, rasters.map(_.getRaster).toArray, warpOptions) - // TODO: Figure out multiple parents, should this be an array? // Format will always be the same as the first raster - if (result == null) { - throw new Exception(s""" - |Warp failed. - |Command: $command - |Error: ${gdal.GetLastErrorMsg} - |""".stripMargin) - } + val errorMsg = gdal.GetLastErrorMsg val size = Files.size(Paths.get(outputPath)) - rasters.head - .copy( - raster = result, - path = outputPath, - memSize = size - ) - .flushCache() + val createInfo = Map( + "path" -> outputPath, + "parentPath" -> rasters.head.getParentPath, + "driver" -> rasters.head.getWriteOptions.format, + "last_command" -> effectiveCommand, + "last_error" -> errorMsg, + "all_parents" -> rasters.map(_.getParentPath).mkString(";") + ) + rasters.head.copy(raster = result, createInfo = createInfo, memSize = size).flushCache() } } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala index b1529d3e7..bc656ec01 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/OperatorOptions.scala @@ -1,5 +1,7 @@ package com.databricks.labs.mosaic.core.raster.operator.gdal +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterWriteOptions + /** OperatorOptions is a helper object for parsing GDAL command options. */ object OperatorOptions { @@ -18,4 +20,37 @@ object OperatorOptions { optionsVec } + /** + * Add default options to the command. Extract the compression from the + * raster and append it to the command. This operation does not change the + * output format. For changing the output format, use RST_ToFormat. + * + * @param command + * The command to append options to. + * @param writeOptions + * The write options to append. Note that not all available options are + * actually appended. At this point it is up to the bellow logic to + * decide what is supported and for which format. + * @return + */ + def appendOptions(command: String, writeOptions: MosaicRasterWriteOptions): String = { + val compression = writeOptions.compression + if (command.startsWith("gdal_calc")) { + writeOptions.format match { + case f @ "GTiff" => command + s" --format $f --co TILED=YES --co COMPRESS=$compression" + case f @ "COG" => command + s" --format $f --co TILED=YES --co COMPRESS=$compression" + case f @ _ => command + s" --format $f --co COMPRESS=$compression" + } + } else { + writeOptions.format match { + case f @ "GTiff" => command + s" -of $f -co TILED=YES -co COMPRESS=$compression" + case f @ "COG" => command + s" -of $f -co TILED=YES -co COMPRESS=$compression" + case "VRT" => command + case f @ "Zarr" if writeOptions.missingGeoRef => + command + s" -of $f -co COMPRESS=$compression -to SRC_METHOD=NO_GEOTRANSFORM" + case f @ _ => command + s" -of $f -co COMPRESS=$compression" + } + } + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala index 6333c50c8..8a82d1238 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeBands.scala @@ -19,10 +19,10 @@ object MergeBands { * A MosaicRaster object. */ def merge(rasters: Seq[MosaicRasterGDAL], resampling: String): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( vrtPath, @@ -33,7 +33,8 @@ object MergeBands { val result = GDALTranslate.executeTranslate( rasterPath, vrtRaster, - command = s"gdal_translate -r $resampling -of $outShortName -co COMPRESS=DEFLATE" + command = s"gdal_translate -r $resampling", + outOptions ) dispose(vrtRaster) @@ -55,10 +56,10 @@ object MergeBands { * A MosaicRaster object. */ def merge(rasters: Seq[MosaicRasterGDAL], pixel: (Double, Double), resampling: String): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( vrtPath, @@ -69,7 +70,8 @@ object MergeBands { val result = GDALTranslate.executeTranslate( rasterPath, vrtRaster, - command = s"gdalwarp -r $resampling -of $outShortName -co COMPRESS=DEFLATE -overwrite" + command = s"gdalwarp -r $resampling", + outOptions ) dispose(vrtRaster) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala index 694d9940a..fafaffbc4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/merge/MergeRasters.scala @@ -17,21 +17,22 @@ object MergeRasters { * A MosaicRaster object. */ def merge(rasters: Seq[MosaicRasterGDAL]): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( - vrtPath, - rasters, - command = s"gdalbuildvrt -resolution highest" + vrtPath, + rasters, + command = s"gdalbuildvrt -resolution highest" ) val result = GDALTranslate.executeTranslate( - rasterPath, - vrtRaster, - command = s"gdal_translate -r bilinear -of $outShortName -co COMPRESS=DEFLATE" + rasterPath, + vrtRaster, + command = s"gdal_translate", + outOptions ) dispose(vrtRaster) @@ -39,5 +40,4 @@ object MergeRasters { result } - } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala index 5bf49fb96..cda9824dc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/pixel/PixelCombineRasters.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.core.raster.operator.pixel import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALBuildVRT, GDALTranslate} +import com.databricks.labs.mosaic.gdal.MosaicGDAL.defaultBlockSize import com.databricks.labs.mosaic.utils.PathUtils import java.io.File @@ -20,10 +21,10 @@ object PixelCombineRasters { * A MosaicRaster object. */ def combine(rasters: Seq[MosaicRasterGDAL], pythonFunc: String, pythonFuncName: String): MosaicRasterGDAL = { - val outShortName = rasters.head.getRaster.GetDriver.getShortName + val outOptions = rasters.head.getWriteOptions val vrtPath = PathUtils.createTmpFilePath("vrt") - val rasterPath = PathUtils.createTmpFilePath("tif") + val rasterPath = PathUtils.createTmpFilePath(outOptions.extension) val vrtRaster = GDALBuildVRT.executeVRT( vrtPath, @@ -37,7 +38,8 @@ object PixelCombineRasters { val result = GDALTranslate.executeTranslate( rasterPath, vrtRaster.refresh(), - command = s"gdal_translate -r bilinear -of $outShortName -co COMPRESS=DEFLATE" + command = s"gdal_translate", + outOptions ) dispose(vrtRaster) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala index efd7c8c67..5d7c5f5f2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/proj/RasterProject.scala @@ -15,8 +15,7 @@ object RasterProject { /** * Projects a raster to a new CRS. The method handles all the abstractions * over GDAL Warp. It uses cubic resampling to ensure that the output is - * smooth. It also uses COMPRESS=DEFLATE to ensure that the output is - * compressed. + * smooth. * * @param raster * The raster to project. @@ -33,11 +32,11 @@ object RasterProject { // Note that Null is the right value here val authName = destCRS.GetAuthorityName(null) val authCode = destCRS.GetAuthorityCode(null) - + val result = GDALWarp.executeWarp( resultFileName, Seq(raster), - command = s"gdalwarp -of $outShortName -t_srs $authName:$authCode -r cubic -overwrite -co COMPRESS=DEFLATE" + command = s"gdalwarp -t_srs $authName:$authCode" ) result diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala index 75e59c1fa..daa0e6266 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/BalancedSubdivision.scala @@ -21,9 +21,12 @@ object BalancedSubdivision { */ def getNumSplits(raster: MosaicRasterGDAL, destSize: Int): Int = { val size = raster.getMemSize - val n = size.toDouble / (destSize * 1000 * 1000) - val nInt = Math.ceil(n).toInt - Math.pow(4, Math.ceil(Math.log(nInt) / Math.log(4))).toInt + var n = 1 + while (true) { + n *= 4 + if (size / n <= destSize * 1000 * 1000) return n + } + n } /** diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala index c1498ea05..072380666 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/OverlappingTiles.scala @@ -48,12 +48,13 @@ object OverlappingTiles { val fileExtension = GDAL.getExtension(tile.getDriver) val rasterPath = PathUtils.createTmpFilePath(fileExtension) - val shortName = raster.getRaster.GetDriver.getShortName + val outOptions = raster.getWriteOptions val result = GDALTranslate.executeTranslate( rasterPath, raster, - command = s"gdal_translate -of $shortName -srcwin $xOff $yOff $width $height" + command = s"gdal_translate -srcwin $xOff $yOff $width $height", + outOptions ) val isEmpty = result.isEmpty @@ -68,7 +69,7 @@ object OverlappingTiles { val (_, valid) = tiles.flatten.partition(_._1) - valid.map(t => MosaicRasterTile(null, t._2, raster.getParentPath, raster.getDriversShortName)) + valid.map(t => MosaicRasterTile(null, t._2)) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala index 0b31519ea..9920af923 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/RasterTessellate.scala @@ -38,12 +38,13 @@ object RasterTessellate { val cellID = cell.cellIdAsLong(indexSystem) val isValidCell = indexSystem.isValid(cellID) if (!isValidCell) { - (false, MosaicRasterTile(cell.index, null, "", "")) + (false, MosaicRasterTile(cell.index, null)) } else { val cellRaster = tmpRaster.getRasterForCell(cellID, indexSystem, geometryAPI) val isValidRaster = !cellRaster.isEmpty ( - isValidRaster, MosaicRasterTile(cell.index, cellRaster, raster.getParentPath, raster.getDriversShortName) + isValidRaster, + MosaicRasterTile(cell.index, cellRaster) ) } }) diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala index 9a995b6b1..7b218199e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/retile/ReTile.scala @@ -1,7 +1,7 @@ package com.databricks.labs.mosaic.core.raster.operator.retile import com.databricks.labs.mosaic.core.raster.io.RasterCleaner.dispose -import com.databricks.labs.mosaic.core.raster.operator.gdal.{GDALBuildVRT, GDALTranslate} +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.utils.PathUtils @@ -39,12 +39,13 @@ object ReTile { val fileExtension = raster.getRasterFileExtension val rasterPath = PathUtils.createTmpFilePath(fileExtension) - val shortDriver = raster.getDriversShortName + val outOptions = raster.getWriteOptions val result = GDALTranslate.executeTranslate( rasterPath, raster, - command = s"gdal_translate -of $shortDriver -srcwin $xMin $yMin $xOffset $yOffset -co COMPRESS=DEFLATE" + command = s"gdal_translate -srcwin $xMin $yMin $xOffset $yOffset", + outOptions ) val isEmpty = result.isEmpty @@ -57,7 +58,7 @@ object ReTile { val (_, valid) = tiles.partition(_._1) - valid.map(t => MosaicRasterTile(null, t._2, raster.getParentPath, raster.getDriversShortName)) + valid.map(t => MosaicRasterTile(null, t._2)) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala index 25f73bf8b..9580cc441 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/separate/SeparateBands.scala @@ -5,7 +5,10 @@ import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALTranslate import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.utils.PathUtils -/** ReTile is a helper object for splitting multi-band rasters into single-band-per-row. */ +/** + * ReTile is a helper object for splitting multi-band rasters into + * single-band-per-row. + */ object SeparateBands { /** @@ -24,11 +27,13 @@ object SeparateBands { val fileExtension = raster.getRasterFileExtension val rasterPath = PathUtils.createTmpFilePath(fileExtension) val shortDriver = raster.getDriversShortName + val outOptions = raster.getWriteOptions val result = GDALTranslate.executeTranslate( rasterPath, raster, - command = s"gdal_translate -of $shortDriver -b ${i + 1} -co COMPRESS=DEFLATE" + command = s"gdal_translate -of $shortDriver -b ${i + 1}", + writeOptions = outOptions ) val isEmpty = result.isEmpty @@ -38,13 +43,13 @@ object SeparateBands { if (isEmpty) dispose(result) - (isEmpty, result, i) + (isEmpty, result.copy(createInfo = result.createInfo ++ Map("bandIndex" -> (i + 1).toString)), i) } val (_, valid) = tiles.partition(_._1) - valid.map(t => new MosaicRasterTile(null, t._2, raster.getParentPath, raster.getDriversShortName)) + valid.map(t => new MosaicRasterTile(null, t._2)) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala index 1cadf2c9a..137d482ce 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/RasterTileType.scala @@ -1,11 +1,12 @@ package com.databricks.labs.mosaic.core.types +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types._ /** Type definition for the raster tile. */ class RasterTileType(fields: Array[StructField]) extends StructType(fields) { - def rasterType: DataType = fields(1).dataType + def rasterType: DataType = fields.find(_.name == "raster").get.dataType override def simpleString: String = "RASTER_TILE" @@ -19,20 +20,53 @@ object RasterTileType { * Creates a new instance of [[RasterTileType]]. * * @param idType - * Type of the index ID. + * Type of the index ID. Can be one of [[LongType]], [[IntegerType]] or + * [[StringType]]. + * @param rasterType + * Type of the raster. Can be one of [[ByteType]] or [[StringType]]. Not + * to be confused with the data type of the raster. This is the type of + * the column that contains the raster. + * * @return * An instance of [[RasterTileType]]. */ - def apply(idType: DataType): RasterTileType = { + def apply(idType: DataType, rasterType: DataType): DataType = { require(Seq(LongType, IntegerType, StringType).contains(idType)) new RasterTileType( Array( StructField("index_id", idType), - StructField("raster", BinaryType), - StructField("parentPath", StringType), - StructField("driver", StringType) + StructField("raster", rasterType), + StructField("metadata", MapType(StringType, StringType)) ) ) } + /** + * Creates a new instance of [[RasterTileType]]. + * + * @param idType + * Type of the index ID. Can be one of [[LongType]], [[IntegerType]] or + * [[StringType]]. + * @param tileExpr + * Expression containing a tile. This is used to infer the raster type + * when chaining expressions. + * @return + */ + def apply(idType: DataType, tileExpr: Expression): DataType = { + require(Seq(LongType, IntegerType, StringType).contains(idType)) + tileExpr.dataType match { + case st @ StructType(_) => apply(idType, st.find(_.name == "raster").get.dataType) + case _ @ArrayType(elementType: StructType, _) => apply(idType, elementType.find(_.name == "raster").get.dataType) + case _ => throw new IllegalArgumentException("Unsupported raster type.") + } + } + + def apply(tileExpr: Expression): RasterTileType = { + tileExpr.dataType match { + case StructType(fields) => new RasterTileType(fields) + case ArrayType(elementType: StructType, _) => new RasterTileType(elementType.fields) + case _ => throw new IllegalArgumentException("Unsupported raster type.") + } + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala index 798b11206..97122a7d7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/types/model/MosaicRasterTile.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.core.types.model import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.expressions.raster.{buildMapString, extractMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{BinaryType, DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -16,18 +17,16 @@ import scala.util.{Failure, Success, Try} * Index ID. * @param raster * Raster instance corresponding to the tile. - * @param parentPath - * Parent path of the raster. - * @param driver - * Driver used to read the raster. */ case class MosaicRasterTile( index: Either[Long, String], - raster: MosaicRasterGDAL, - parentPath: String, - driver: String + raster: MosaicRasterGDAL ) { + def parentPath: String = raster.createInfo("parentPath") + + def driver: String = raster.createInfo("driver") + def getIndex: Either[Long, String] = index def getParentPath: String = parentPath @@ -57,18 +56,8 @@ case class MosaicRasterTile( (indexSystem.getCellIdDataType, index) match { case (_: LongType, Left(_)) => this case (_: StringType, Right(_)) => this - case (_: LongType, Right(value)) => new MosaicRasterTile( - index = Left(indexSystem.parse(value)), - raster = raster, - parentPath = parentPath, - driver = driver - ) - case (_: StringType, Left(value)) => new MosaicRasterTile( - index = Right(indexSystem.format(value)), - raster = raster, - parentPath = parentPath, - driver = driver - ) + case (_: LongType, Right(value)) => this.copy(index = Left(indexSystem.parse(value))) + case (_: StringType, Left(value)) => this.copy(index = Right(indexSystem.format(value))) case _ => throw new IllegalArgumentException("Invalid cell id data type") } } @@ -108,21 +97,24 @@ case class MosaicRasterTile( * An instance of [[InternalRow]]. */ def serialize( - rasterDataType: DataType = BinaryType, - checkpointLocation: String = "" + rasterDataType: DataType ): InternalRow = { - val parentPathUTF8 = UTF8String.fromString(parentPath) - val driverUTF8 = UTF8String.fromString(driver) - val encodedRaster = encodeRaster(rasterDataType, checkpointLocation) + val encodedRaster = encodeRaster(rasterDataType) + val path = if (rasterDataType == StringType) encodedRaster.toString else raster.createInfo("path") + val parentPath = if (raster.createInfo("parentPath").isEmpty) raster.createInfo("path") else raster.createInfo("parentPath") + val newCreateInfo = raster.createInfo + ("path" -> path, "parentPath" -> parentPath) + val mapData = buildMapString(newCreateInfo) if (Option(index).isDefined) { if (index.isLeft) InternalRow.fromSeq( - Seq(index.left.get, encodedRaster, parentPathUTF8, driverUTF8) - ) - else InternalRow.fromSeq( - Seq(UTF8String.fromString(index.right.get), encodedRaster, parentPathUTF8, driverUTF8) + Seq(index.left.get, encodedRaster, mapData) ) + else { + InternalRow.fromSeq( + Seq(UTF8String.fromString(index.right.get), encodedRaster, mapData) + ) + } } else { - InternalRow.fromSeq(Seq(null, encodedRaster, parentPathUTF8, driverUTF8)) + InternalRow.fromSeq(Seq(null, encodedRaster, mapData)) } } @@ -134,10 +126,9 @@ case class MosaicRasterTile( * An instance of [[Array]] of [[Byte]] representing WKB. */ private def encodeRaster( - rasterDataType: DataType = BinaryType, - checkpointLocation: String = "" + rasterDataType: DataType = BinaryType ): Any = { - GDAL.writeRasters(Seq(raster), checkpointLocation, rasterDataType).head + GDAL.writeRasters(Seq(raster), rasterDataType).head } def getSequenceNumber: Int = @@ -145,6 +136,7 @@ case class MosaicRasterTile( case Success(value) => value.toInt case Failure(_) => -1 } + } /** Companion object. */ @@ -160,21 +152,20 @@ object MosaicRasterTile { * @return * An instance of [[MosaicRasterTile]]. */ - def deserialize(row: InternalRow, idDataType: DataType): MosaicRasterTile = { + def deserialize(row: InternalRow, idDataType: DataType, rasterType: DataType): MosaicRasterTile = { val index = row.get(0, idDataType) - val rasterBytes = row.get(1, BinaryType) - val parentPath = row.get(2, StringType).toString - val driver = row.get(3, StringType).toString - val raster = GDAL.readRaster(rasterBytes, parentPath, driver, BinaryType) + val rawRaster = row.get(1, rasterType) + val createInfo = extractMap(row.getMap(2)) + val raster = GDAL.readRaster(rawRaster, createInfo, rasterType) // noinspection TypeCheckCanBeMatch if (Option(index).isDefined) { if (index.isInstanceOf[Long]) { - new MosaicRasterTile(Left(index.asInstanceOf[Long]), raster, parentPath, driver) + new MosaicRasterTile(Left(index.asInstanceOf[Long]), raster) } else { - new MosaicRasterTile(Right(index.asInstanceOf[UTF8String].toString), raster, parentPath, driver) + new MosaicRasterTile(Right(index.asInstanceOf[UTF8String].toString), raster) } } else { - new MosaicRasterTile(null, raster, parentPath, driver) + new MosaicRasterTile(null, raster) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala index 285df2191..b68a580d2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReTileOnRead.scala @@ -19,11 +19,17 @@ import java.nio.file.{Files, Paths} /** An object defining the retiling read strategy for the GDAL file format. */ object ReTileOnRead extends ReadStrategy { + val tileDataType: DataType = StringType + // noinspection DuplicatedCode /** * Returns the schema of the GDAL file format. * @note - * Different read strategies can have different schemas. + * Different read strategies can have different schemas. This is because + * the schema is defined by the read strategy. For retiling we always use + * checkpoint location. In this case rasters are stored off spark rows. + * If you need the tiles in memory please load them from path stored in + * the tile returned by the reader. * * @param options * Options passed to the reader. @@ -54,7 +60,10 @@ object ReTileOnRead extends ReadStrategy { .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) .add(StructField(SRID, IntegerType, nullable = false)) .add(StructField(LENGTH, LongType, nullable = false)) - .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType), nullable = false)) + // Note that for retiling we always use checkpoint location. + // In this case rasters are stored off spark rows. + // If you need the tiles in memory please load them from path stored in the tile returned by the reader. + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType, tileDataType), nullable = false)) } /** @@ -84,7 +93,7 @@ object ReTileOnRead extends ReadStrategy { val uuid = getUUID(status) val sizeInMB = options.getOrElse("sizeInMB", "16").toInt - val tmpPath = PathUtils.copyToTmp(inPath) + var tmpPath = PathUtils.copyToTmpWithRetry(inPath, 5) val tiles = localSubdivide(tmpPath, inPath, sizeInMB) val rows = tiles.map(tile => { @@ -103,7 +112,7 @@ object ReTileOnRead extends ReadStrategy { case other => throw new RuntimeException(s"Unsupported field name: $other") } // Writing to bytes is destructive so we delay reading content and content length until the last possible moment - val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize())) + val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize(tileDataType))) RasterCleaner.dispose(tile) row }) @@ -125,8 +134,9 @@ object ReTileOnRead extends ReadStrategy { */ def localSubdivide(inPath: String, parentPath: String, sizeInMB: Int): Seq[MosaicRasterTile] = { val cleanPath = PathUtils.getCleanPath(inPath) - val raster = MosaicRasterGDAL.readRaster(cleanPath, parentPath) - val inTile = new MosaicRasterTile(null, raster, parentPath, raster.getDriversShortName) + val createInfo = Map("path" -> cleanPath, "parentPath" -> parentPath) + val raster = MosaicRasterGDAL.readRaster(createInfo) + val inTile = new MosaicRasterTile(null, raster) val tiles = BalancedSubdivision.splitRaster(inTile, sizeInMB) RasterCleaner.dispose(raster) RasterCleaner.dispose(inTile) diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadAsPath.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadAsPath.scala new file mode 100644 index 000000000..0973146ef --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadAsPath.scala @@ -0,0 +1,124 @@ +package com.databricks.labs.mosaic.datasource.gdal + +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.datasource.Utils +import com.databricks.labs.mosaic.datasource.gdal.GDALFileFormat._ +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.hadoop.fs.{FileStatus, FileSystem} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +import java.nio.file.{Files, Paths} + +/** An object defining the retiling read strategy for the GDAL file format. */ +object ReadAsPath extends ReadStrategy { + + val tileDataType: DataType = StringType + + // noinspection DuplicatedCode + /** + * Returns the schema of the GDAL file format. + * @note + * Different read strategies can have different schemas. This is because + * the schema is defined by the read strategy. For retiling we always use + * checkpoint location. In this case rasters are stored off spark rows. + * If you need the tiles in memory please load them from path stored in + * the tile returned by the reader. + * + * @param options + * Options passed to the reader. + * @param files + * List of files to read. + * @param parentSchema + * Parent schema. + * @param sparkSession + * Spark session. + * + * @return + * Schema of the GDAL file format. + */ + override def getSchema( + options: Map[String, String], + files: Seq[FileStatus], + parentSchema: StructType, + sparkSession: SparkSession + ): StructType = { + val trimmedSchema = parentSchema.filter(field => field.name != CONTENT && field.name != LENGTH) + val indexSystem = IndexSystemFactory.getIndexSystem(sparkSession) + StructType(trimmedSchema) + .add(StructField(UUID, LongType, nullable = false)) + .add(StructField(X_SIZE, IntegerType, nullable = false)) + .add(StructField(Y_SIZE, IntegerType, nullable = false)) + .add(StructField(BAND_COUNT, IntegerType, nullable = false)) + .add(StructField(METADATA, MapType(StringType, StringType), nullable = false)) + .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) + .add(StructField(SRID, IntegerType, nullable = false)) + .add(StructField(LENGTH, LongType, nullable = false)) + // Note that for retiling we always use checkpoint location. + // In this case rasters are stored off spark rows. + // If you need the tiles in memory please load them from path stored in the tile returned by the reader. + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType, tileDataType), nullable = false)) + } + + /** + * Reads the content of the file. + * @param status + * File status. + * @param fs + * File system. + * @param requiredSchema + * Required schema. + * @param options + * Options passed to the reader. + * @param indexSystem + * Index system. + * + * @return + * Iterator of internal rows. + */ + override def read( + status: FileStatus, + fs: FileSystem, + requiredSchema: StructType, + options: Map[String, String], + indexSystem: IndexSystem + ): Iterator[InternalRow] = { + val inPath = status.getPath.toString + val uuid = getUUID(status) + + val tmpPath = PathUtils.copyToTmp(inPath) + val createInfo = Map("path" -> tmpPath, "parentPath" -> inPath) + val raster = MosaicRasterGDAL.readRaster(createInfo) + val tile = MosaicRasterTile(null, raster) + + val trimmedSchema = StructType(requiredSchema.filter(field => field.name != TILE)) + val fields = trimmedSchema.fieldNames.map { + case PATH => status.getPath.toString + case MODIFICATION_TIME => status.getModificationTime + case UUID => uuid + case X_SIZE => tile.getRaster.xSize + case Y_SIZE => tile.getRaster.ySize + case BAND_COUNT => tile.getRaster.numBands + case METADATA => tile.getRaster.metadata + case SUBDATASETS => tile.getRaster.subdatasets + case SRID => tile.getRaster.SRID + case LENGTH => tile.getRaster.getMemSize + case other => throw new RuntimeException(s"Unsupported field name: $other") + } + // Writing to bytes is destructive so we delay reading content and content length until the last possible moment + val row = Utils.createRow(fields ++ Seq(tile.formatCellId(indexSystem).serialize(tileDataType))) + RasterCleaner.dispose(tile) + + val rows = Seq(row) + + Files.deleteIfExists(Paths.get(tmpPath)) + + rows.iterator + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala index fe2f17148..7e6687079 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadInMemory.scala @@ -6,12 +6,12 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.datasource.Utils import com.databricks.labs.mosaic.datasource.gdal.GDALFileFormat._ +import com.databricks.labs.mosaic.expressions.raster.buildMapString import com.databricks.labs.mosaic.utils.PathUtils import org.apache.hadoop.fs.{FileStatus, FileSystem} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** An object defining the in memory read strategy for the GDAL file format. */ object ReadInMemory extends ReadStrategy { @@ -49,7 +49,9 @@ object ReadInMemory extends ReadStrategy { .add(StructField(METADATA, MapType(StringType, StringType), nullable = false)) .add(StructField(SUBDATASETS, MapType(StringType, StringType), nullable = false)) .add(StructField(SRID, IntegerType, nullable = false)) - .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType), nullable = false)) + // Note, for in memory reads the rasters are stored in the tile. + // For that we use Binary Columns. + .add(StructField(TILE, RasterTileType(indexSystem.getCellIdDataType, BinaryType), nullable = false)) } /** @@ -76,9 +78,12 @@ object ReadInMemory extends ReadStrategy { ): Iterator[InternalRow] = { val inPath = status.getPath.toString val readPath = PathUtils.getCleanPath(inPath) - val driverShortName = MosaicRasterGDAL.identifyDriver(readPath) val contentBytes: Array[Byte] = readContent(fs, status) - val raster = MosaicRasterGDAL.readRaster(readPath, inPath) + val createInfo = Map( + "path" -> readPath, + "parentPath" -> inPath + ) + val raster = MosaicRasterGDAL.readRaster(createInfo) val uuid = getUUID(status) val fields = requiredSchema.fieldNames.filter(_ != TILE).map { @@ -94,8 +99,9 @@ object ReadInMemory extends ReadStrategy { case SRID => raster.SRID case other => throw new RuntimeException(s"Unsupported field name: $other") } + val mapData = buildMapString(raster.createInfo) val rasterTileSer = InternalRow.fromSeq( - Seq(null, contentBytes, UTF8String.fromString(inPath), UTF8String.fromString(driverShortName), null) + Seq(null, contentBytes, mapData) ) val row = Utils.createRow( fields ++ Seq(rasterTileSer) diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala index cacc1c133..ab141b069 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/gdal/ReadStrategy.scala @@ -72,6 +72,7 @@ object ReadStrategy { readStrategy match { case MOSAIC_RASTER_READ_IN_MEMORY => ReadInMemory case MOSAIC_RASTER_RE_TILE_ON_READ => ReTileOnRead + case MOSAIC_RASTER_READ_AS_PATH => ReadAsPath case _ => ReadInMemory } diff --git a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala index c1f805afa..2f5bf39b6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala +++ b/src/main/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReader.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.datasource.multiread +import com.databricks.labs.mosaic.MOSAIC_RASTER_READ_STRATEGY import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -25,6 +26,12 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead nPartitions } + private def workerNCores = { + sparkSession.sparkContext.range(0, 1).map(_ => java.lang.Runtime.getRuntime.availableProcessors).collect.head + } + + private def nWorkers = sparkSession.sparkContext.getExecutorMemoryStatus.size + override def load(path: String): DataFrame = load(Seq(path): _*) override def load(paths: String*): DataFrame = { @@ -32,11 +39,23 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead val config = getConfig val resolution = config("resolution").toInt val nPartitions = getNPartitions(config) + val readStrategy = config("retile") match { + case "true" => "retile_on_read" + case _ => "in_memory" + } + val tileSize = config("sizeInMB").toInt + + val nCores = nWorkers * workerNCores + val stageCoefficient = math.ceil(math.log(nCores) / math.log(4)) + + val firstStageSize = (tileSize * math.pow(4, stageCoefficient)).toInt val pathsDf = sparkSession.read .format("gdal") .option("extensions", config("extensions")) - .option("raster_storage", "in-memory") + .option(MOSAIC_RASTER_READ_STRATEGY, readStrategy) + .option("vsizip", config("vsizip")) + .option("sizeInMB", firstStageSize) .load(paths: _*) .repartition(nPartitions) @@ -47,29 +66,31 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead val retiledDf = retileRaster(rasterDf, config) val loadedDf = retiledDf + .withColumn( + "tile", + rst_tessellate(col("tile"), lit(resolution)) + ) + .repartition(nPartitions) + .groupBy("tile.index_id") + .agg(rst_combineavg_agg(col("tile")).alias("tile")) .withColumn( "grid_measures", - rasterToGridCombiner(col("tile"), lit(resolution)) + rasterToGridCombiner(col("tile")) ) .select( "grid_measures", "tile" ) .select( - posexplode(col("grid_measures")).as(Seq("band_id", "grid_measures")) - ) - .select( - col("band_id"), - explode(col("grid_measures")).alias("grid_measures") + posexplode(col("grid_measures")).as(Seq("band_id", "measure")), + col("tile").getField("index_id").alias("cell_id") ) .repartition(nPartitions) .select( col("band_id"), - col("grid_measures").getItem("cellID").alias("cell_id"), - col("grid_measures").getItem("measure").alias("measure") + col("cell_id"), + col("measure") ) - .groupBy("band_id", "cell_id") - .agg(avg("measure").alias("measure")) kRingResample(loadedDf, config) @@ -88,16 +109,22 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead */ private def retileRaster(rasterDf: DataFrame, config: Map[String, String]) = { val retile = config("retile").toBoolean - val tileSize = config("tileSize").toInt + val tileSize = config.getOrElse("tileSize", "-1").toInt + val memSize = config.getOrElse("sizeInMB", "-1").toInt val nPartitions = getNPartitions(config) if (retile) { - rasterDf - .withColumn( - "tile", - rst_retile(col("tile"), lit(tileSize), lit(tileSize)) - ) - .repartition(nPartitions) + if (memSize > 0) { + rasterDf + .withColumn("tile", rst_subdivide(col("tile"), lit(memSize))) + .repartition(nPartitions) + } else if (tileSize > 0) { + rasterDf + .withColumn("tile", rst_retile(col("tile"), lit(tileSize), lit(tileSize))) + .repartition(nPartitions) + } else { + rasterDf + } } else { rasterDf } @@ -172,15 +199,15 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead * @return * The raster to grid function. */ - private def getRasterToGridFunc(combiner: String): (Column, Column) => Column = { + private def getRasterToGridFunc(combiner: String): Column => Column = { combiner match { - case "mean" => rst_rastertogridavg - case "min" => rst_rastertogridmin - case "max" => rst_rastertogridmax - case "median" => rst_rastertogridmedian - case "count" => rst_rastertogridcount - case "average" => rst_rastertogridavg - case "avg" => rst_rastertogridavg + case "mean" => rst_avg + case "min" => rst_min + case "max" => rst_max + case "median" => rst_median + case "count" => rst_pixelcount + case "average" => rst_avg + case "avg" => rst_avg case _ => throw new Error("Combiner not supported") } } @@ -200,7 +227,8 @@ class RasterAsGridReader(sparkSession: SparkSession) extends MosaicDataFrameRead "resolution" -> this.extraOptions.getOrElse("resolution", "0"), "combiner" -> this.extraOptions.getOrElse("combiner", "mean"), "retile" -> this.extraOptions.getOrElse("retile", "false"), - "tileSize" -> this.extraOptions.getOrElse("tileSize", "256"), + "tileSize" -> this.extraOptions.getOrElse("tileSize", "-1"), + "sizeInMB" -> this.extraOptions.getOrElse("sizeInMB", "-1"), "kRingInterpolate" -> this.extraOptions.getOrElse("kRingInterpolate", "0") ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala new file mode 100644 index 000000000..a5907dbe9 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Avg.scala @@ -0,0 +1,59 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALInfo +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + + +/** Returns the upper left x of the raster. */ +case class RST_Avg(tileExpr: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Avg](tileExpr, returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = ArrayType(DoubleType) + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + import org.json4s._ + import org.json4s.jackson.JsonMethods._ + implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats + + val command = s"gdalinfo -stats -json -mm -nogcp -nomd -norat -noct" + val gdalInfo = GDALInfo.executeInfo(tile.raster, command) + // parse json from gdalinfo + val json = parse(gdalInfo).extract[Map[String, Any]] + val maxValues = json("bands").asInstanceOf[List[Map[String, Any]]].map { band => + band("mean").asInstanceOf[Double] + } + ArrayData.toArrayData(maxValues.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Avg extends WithExpressionInfo { + + override def name: String = "rst_avg" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Avg](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala index 50535ab95..4cdf63673 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BandMetaData.scala @@ -25,13 +25,14 @@ case class RST_BandMetaData(raster: Expression, band: Expression, expressionConf extends RasterBandExpression[RST_BandMetaData]( raster, band, - MapType(StringType, StringType), returnsRaster = false, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, StringType) + /** * @param raster * The raster to be used. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala index f40f6c590..65af217dd 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_BoundingBox.scala @@ -15,10 +15,12 @@ import org.apache.spark.sql.types._ case class RST_BoundingBox( raster: Expression, expressionConfig: MosaicExpressionConfig -) extends RasterExpression[RST_BoundingBox](raster, BinaryType, returnsRaster = false, expressionConfig = expressionConfig) +) extends RasterExpression[RST_BoundingBox](raster, returnsRaster = false, expressionConfig = expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = BinaryType + /** * Computes the bounding box of the raster. The bbox is returned as a WKB * polygon. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala index 2284c3fa9..84e4577be 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Clip.scala @@ -19,13 +19,14 @@ case class RST_Clip( ) extends Raster1ArgExpression[RST_Clip]( rastersExpr, geometryExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: org.apache.spark.sql.types.DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) /** diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala index fcaddf928..de163ca37 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvg.scala @@ -9,29 +9,26 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** Expression for combining rasters using average of pixels. */ case class RST_CombineAvg( - rastersExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_CombineAvg]( - rastersExpr, - RasterTileType(expressionConfig.getCellIdType), + tileExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** Combines the rasters using average of pixels. */ override def rasterTransform(tiles: Seq[MosaicRasterTile]): Any = { val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null - MosaicRasterTile( - index, - CombineAVG.compute(tiles.map(_.getRaster)), - tiles.head.getParentPath, - tiles.head.getDriver - ) + MosaicRasterTile(index, CombineAVG.compute(tiles.map(_.getRaster))) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala index 0a6791487..3bf9248c2 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgAgg.scala @@ -6,6 +6,7 @@ import com.databricks.labs.mosaic.core.raster.io.RasterCleaner import com.databricks.labs.mosaic.core.raster.operator.CombineAVG import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile.{deserialize => deserializeTile} import com.databricks.labs.mosaic.expressions.raster.base.RasterExpressionSerialization import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow @@ -13,7 +14,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType} +import org.apache.spark.sql.types.{ArrayType, DataType} import scala.collection.mutable.ArrayBuffer @@ -23,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer */ //noinspection DuplicatedCode case class RST_CombineAvgAgg( - rasterExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0 @@ -32,21 +33,23 @@ case class RST_CombineAvgAgg( with RasterExpressionSerialization { override lazy val deterministic: Boolean = true - override val child: Expression = rasterExpr + override val child: Expression = tileExpr override val nullable: Boolean = false - override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override lazy val dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + lazy val tileType: DataType = dataType.asInstanceOf[RasterTileType].rasterType override def prettyName: String = "rst_combine_avg_agg" + val cellIDType: DataType = expressionConfig.getCellIdType private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) private lazy val row = new UnsafeRow(1) - def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { + override def update(buffer: ArrayBuffer[Any], input: InternalRow): ArrayBuffer[Any] = { val value = child.eval(input) buffer += InternalRow.copyValue(value) buffer } - def merge(buffer: ArrayBuffer[Any], input: ArrayBuffer[Any]): ArrayBuffer[Any] = { + override def merge(buffer: ArrayBuffer[Any], input: ArrayBuffer[Any]): ArrayBuffer[Any] = { buffer ++= input } @@ -63,23 +66,25 @@ case class RST_CombineAvgAgg( if (buffer.isEmpty) { null + } else if (buffer.size == 1) { + val result = buffer.head + buffer.clear() + result } else { // Do do move the expression - var tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + var tiles = buffer.map(row => deserializeTile(row.asInstanceOf[InternalRow], cellIDType, tileType)) + buffer.clear() // If merging multiple index rasters, the index value is dropped val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null var combined = CombineAVG.compute(tiles.map(_.getRaster)).flushCache() - // TODO: should parent path be an array? - val parentPath = tiles.head.getParentPath - val driver = tiles.head.getDriver - val result = MosaicRasterTile(idx, combined, parentPath, driver) + val result = MosaicRasterTile(idx, combined) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + .serialize(tileType) - tiles.foreach(RasterCleaner.dispose(_)) + tiles.foreach(RasterCleaner.dispose) RasterCleaner.dispose(result) tiles = null @@ -101,7 +106,7 @@ case class RST_CombineAvgAgg( buffer } - override protected def withNewChildInternal(newChild: Expression): RST_CombineAvgAgg = copy(rasterExpr = newChild) + override protected def withNewChildInternal(newChild: Expression): RST_CombineAvgAgg = copy(tileExpr = newChild) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala new file mode 100644 index 000000000..d831ad849 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Convolve.scala @@ -0,0 +1,82 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + +/** The expression for applying kernel filter on a raster. */ +case class RST_Convolve( + rastersExpr: Expression, + kernelExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_Convolve]( + rastersExpr, + kernelExpr, + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + override def dataType: org.apache.spark.sql.types.DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) + + /** + * Clips a raster by a vector. + * + * @param tile + * The raster to be used. + * @param arg1 + * The vector to be used. + * @return + * The clipped raster. + */ + override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = { + val kernel = arg1.asInstanceOf[ArrayData].array.map(_.asInstanceOf[ArrayData].array.map( + el => kernelExpr.dataType match { + case ArrayType(ArrayType(DoubleType, false), false) => el.asInstanceOf[Double] + case ArrayType(ArrayType(DecimalType(), false), false) => el.asInstanceOf[java.math.BigDecimal].doubleValue() + case ArrayType(ArrayType(IntegerType, false), false) => el.asInstanceOf[Int].toDouble + case _ => throw new IllegalArgumentException(s"Unsupported kernel type: ${kernelExpr.dataType}") + } + )) + tile.copy( + raster = tile.getRaster.convolve(kernel) + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Convolve extends WithExpressionInfo { + + override def name: String = "rst_convolve" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster with the kernel filter applied. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster, kernel); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Convolve](2, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala index fcd1116cd..459f1774a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBand.scala @@ -9,25 +9,27 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** Expression for combining rasters using average of pixels. */ case class RST_DerivedBand( - rastersExpr: Expression, + tileExpr: Expression, pythonFuncExpr: Expression, funcNameExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArray2ArgExpression[RST_DerivedBand]( - rastersExpr, + tileExpr, pythonFuncExpr, funcNameExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** Combines the rasters using average of pixels. */ override def rasterTransform(tiles: Seq[MosaicRasterTile], arg1: Any, arg2: Any): Any = { val pythonFunc = arg1.asInstanceOf[UTF8String].toString @@ -35,9 +37,7 @@ case class RST_DerivedBand( val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null MosaicRasterTile( index, - PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName), - tiles.head.getParentPath, - tiles.head.getDriver + PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName) ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala index 3ceb03318..836b79cbd 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandAgg.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer */ //noinspection DuplicatedCode case class RST_DerivedBandAgg( - rasterExpr: Expression, + tileExpr: Expression, pythonFuncExpr: Expression, funcNameExpr: Expression, expressionConfig: MosaicExpressionConfig, @@ -36,13 +36,13 @@ case class RST_DerivedBandAgg( override lazy val deterministic: Boolean = true override val nullable: Boolean = false - override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override lazy val dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) override def prettyName: String = "rst_combine_avg_agg" private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) private lazy val row = new UnsafeRow(1) - override def first: Expression = rasterExpr + override def first: Expression = tileExpr override def second: Expression = pythonFuncExpr override def third: Expression = funcNameExpr @@ -74,21 +74,25 @@ case class RST_DerivedBandAgg( // This works for Literals only val pythonFunc = pythonFuncExpr.eval(null).asInstanceOf[UTF8String].toString val funcName = funcNameExpr.eval(null).asInstanceOf[UTF8String].toString + val rasterType = RasterTileType(tileExpr).rasterType // Do do move the expression - var tiles = buffer.map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + var tiles = buffer.map(row => + MosaicRasterTile.deserialize( + row.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) + ) // If merging multiple index rasters, the index value is dropped val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null var combined = PixelCombineRasters.combine(tiles.map(_.getRaster), pythonFunc, funcName) - // TODO: should parent path be an array? - val parentPath = tiles.head.getParentPath - val driver = tiles.head.getDriver - val result = MosaicRasterTile(idx, combined, parentPath, driver) + val result = MosaicRasterTile(idx, combined) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + .serialize(BinaryType) tiles.foreach(RasterCleaner.dispose(_)) RasterCleaner.dispose(result) @@ -113,7 +117,7 @@ case class RST_DerivedBandAgg( } override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): RST_DerivedBandAgg = - copy(rasterExpr = newFirst, pythonFuncExpr = newSecond, funcNameExpr = newThird) + copy(tileExpr = newFirst, pythonFuncExpr = newSecond, funcNameExpr = newThird) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala new file mode 100644 index 000000000..ee8b34d3b --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Filter.scala @@ -0,0 +1,77 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.unsafe.types.UTF8String + +/** The expression for applying NxN filter on a raster. */ +case class RST_Filter( + rastersExpr: Expression, + kernelSizeExpr: Expression, + operationExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster2ArgExpression[RST_Filter]( + rastersExpr, + kernelSizeExpr, + operationExpr, + returnsRaster = true, + expressionConfig = expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + override def dataType: org.apache.spark.sql.types.DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) + + /** + * Clips a raster by a vector. + * + * @param tile + * The raster to be used. + * @param arg1 + * The vector to be used. + * @return + * The clipped raster. + */ + override def rasterTransform(tile: MosaicRasterTile, arg1: Any, arg2: Any): Any = { + val n = arg1.asInstanceOf[Int] + val operation = arg2.asInstanceOf[UTF8String].toString + tile.copy( + raster = tile.getRaster.filter(n, operation) + ) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Filter extends WithExpressionInfo { + + override def name: String = "rst_filter" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a raster with the filter applied. + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster, kernelSize, operation); + | {index_id, clipped_raster, parentPath, driver} + | {index_id, clipped_raster, parentPath, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Filter](3, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala index 268a1c550..8325dc0f4 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala @@ -9,6 +9,7 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.ArrayType /** The expression for stacking and resampling input bands. */ case class RST_FromBands( @@ -16,13 +17,18 @@ case class RST_FromBands( expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_FromBands]( bandsExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: org.apache.spark.sql.types.DataType = + RasterTileType( + expressionConfig.getCellIdType, + RasterTileType(bandsExpr).rasterType + ) + /** * Stacks and resamples input bands. * @param rasters diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala index 8749c0d5b..1021eb083 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromContent.scala @@ -15,7 +15,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} -import org.apache.spark.sql.types.{DataType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import java.nio.file.{Files, Paths} @@ -25,7 +25,7 @@ import java.nio.file.{Files, Paths} * expression in the expression tree for a raster tile. */ case class RST_FromContent( - rasterExpr: Expression, + contentExpr: Expression, driverExpr: Expression, sizeInMB: Expression, expressionConfig: MosaicExpressionConfig @@ -33,8 +33,10 @@ case class RST_FromContent( with Serializable with NullIntolerant with CodegenFallback { + + val tileType: DataType = BinaryType - override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileType) protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) @@ -46,12 +48,13 @@ case class RST_FromContent( override def inline: Boolean = false - override def children: Seq[Expression] = Seq(rasterExpr, driverExpr, sizeInMB) + override def children: Seq[Expression] = Seq(contentExpr, driverExpr, sizeInMB) override def elementSchema: StructType = StructType(Array(StructField("tile", dataType))) /** - * subdivides raster binary content into tiles of the specified size (in MB). + * subdivides raster binary content into tiles of the specified size (in + * MB). * @param input * The input file path. * @return @@ -61,13 +64,14 @@ case class RST_FromContent( GDAL.enable(expressionConfig) val driver = driverExpr.eval(input).asInstanceOf[UTF8String].toString val ext = GDAL.getExtension(driver) - var rasterArr = rasterExpr.eval(input).asInstanceOf[Array[Byte]] + var rasterArr = contentExpr.eval(input).asInstanceOf[Array[Byte]] val targetSize = sizeInMB.eval(input).asInstanceOf[Int] if (targetSize <= 0 || rasterArr.length <= targetSize) { // - no split required - var raster = MosaicRasterGDAL.readRaster(rasterArr, PathUtils.NO_PATH_STRING, driver) - var tile = MosaicRasterTile(null, raster, PathUtils.NO_PATH_STRING, driver) - val row = tile.formatCellId(indexSystem).serialize() + val createInfo = Map("parentPath" -> PathUtils.NO_PATH_STRING, "driver" -> driver) + var raster = MosaicRasterGDAL.readRaster(rasterArr, createInfo) + var tile = MosaicRasterTile(null, raster) + val row = tile.formatCellId(indexSystem).serialize(tileType) RasterCleaner.dispose(raster) RasterCleaner.dispose(tile) rasterArr = null @@ -84,7 +88,7 @@ case class RST_FromContent( // split to tiles up to specifed threshold var tiles = ReTileOnRead.localSubdivide(rasterPath, PathUtils.NO_PATH_STRING, targetSize) - val rows = tiles.map(_.formatCellId(indexSystem).serialize()) + val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) tiles.foreach(RasterCleaner.dispose(_)) Files.deleteIfExists(Paths.get(rasterPath)) rasterArr = null @@ -118,10 +122,10 @@ object RST_FromContent extends WithExpressionInfo { | ... | """.stripMargin - override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { - (children: Seq[Expression]) => { + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { (children: Seq[Expression]) => + { val sizeExpr = if (children.length < 3) new Literal(-1, IntegerType) else children(2) - RST_FromContent(children(0), children(1), sizeExpr, expressionConfig) + RST_FromContent(children.head, children(1), sizeExpr, expressionConfig) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala index 4a4bf04c0..8e1dc213e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromFile.scala @@ -15,7 +15,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} -import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import java.nio.file.{Files, Paths, StandardCopyOption} @@ -32,8 +32,10 @@ case class RST_FromFile( with Serializable with NullIntolerant with CodegenFallback { + + val tileType: DataType = BinaryType - override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileType) protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) @@ -63,10 +65,12 @@ case class RST_FromFile( val readPath = PathUtils.getCleanPath(path) val driver = MosaicRasterGDAL.identifyDriver(path) val targetSize = sizeInMB.eval(input).asInstanceOf[Int] - if (targetSize <= 0 && Files.size(Paths.get(readPath)) <= Integer.MAX_VALUE) { - var raster = MosaicRasterGDAL.readRaster(readPath, path) - var tile = MosaicRasterTile(null, raster, path, raster.getDriversShortName) - val row = tile.formatCellId(indexSystem).serialize() + val currentSize = Files.size(Paths.get(PathUtils.replaceDBFSTokens(readPath))) + if (targetSize <= 0 && currentSize <= Integer.MAX_VALUE) { + val createInfo = Map("path" -> readPath, "parentPath" -> path) + var raster = MosaicRasterGDAL.readRaster(createInfo) + var tile = MosaicRasterTile(null, raster) + val row = tile.formatCellId(indexSystem).serialize(tileType) RasterCleaner.dispose(raster) RasterCleaner.dispose(tile) raster = null @@ -79,7 +83,7 @@ case class RST_FromFile( Files.copy(Paths.get(readPath), Paths.get(tmpPath), StandardCopyOption.REPLACE_EXISTING) val size = if (targetSize <= 0) 64 else targetSize var tiles = ReTileOnRead.localSubdivide(tmpPath, path, size) - val rows = tiles.map(_.formatCellId(indexSystem).serialize()) + val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) tiles.foreach(RasterCleaner.dispose(_)) Files.deleteIfExists(Paths.get(tmpPath)) tiles = null diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala index 19f30ab23..797db8e09 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GeoReference.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the georeference of the raster. */ case class RST_GeoReference(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_GeoReference](raster, MapType(StringType, DoubleType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_GeoReference](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, DoubleType) + /** Returns the georeference of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val geoTransform = tile.getRaster.getRaster.GetGeoTransform() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala index 1bff1c5ae..0084141c8 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala @@ -8,7 +8,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{ArrayType, DoubleType} +import org.apache.spark.sql.types.{ArrayType, DataType, DoubleType} /** The expression for extracting the no data value of a raster. */ case class RST_GetNoData( @@ -16,13 +16,14 @@ case class RST_GetNoData( expressionConfig: MosaicExpressionConfig ) extends RasterExpression[RST_GetNoData]( rastersExpr, - ArrayType(DoubleType), returnsRaster = false, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = ArrayType(DoubleType) + /** * Extracts the no data value of a raster. * diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala index 709f1aca8..1589fa463 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetSubdataset.scala @@ -8,20 +8,25 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** Returns the subdatasets of the raster. */ -case class RST_GetSubdataset(raster: Expression, subsetName: Expression, expressionConfig: MosaicExpressionConfig) - extends Raster1ArgExpression[RST_GetSubdataset]( - raster, +case class RST_GetSubdataset( + tileExpr: Expression, + subsetName: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_GetSubdataset]( + tileExpr, subsetName, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** Returns the subdatasets of the raster. */ override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = { val subsetName = arg1.asInstanceOf[UTF8String].toString diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala index d39d1cf70..bd54511b0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Height.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the width of the raster. */ case class RST_Height(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Height](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Height](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the width of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.ySize diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala index c97c4365d..6f92b79ae 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_InitNoData.scala @@ -11,20 +11,22 @@ import com.databricks.labs.mosaic.utils.PathUtils import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** The expression that initializes no data values of a raster. */ case class RST_InitNoData( - rastersExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterExpression[RST_InitNoData]( - rastersExpr, - RasterTileType(expressionConfig.getCellIdType), + tileExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Initializes no data values of a raster. * diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala index 15fca50c5..28e70472e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_IsEmpty.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns true if the raster is empty. */ case class RST_IsEmpty(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_IsEmpty](raster, BooleanType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_IsEmpty](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = BooleanType + /** Returns true if the raster is empty. */ override def rasterTransform(tile: MosaicRasterTile): Any = { var raster = tile.getRaster diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala new file mode 100644 index 000000000..f9cf7099d --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTiles.scala @@ -0,0 +1,214 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.MOSAIC_NO_DRIVER +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.datasource.gdal.ReTileOnRead +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{CollectionGenerator, Expression, Literal, NullIntolerant} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import java.nio.file.{Files, Paths} +import scala.util.Try + +/** + * Creates raster tiles from the input column. + * + * @param inputExpr + * The expression for the raster. If the raster is stored on disc, the path + * to the raster is provided. If the raster is stored in memory, the bytes of + * the raster are provided. + * @param sizeInMBExpr + * The size of the tiles in MB. If set to -1, the file is loaded and returned + * as a single tile. If set to 0, the file is loaded and subdivided into + * tiles of size 64MB. If set to a positive value, the file is loaded and + * subdivided into tiles of the specified size. If the file is too big to fit + * in memory, it is subdivided into tiles of size 64MB. + * @param driverExpr + * The driver to use for reading the raster. If not specified, the driver is + * inferred from the file extension. If the input is a byte array, the driver + * has to be specified. + * @param withCheckpointExpr + * If set to true, the tiles are written to the checkpoint directory. If set + * to false, the tiles are returned as a in-memory byte arrays. + * @param expressionConfig + * Additional arguments for the expression (expressionConfigs). + */ +case class RST_MakeTiles( + inputExpr: Expression, + driverExpr: Expression, + sizeInMBExpr: Expression, + withCheckpointExpr: Expression, + expressionConfig: MosaicExpressionConfig +) extends CollectionGenerator + with Serializable + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = { + require(withCheckpointExpr.isInstanceOf[Literal]) + if (withCheckpointExpr.eval().asInstanceOf[Boolean]) { + // Raster is referenced via a path + RasterTileType(expressionConfig.getCellIdType, StringType) + } else { + // Raster is referenced via a byte array + RasterTileType(expressionConfig.getCellIdType, BinaryType) + } + } + + protected val geometryAPI: GeometryAPI = GeometryAPI.apply(expressionConfig.getGeometryAPI) + + protected val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) + + protected val cellIdDataType: DataType = indexSystem.getCellIdDataType + + override def position: Boolean = false + + override def inline: Boolean = false + + override def children: Seq[Expression] = Seq(inputExpr, driverExpr, sizeInMBExpr, withCheckpointExpr) + + override def elementSchema: StructType = StructType(Array(StructField("tile", dataType))) + + private def getDriver(rawInput: Any, rawDriver: String): String = { + if (rawDriver == MOSAIC_NO_DRIVER) { + if (inputExpr.dataType == StringType) { + val path = rawInput.asInstanceOf[UTF8String].toString + MosaicRasterGDAL.identifyDriver(path) + } else { + throw new IllegalArgumentException("Driver has to be specified for byte array input") + } + } else { + rawDriver + } + } + + private def getInputSize(rawInput: Any): Long = { + if (inputExpr.dataType == StringType) { + val path = rawInput.asInstanceOf[UTF8String].toString + val cleanPath = PathUtils.replaceDBFSTokens(path) + Files.size(Paths.get(cleanPath)) + } else { + val bytes = rawInput.asInstanceOf[Array[Byte]] + bytes.length + } + } + + /** + * Loads a raster from a file and subdivides it into tiles of the specified + * size (in MB). + * @param input + * The input file path. + * @return + * The tiles. + */ + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + GDAL.enable(expressionConfig) + + val tileType = dataType.asInstanceOf[StructType].find(_.name == "raster").get.dataType + + val rawDriver = driverExpr.eval(input).asInstanceOf[UTF8String].toString + val rawInput = inputExpr.eval(input) + val driver = getDriver(rawInput, rawDriver) + val targetSize = sizeInMBExpr.eval(input).asInstanceOf[Int] + val inputSize = getInputSize(rawInput) + val path = if (inputExpr.dataType == StringType) rawInput.asInstanceOf[UTF8String].toString else PathUtils.NO_PATH_STRING + + if (targetSize <= 0 && inputSize <= Integer.MAX_VALUE) { + // - no split required + val createInfo = Map("parentPath" -> PathUtils.NO_PATH_STRING, "driver" -> driver, "path" -> path) + val raster = GDAL.readRaster(rawInput, createInfo, inputExpr.dataType) + val tile = MosaicRasterTile(null, raster) + val row = tile.formatCellId(indexSystem).serialize(tileType) + RasterCleaner.dispose(raster) + RasterCleaner.dispose(tile) + Seq(InternalRow.fromSeq(Seq(row))) + } else { + // target size is > 0 and raster size > target size + // - write the initial raster to file (unsplit) + // - createDirectories in case of context isolation + val rasterPath = + if (inputExpr.dataType == StringType) { + PathUtils.copyToTmpWithRetry(path, 5) + } else { + val rasterPath = PathUtils.createTmpFilePath(GDAL.getExtension(driver)) + Files.createDirectories(Paths.get(rasterPath).getParent) + Files.write(Paths.get(rasterPath), rawInput.asInstanceOf[Array[Byte]]) + rasterPath + } + val size = if (targetSize <= 0) 64 else targetSize + var tiles = ReTileOnRead.localSubdivide(rasterPath, PathUtils.NO_PATH_STRING, size) + val rows = tiles.map(_.formatCellId(indexSystem).serialize(tileType)) + tiles.foreach(RasterCleaner.dispose(_)) + Files.deleteIfExists(Paths.get(rasterPath)) + tiles = null + rows.map(row => InternalRow.fromSeq(Seq(row))) + } + } + + override def makeCopy(newArgs: Array[AnyRef]): Expression = + GenericExpressionFactory.makeCopyImpl[RST_MakeTiles](this, newArgs, children.length, expressionConfig) + + override def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = makeCopy(newChildren.toArray) + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_MakeTiles extends WithExpressionInfo { + + override def name: String = "rst_maketiles" + + override def usage: String = + """ + |_FUNC_(expr1) - Returns a set of new rasters with the specified tile size (tileWidth x tileHeight). + |""".stripMargin + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_path); + | {index_id, raster, parent_path, driver} + | ... + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { (children: Seq[Expression]) => + { + def checkSize(size: Expression) = Try(size.eval().asInstanceOf[Int]).isSuccess + def checkChkpnt(chkpnt: Expression) = Try(chkpnt.eval().asInstanceOf[Boolean]).isSuccess + def checkDriver(driver: Expression) = Try(driver.eval().asInstanceOf[UTF8String].toString).isSuccess + val noSize = new Literal(-1, IntegerType) + val noDriver = new Literal(UTF8String.fromString(MOSAIC_NO_DRIVER), StringType) + val noCheckpoint = new Literal(false, BooleanType) + + children match { + // Note type checking only works for literals + case Seq(input) => RST_MakeTiles(input, noDriver, noSize, noCheckpoint, expressionConfig) + case Seq(input, driver) if checkDriver(driver) => RST_MakeTiles(input, driver, noSize, noCheckpoint, expressionConfig) + case Seq(input, size) if checkSize(size) => RST_MakeTiles(input, noDriver, size, noCheckpoint, expressionConfig) + case Seq(input, checkpoint) if checkChkpnt(checkpoint) => + RST_MakeTiles(input, noDriver, noSize, checkpoint, expressionConfig) + case Seq(input, size, checkpoint) if checkSize(size) && checkChkpnt(checkpoint) => + RST_MakeTiles(input, noDriver, size, checkpoint, expressionConfig) + case Seq(input, driver, size) if checkDriver(driver) && checkSize(size) => + RST_MakeTiles(input, driver, size, noCheckpoint, expressionConfig) + case Seq(input, driver, checkpoint) if checkDriver(driver) && checkChkpnt(checkpoint) => + RST_MakeTiles(input, driver, noSize, checkpoint, expressionConfig) + case Seq(input, driver, size, checkpoint) if checkDriver(driver) && checkSize(size) && checkChkpnt(checkpoint) => + RST_MakeTiles(input, driver, size, checkpoint, expressionConfig) + case _ => RST_MakeTiles(children.head, children(1), children(2), children(3), expressionConfig) + } + } + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala index 24d941b4b..bc2aea949 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MapAlgebra.scala @@ -11,23 +11,25 @@ import com.databricks.labs.mosaic.utils.PathUtils import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType import org.apache.spark.unsafe.types.UTF8String /** The expression for map algebra. */ case class RST_MapAlgebra( - rastersExpr: Expression, + tileExpr: Expression, jsonSpecExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArray1ArgExpression[RST_MapAlgebra]( - rastersExpr, + tileExpr, jsonSpecExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Map Algebra. * @param tiles @@ -43,12 +45,8 @@ case class RST_MapAlgebra( val resultPath = PathUtils.createTmpFilePath(extension) val command = parseSpec(jsonSpec, resultPath, tiles) val index = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null - MosaicRasterTile( - index, - GDALCalc.executeCalc(command, resultPath), - resultPath, - tiles.head.getDriver - ) + val result = GDALCalc.executeCalc(command, resultPath) + MosaicRasterTile(index, result) } def parseSpec(jsonSpec: String, resultPath: String, tiles: Seq[MosaicRasterTile]): String = { diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala new file mode 100644 index 000000000..434be4a68 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Max.scala @@ -0,0 +1,50 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALInfo +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + + +/** Returns the upper left x of the raster. */ +case class RST_Max(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Max](raster, returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = ArrayType(DoubleType) + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + val nBands = tile.raster.raster.GetRasterCount() + val maxValues = (1 to nBands).map(tile.raster.getBand(_).maxPixelValue) + ArrayData.toArrayData(maxValues.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Max extends WithExpressionInfo { + + override def name: String = "rst_max" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing max values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Max](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala new file mode 100644 index 000000000..19d3fc0a6 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Median.scala @@ -0,0 +1,61 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.api.GDAL +import com.databricks.labs.mosaic.core.raster.operator.gdal.GDALWarp +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import com.databricks.labs.mosaic.utils.PathUtils +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + +/** Returns the upper left x of the raster. */ +case class RST_Median(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Median](rasterExpr, returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = ArrayType(DoubleType) + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + val raster = tile.raster + val width = raster.xSize * raster.pixelXSize + val height = raster.ySize * raster.pixelYSize + val outShortName = raster.getDriversShortName + val resultFileName = PathUtils.createTmpFilePath(GDAL.getExtension(outShortName)) + val medRaster = GDALWarp.executeWarp( + resultFileName, + Seq(raster), + command = s"gdalwarp -r med -tr $width $height -of $outShortName" + ) + // Max pixel is a hack since we get a 1x1 raster back + val maxValues = (1 to medRaster.raster.GetRasterCount()).map(medRaster.getBand(_).maxPixelValue) + ArrayData.toArrayData(maxValues.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Median extends WithExpressionInfo { + + override def name: String = "rst_median" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Median](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala index 772cb4b3b..91770653e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MemSize.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the memory size of the raster in bytes. */ case class RST_MemSize(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_MemSize](raster, LongType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_MemSize](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = LongType + /** Returns the memory size of the raster in bytes. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.getMemSize diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala index 08df40d43..09d4b9d3d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Merge.scala @@ -9,20 +9,22 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** Returns a raster that is a result of merging an array of rasters. */ case class RST_Merge( - rastersExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_Merge]( - rastersExpr, - RasterTileType(expressionConfig.getCellIdType), + tileExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Merges an array of rasters. * @param tiles diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala index fc69523d9..ae56a01ab 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeAgg.scala @@ -20,7 +20,7 @@ import scala.collection.mutable.ArrayBuffer /** Merges rasters into a single raster. */ //noinspection DuplicatedCode case class RST_MergeAgg( - rasterExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0 @@ -29,9 +29,9 @@ case class RST_MergeAgg( with RasterExpressionSerialization { override lazy val deterministic: Boolean = true - override val child: Expression = rasterExpr + override val child: Expression = tileExpr override val nullable: Boolean = false - override val dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override lazy val dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) override def prettyName: String = "rst_merge_agg" private lazy val projection = UnsafeProjection.create(Array[DataType](ArrayType(elementType = dataType, containsNull = false))) @@ -66,20 +66,24 @@ case class RST_MergeAgg( // This is a trick to get the rasters sorted by their parent path to ensure more consistent results // when merging rasters with large overlaps + val rasterType = RasterTileType(tileExpr).rasterType var tiles = buffer - .map(row => MosaicRasterTile.deserialize(row.asInstanceOf[InternalRow], expressionConfig.getCellIdType)) + .map(row => + MosaicRasterTile.deserialize( + row.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) + ) .sortBy(_.getParentPath) // If merging multiple index rasters, the index value is dropped val idx = if (tiles.map(_.getIndex).groupBy(identity).size == 1) tiles.head.getIndex else null var merged = MergeRasters.merge(tiles.map(_.getRaster)).flushCache() - // TODO: should parent path be an array? - val parentPath = tiles.head.getParentPath - val driver = tiles.head.getDriver - val result = MosaicRasterTile(idx, merged, parentPath, driver) + val result = MosaicRasterTile(idx, merged) .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(BinaryType, expressionConfig.getRasterCheckpoint) + .serialize(BinaryType) tiles.foreach(RasterCleaner.dispose(_)) RasterCleaner.dispose(merged) @@ -103,7 +107,7 @@ case class RST_MergeAgg( buffer } - override protected def withNewChildInternal(newChild: Expression): RST_MergeAgg = copy(rasterExpr = newChild) + override protected def withNewChildInternal(newChild: Expression): RST_MergeAgg = copy(tileExpr = newChild) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala index 72236e8ba..3b6bfaf78 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_MetaData.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the metadata of the raster. */ case class RST_MetaData(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_MetaData](raster, MapType(StringType, StringType), returnsRaster = false, expressionConfig) + extends RasterExpression[RST_MetaData](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, StringType) + /** Returns the metadata of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = buildMapString(tile.getRaster.metadata) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala new file mode 100644 index 000000000..ea62e106f --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Min.scala @@ -0,0 +1,49 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + + +/** Returns the upper left x of the raster. */ +case class RST_Min(raster: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Min](raster, returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = ArrayType(DoubleType) + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + val nBands = tile.raster.raster.GetRasterCount() + val minValues = (1 to nBands).map(tile.raster.getBand(_).minPixelValue) + ArrayData.toArrayData(minValues.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Min extends WithExpressionInfo { + + override def name: String = "rst_min" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing min values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Min](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala index 5f4ca5743..2d7b55fd7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala @@ -9,24 +9,26 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** The expression for computing NDVI index. */ case class RST_NDVI( - rastersExpr: Expression, + tileExpr: Expression, redIndex: Expression, nirIndex: Expression, expressionConfig: MosaicExpressionConfig ) extends Raster2ArgExpression[RST_NDVI]( - rastersExpr, + tileExpr, redIndex, nirIndex, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Computes NDVI index. * @param tile diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala index 5fad0186f..383cf6d73 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NumBands.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the number of bands in the raster. */ case class RST_NumBands(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_NumBands](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_NumBands](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the number of bands in the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.numBands diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala new file mode 100644 index 000000000..b2543a87e --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCount.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.RasterExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + +/** Returns the upper left x of the raster. */ +case class RST_PixelCount(rasterExpr: Expression, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_PixelCount](rasterExpr, returnsRaster = false, expressionConfig) + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = ArrayType(LongType) + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile): Any = { + val bandCount = tile.raster.raster.GetRasterCount() + val pixelCount = (1 to bandCount).map(tile.raster.getBand(_).pixelCount) + ArrayData.toArrayData(pixelCount.toArray) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_PixelCount extends WithExpressionInfo { + + override def name: String = "rst_pixelcount" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing valid pixel count values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [12, 212, 313] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_PixelCount](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala index 63e060552..13c717a2e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelHeight.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the pixel height of the raster. */ case class RST_PixelHeight(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelHeight](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_PixelHeight](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the pixel height of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val gt = tile.getRaster.getGeoTransform diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala index 8fc18f759..f1b3e6cee 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelWidth.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the pixel width of the raster. */ case class RST_PixelWidth(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_PixelWidth](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_PixelWidth](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the pixel width of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val gt = tile.getRaster.getGeoTransform diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala index 0fba20cca..5f5e5beb1 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala @@ -17,10 +17,12 @@ case class RST_RasterToWorldCoord( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoord](raster, x, y, StringType, returnsRaster = false, expressionConfig = expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoord](raster, x, y, returnsRaster = false, expressionConfig = expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = StringType + /** * Returns the world coordinates of the raster (x,y) pixel by applying * GeoTransform. This ensures the projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala index 719689a90..0e31ae78a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala @@ -16,10 +16,12 @@ case class RST_RasterToWorldCoordX( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoordX](raster, x, y, DoubleType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoordX](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** * Returns the world coordinates of the raster x pixel by applying * GeoTransform. This ensures the projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala index 661e07fcf..5c4835452 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala @@ -16,10 +16,12 @@ case class RST_RasterToWorldCoordY( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_RasterToWorldCoordY](raster, x, y, DoubleType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_RasterToWorldCoordY](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** * Returns the world coordinates of the raster y pixel by applying * GeoTransform. This ensures the projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala index 1222c7e01..cc1b0e0ee 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ReTile.scala @@ -8,6 +8,7 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** * Returns a set of new rasters with the specified tile size (tileWidth x @@ -22,6 +23,8 @@ case class RST_ReTile( with NullIntolerant with CodegenFallback { + override def dataType: DataType = rasterExpr.dataType + /** * Returns a set of new rasters with the specified tile size (tileWidth x * tileHeight). diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala index 17dd5a937..f6268abbf 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Rotation.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the rotation angle of the raster. */ case class RST_Rotation(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Rotation](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Rotation](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the rotation angle of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val gt = tile.getRaster.getRaster.GetGeoTransform() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala index b1bf01464..aa708609f 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SRID.scala @@ -14,10 +14,12 @@ import scala.util.Try /** Returns the SRID of the raster. */ case class RST_SRID(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SRID](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_SRID](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the SRID of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { // Reference: https://gis.stackexchange.com/questions/267321/extracting-epsg-from-a-raster-using-gdal-bindings-in-python diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala index 0713b4ce8..19ed7efdc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleX.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the scale x of the raster. */ case class RST_ScaleX(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_ScaleX](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_ScaleX](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the scale x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(1) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala index 3937fff4c..019f963e8 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_ScaleY.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the scale y of the raster. */ case class RST_ScaleY(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_ScaleY](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_ScaleY](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the scale y of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(5) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala index bec5c3227..c42af9c2e 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetNoData.scala @@ -12,22 +12,24 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.DataType /** Returns a raster with the specified no data values. */ case class RST_SetNoData( - rastersExpr: Expression, + tileExpr: Expression, noDataExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends Raster1ArgExpression[RST_SetNoData]( - rastersExpr, + tileExpr, noDataExpr, - RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + /** * Returns a raster with the specified no data values. * @param tile diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetSRID.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetSRID.scala index a3d44289a..2aabb3df9 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetSRID.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SetSRID.scala @@ -1,8 +1,6 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI -import com.databricks.labs.mosaic.core.raster.io.RasterCleaner -import com.databricks.labs.mosaic.core.raster.operator.clip.RasterClipByVector import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} @@ -11,6 +9,7 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** The expression for clipping a raster by a vector. */ case class RST_SetSRID( @@ -19,14 +18,15 @@ case class RST_SetSRID( expressionConfig: MosaicExpressionConfig ) extends Raster1ArgExpression[RST_SetSRID]( rastersExpr, - sridExpr, - RasterTileType(expressionConfig.getCellIdType), + sridExpr, returnsRaster = true, expressionConfig = expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, rastersExpr) + val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) /** diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala index a3d0107fd..8f049f280 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewX.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the skew x of the raster. */ case class RST_SkewX(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SkewX](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_SkewX](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the skew x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(2) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala index 2b8db20b2..100e8fde7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_SkewY.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the skew y of the raster. */ case class RST_SkewY(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_SkewY](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_SkewY](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the skew y of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(4) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala index 1e17c2e25..091efcc84 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Subdatasets.scala @@ -13,13 +13,14 @@ import org.apache.spark.sql.types._ case class RST_Subdatasets(raster: Expression, expressionConfig: MosaicExpressionConfig) extends RasterExpression[RST_Subdatasets]( raster, - MapType(StringType, StringType), returnsRaster = false, expressionConfig ) with NullIntolerant with CodegenFallback { + override def dataType: DataType = MapType(StringType, StringType) + /** Returns the subdatasets of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = buildMapString(tile.getRaster.subdatasets) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala index b115b0973..3fa1aec51 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Summary.scala @@ -16,10 +16,12 @@ import java.util.{Vector => JVector} /** Returns the summary info the raster. */ case class RST_Summary(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Summary](raster, StringType, returnsRaster = false, expressionConfig: MosaicExpressionConfig) + extends RasterExpression[RST_Summary](raster, returnsRaster = false, expressionConfig: MosaicExpressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = StringType + /** Returns the summary info the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { val vector = new JVector[String]() diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Transform.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Transform.scala new file mode 100644 index 000000000..7681f2bba --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Transform.scala @@ -0,0 +1,61 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.raster.operator.proj.RasterProject +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile +import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} +import com.databricks.labs.mosaic.expressions.raster.base.Raster1ArgExpression +import com.databricks.labs.mosaic.functions.MosaicExpressionConfig +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types._ +import org.gdal.osr.SpatialReference + +/** Returns the upper left x of the raster. */ +case class RST_Transform( + tileExpr: Expression, + srid: Expression, + expressionConfig: MosaicExpressionConfig +) extends Raster1ArgExpression[RST_Transform]( + tileExpr, + srid, + returnsRaster = true, + expressionConfig + ) + with NullIntolerant + with CodegenFallback { + + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) + + /** Returns the upper left x of the raster. */ + override def rasterTransform(tile: MosaicRasterTile, arg1: Any): Any = { + val srid = arg1.asInstanceOf[Int] + val sReff = new SpatialReference() + sReff.ImportFromEPSG(srid) + sReff.SetAxisMappingStrategy(org.gdal.osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) + val result = RasterProject.project(tile.raster, sReff) + tile.copy(raster = result) + } + +} + +/** Expression info required for the expression registration for spark SQL. */ +object RST_Transform extends WithExpressionInfo { + + override def name: String = "rst_transform" + + override def usage: String = "_FUNC_(expr1) - Returns an array containing mean values for each band." + + override def example: String = + """ + | Examples: + | > SELECT _FUNC_(raster_tile); + | [1.123, 2.123, 3.123] + | """.stripMargin + + override def builder(expressionConfig: MosaicExpressionConfig): FunctionBuilder = { + GenericExpressionFactory.getBaseBuilder[RST_Avg](1, expressionConfig) + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala index 2af8cf7a5..206ed9c7a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns true if the raster is empty. */ case class RST_TryOpen(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_TryOpen](raster, BooleanType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_TryOpen](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = BooleanType + /** Returns true if the raster can be opened. */ override def rasterTransform(tile: MosaicRasterTile): Any = { Option(tile.getRaster.getRaster).isDefined diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala index 3ff0fdd67..4481149fa 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftX.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left x of the raster. */ case class RST_UpperLeftX(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_UpperLeftX](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_UpperLeftX](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the upper left x of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(0) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala index 5ef56ba96..dc512a84a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_UpperLeftY.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the upper left y of the raster. */ case class RST_UpperLeftY(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_UpperLeftY](raster, DoubleType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_UpperLeftY](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = DoubleType + /** Returns the upper left y of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = { tile.getRaster.getRaster.GetGeoTransform()(3) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala index 0dfc596b6..5543c1b81 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_Width.scala @@ -11,10 +11,12 @@ import org.apache.spark.sql.types._ /** Returns the width of the raster. */ case class RST_Width(raster: Expression, expressionConfig: MosaicExpressionConfig) - extends RasterExpression[RST_Width](raster, IntegerType, returnsRaster = false, expressionConfig) + extends RasterExpression[RST_Width](raster, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = IntegerType + /** Returns the width of the raster. */ override def rasterTransform(tile: MosaicRasterTile): Any = tile.getRaster.xSize diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala index 0de5cb009..2d943a120 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala @@ -9,6 +9,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.types.DataType /** Returns the world coordinate of the raster. */ case class RST_WorldToRasterCoord( @@ -16,10 +17,12 @@ case class RST_WorldToRasterCoord( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoord](raster, x, y, PixelCoordsType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoord](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: DataType = PixelCoordsType + /** * Returns the x and y of the raster by applying GeoTransform as a tuple of * Integers. This will ensure projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala index 7f23d7c9e..43ea89de3 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala @@ -16,10 +16,12 @@ case class RST_WorldToRasterCoordX( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoordX](raster, x, y, IntegerType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoordX](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: IntegerType = IntegerType + /** * Returns the x coordinate of the raster by applying GeoTransform. This * will ensure projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala index 3e82f9614..93c0c27e7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala @@ -16,10 +16,12 @@ case class RST_WorldToRasterCoordY( x: Expression, y: Expression, expressionConfig: MosaicExpressionConfig -) extends Raster2ArgExpression[RST_WorldToRasterCoordY](raster, x, y, IntegerType, returnsRaster = false, expressionConfig) +) extends Raster2ArgExpression[RST_WorldToRasterCoordY](raster, x, y, returnsRaster = false, expressionConfig) with NullIntolerant with CodegenFallback { + override def dataType: IntegerType = IntegerType + /** * Returns the y coordinate of the raster by applying GeoTransform. This * will ensure projection of the raster is respected. diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala index f01027ff1..35ad927c6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster1ArgExpression.scala @@ -2,12 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} -import org.apache.spark.sql.types.DataType import scala.reflect.ClassTag @@ -21,8 +21,6 @@ import scala.reflect.ClassTag * containing the raster file content. * @param arg1Expr * The expression for the first argument. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -31,7 +29,6 @@ import scala.reflect.ClassTag abstract class Raster1ArgExpression[T <: Expression: ClassTag]( rasterExpr: Expression, arg1Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression @@ -43,9 +40,6 @@ abstract class Raster1ArgExpression[T <: Expression: ClassTag]( override def right: Expression = arg1Expr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster and the arguments to @@ -75,10 +69,15 @@ abstract class Raster1ArgExpression[T <: Expression: ClassTag]( // noinspection DuplicatedCode override def nullSafeEval(input: Any, arg1: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + input.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) val raster = tile.getRaster val result = rasterTransform(tile, arg1) - val serialized = serialize(result, returnsRaster, outputType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) RasterCleaner.dispose(raster) RasterCleaner.dispose(result) serialized diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala index ccdc7d5b3..c5be60724 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala @@ -2,6 +2,7 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -22,8 +23,6 @@ import scala.reflect.ClassTag * The expression for the first argument. * @param arg2Expr * The expression for the second argument. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -33,7 +32,6 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( rasterExpr: Expression, arg1Expr: Expression, arg2Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends TernaryExpression @@ -47,9 +45,6 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( override def third: Expression = arg2Expr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster and the arguments to @@ -83,9 +78,14 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( // noinspection DuplicatedCode override def nullSafeEval(input: Any, arg1: Any, arg2: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + input.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) val result = rasterTransform(tile, arg1, arg2) - val serialized = serialize(result, returnsRaster, outputType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) // passed by name makes things re-evaluated RasterCleaner.dispose(tile) serialized diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala index d21f96c2d..5dbfd08cc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray1ArgExpression.scala @@ -2,11 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.ArrayType import scala.reflect.ClassTag @@ -18,8 +19,6 @@ import scala.reflect.ClassTag * @param rastersExpr * The rasters expression. It is an array column containing rasters as either * paths or as content byte arrays. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -28,7 +27,6 @@ import scala.reflect.ClassTag abstract class RasterArray1ArgExpression[T <: Expression: ClassTag]( rastersExpr: Expression, arg1Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression @@ -36,9 +34,6 @@ abstract class RasterArray1ArgExpression[T <: Expression: ClassTag]( with Serializable with RasterExpressionSerialization { - /** Output Data Type */ - override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType - override def left: Expression = rastersExpr override def right: Expression = arg1Expr @@ -72,7 +67,8 @@ abstract class RasterArray1ArgExpression[T <: Expression: ClassTag]( GDAL.enable(expressionConfig) val tiles = RasterArrayUtils.getTiles(input, rastersExpr, expressionConfig) val result = rasterTransform(tiles, arg1) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val resultType = if (returnsRaster) RasterTileType(rastersExpr).rasterType else dataType + val serialized = serialize(result, returnsRaster, resultType, expressionConfig) tiles.foreach(t => RasterCleaner.dispose(t)) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala index a26082f2d..9de963684 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArray2ArgExpression.scala @@ -2,11 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, TernaryExpression} -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.ArrayType import scala.reflect.ClassTag @@ -18,8 +19,6 @@ import scala.reflect.ClassTag * @param rastersExpr * The rasters expression. It is an array column containing rasters as either * paths or as content byte arrays. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -29,7 +28,6 @@ abstract class RasterArray2ArgExpression[T <: Expression: ClassTag]( rastersExpr: Expression, arg1Expr: Expression, arg2Expr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends TernaryExpression @@ -37,9 +35,6 @@ abstract class RasterArray2ArgExpression[T <: Expression: ClassTag]( with Serializable with RasterExpressionSerialization { - /** Output Data Type */ - override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType - override def first: Expression = rastersExpr override def second: Expression = arg1Expr @@ -77,7 +72,8 @@ abstract class RasterArray2ArgExpression[T <: Expression: ClassTag]( GDAL.enable(expressionConfig) val tiles = RasterArrayUtils.getTiles(input, rastersExpr, expressionConfig) val result = rasterTransform(tiles, arg1, arg2) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val resultType = if (returnsRaster) RasterTileType(rastersExpr).rasterType else dataType + val serialized = serialize(result, returnsRaster, resultType, expressionConfig) tiles.foreach(t => RasterCleaner.dispose(t)) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala index b8ad9fc12..8c3a52d9a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayExpression.scala @@ -2,11 +2,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant, UnaryExpression} -import org.apache.spark.sql.types.{ArrayType, DataType} +import org.apache.spark.sql.types.ArrayType import scala.reflect.ClassTag @@ -27,7 +28,6 @@ import scala.reflect.ClassTag */ abstract class RasterArrayExpression[T <: Expression: ClassTag]( rastersExpr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends UnaryExpression @@ -37,9 +37,6 @@ abstract class RasterArrayExpression[T <: Expression: ClassTag]( override def child: Expression = rastersExpr - /** Output Data Type */ - override def dataType: DataType = if (returnsRaster) rastersExpr.dataType.asInstanceOf[ArrayType].elementType else outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the rasters to the expression. @@ -67,7 +64,8 @@ abstract class RasterArrayExpression[T <: Expression: ClassTag]( GDAL.enable(expressionConfig) val tiles = RasterArrayUtils.getTiles(input, rastersExpr, expressionConfig) val result = rasterTransform(tiles) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val resultType = if (returnsRaster) RasterTileType(rastersExpr).rasterType else dataType + val serialized = serialize(result, returnsRaster, resultType, expressionConfig) tiles.foreach(t => RasterCleaner.dispose(t)) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala index 3162bb421..f2d399350 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterArrayUtils.scala @@ -1,5 +1,6 @@ package com.databricks.labs.mosaic.expressions.raster.base +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow @@ -12,11 +13,16 @@ object RasterArrayUtils { def getTiles(input: Any, rastersExpr: Expression, expressionConfig: MosaicExpressionConfig): Seq[MosaicRasterTile] = { val rasterDT = rastersExpr.dataType.asInstanceOf[ArrayType].elementType val arrayData = input.asInstanceOf[ArrayData] + val rasterType = RasterTileType(rastersExpr).rasterType val n = arrayData.numElements() (0 until n) .map(i => MosaicRasterTile - .deserialize(arrayData.get(i, rasterDT).asInstanceOf[InternalRow], expressionConfig.getCellIdType) + .deserialize( + arrayData.get(i, rasterDT).asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) ) } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala index 7cee607ca..97bd3e333 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterBandExpression.scala @@ -3,12 +3,12 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterBandGDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, NullIntolerant} -import org.apache.spark.sql.types.DataType import scala.reflect.ClassTag @@ -23,8 +23,6 @@ import scala.reflect.ClassTag * MOSAIC_RASTER_STORAGE is set to MOSAIC_RASTER_STORAGE_BYTE. * @param bandExpr * The expression for the band index. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -33,7 +31,6 @@ import scala.reflect.ClassTag abstract class RasterBandExpression[T <: Expression: ClassTag]( rasterExpr: Expression, bandExpr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends BinaryExpression @@ -45,9 +42,6 @@ abstract class RasterBandExpression[T <: Expression: ClassTag]( override def right: Expression = bandExpr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster band to the @@ -79,13 +73,18 @@ abstract class RasterBandExpression[T <: Expression: ClassTag]( // noinspection DuplicatedCode override def nullSafeEval(inputRaster: Any, inputBand: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(inputRaster.asInstanceOf[InternalRow], expressionConfig.getCellIdType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + inputRaster.asInstanceOf[InternalRow], + expressionConfig.getCellIdType, + rasterType + ) val bandIndex = inputBand.asInstanceOf[Int] val band = tile.getRaster.getBand(bandIndex) val result = bandTransform(tile, band) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) RasterCleaner.dispose(tile) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala index 462d3204b..66435f101 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpression.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.expressions.raster.base import com.databricks.labs.mosaic.core.index.{IndexSystem, IndexSystemFactory} import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.io.RasterCleaner +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.GenericExpressionFactory import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -20,8 +21,6 @@ import scala.reflect.ClassTag * The expression for the raster. If the raster is stored on disc, the path * to the raster is provided. If the raster is stored in memory, the bytes of * the raster are provided. - * @param outputType - * The output type of the result. * @param expressionConfig * Additional arguments for the expression (expressionConfigs). * @tparam T @@ -29,7 +28,6 @@ import scala.reflect.ClassTag */ abstract class RasterExpression[T <: Expression: ClassTag]( rasterExpr: Expression, - outputType: DataType, returnsRaster: Boolean, expressionConfig: MosaicExpressionConfig ) extends UnaryExpression @@ -43,9 +41,6 @@ abstract class RasterExpression[T <: Expression: ClassTag]( override def child: Expression = rasterExpr - /** Output Data Type */ - override def dataType: DataType = outputType - /** * The function to be overridden by the extending class. It is called when * the expression is evaluated. It provides the raster to the expression. @@ -69,9 +64,14 @@ abstract class RasterExpression[T <: Expression: ClassTag]( */ override def nullSafeEval(input: Any): Any = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], cellIdDataType) + val rasterType = RasterTileType(rasterExpr).rasterType + val tile = MosaicRasterTile.deserialize( + input.asInstanceOf[InternalRow], + cellIdDataType, + rasterType + ) val result = rasterTransform(tile) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val serialized = serialize(result, returnsRaster, rasterType, expressionConfig) RasterCleaner.dispose(tile) serialized } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala index a9bf17917..dc04cb1c7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterExpressionSerialization.scala @@ -35,11 +35,9 @@ trait RasterExpressionSerialization { ): Any = { if (returnsRaster) { val tile = data.asInstanceOf[MosaicRasterTile] - val checkpoint = expressionConfig.getRasterCheckpoint - val rasterType = outputDataType.asInstanceOf[StructType].fields(1).dataType val result = tile .formatCellId(IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem)) - .serialize(rasterType, checkpoint) + .serialize(outputDataType) RasterCleaner.dispose(tile) result } else { diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala index 29c714788..3fc80752d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterGeneratorExpression.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag * rasters based on the input raster. The new rasters are written in the * checkpoint directory. The files are written as GeoTiffs. Subdatasets are not * supported, please flatten beforehand. - * @param rasterExpr + * @param tileExpr * The expression for the raster. If the raster is stored on disc, the path * to the raster is provided. If the raster is stored in memory, the bytes of * the raster are provided. @@ -32,13 +32,13 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class RasterGeneratorExpression[T <: Expression: ClassTag]( - rasterExpr: Expression, + tileExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends CollectionGenerator with NullIntolerant with Serializable { - override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType) + override def dataType: DataType = RasterTileType(expressionConfig.getCellIdType, tileExpr) val uuid: String = java.util.UUID.randomUUID().toString.replace("-", "_") @@ -72,11 +72,12 @@ abstract class RasterGeneratorExpression[T <: Expression: ClassTag]( override def eval(input: InternalRow): TraversableOnce[InternalRow] = { GDAL.enable(expressionConfig) - val tile = MosaicRasterTile.deserialize(rasterExpr.eval(input).asInstanceOf[InternalRow], cellIdDataType) + val rasterType = RasterTileType(tileExpr).rasterType + val tile = MosaicRasterTile.deserialize(tileExpr.eval(input).asInstanceOf[InternalRow], cellIdDataType, rasterType) val generatedRasters = rasterGenerator(tile) // Writing rasters disposes of the written raster - val rows = generatedRasters.map(_.formatCellId(indexSystem).serialize()) + val rows = generatedRasters.map(_.formatCellId(indexSystem).serialize(rasterType)) generatedRasters.foreach(gr => RasterCleaner.dispose(gr)) RasterCleaner.dispose(tile) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala index f2545942b..98ff86ca7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterTessellateGeneratorExpression.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag * checkpoint directory. The files are written as GeoTiffs. Subdatasets are not * supported, please flatten beforehand. * - * @param rasterExpr + * @param tileExpr * The expression for the raster. If the raster is stored on disc, the path * to the raster is provided. If the raster is stored in memory, the bytes of * the raster are provided. @@ -33,7 +33,7 @@ import scala.reflect.ClassTag * The type of the extending class. */ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( - rasterExpr: Expression, + tileExpr: Expression, resolutionExpr: Expression, expressionConfig: MosaicExpressionConfig ) extends CollectionGenerator @@ -55,7 +55,8 @@ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( * needs to be wrapped in a StructType. The actually type is that of the * structs element. */ - override def elementSchema: StructType = StructType(Array(StructField("element", RasterTileType(indexSystem.getCellIdDataType)))) + override def elementSchema: StructType = + StructType(Array(StructField("element", RasterTileType(indexSystem.getCellIdDataType, tileExpr)))) /** * The function to be overridden by the extending class. It is called when @@ -71,17 +72,15 @@ abstract class RasterTessellateGeneratorExpression[T <: Expression: ClassTag]( override def eval(input: InternalRow): TraversableOnce[InternalRow] = { GDAL.enable(expressionConfig) + val rasterType = RasterTileType(tileExpr).rasterType val tile = MosaicRasterTile - .deserialize( - rasterExpr.eval(input).asInstanceOf[InternalRow], - indexSystem.getCellIdDataType - ) + .deserialize(tileExpr.eval(input).asInstanceOf[InternalRow], indexSystem.getCellIdDataType, rasterType) val inResolution: Int = indexSystem.getResolution(resolutionExpr.eval(input)) val generatedChips = rasterGenerator(tile, inResolution) .map(chip => chip.formatCellId(indexSystem)) val rows = generatedChips - .map(chip => InternalRow.fromSeq(Seq(chip.formatCellId(indexSystem).serialize()))) + .map(chip => InternalRow.fromSeq(Seq(chip.formatCellId(indexSystem).serialize(rasterType)))) RasterCleaner.dispose(tile) generatedChips.foreach(chip => RasterCleaner.dispose(chip)) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala index 743f9cbd6..e7b04f989 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/RasterToGridExpression.scala @@ -37,11 +37,13 @@ abstract class RasterToGridExpression[T <: Expression: ClassTag, P]( resolution: Expression, measureType: DataType, expressionConfig: MosaicExpressionConfig -) extends Raster1ArgExpression[T](rasterExpr, resolution, RasterToGridType(expressionConfig.getCellIdType, measureType), returnsRaster = false, expressionConfig) +) extends Raster1ArgExpression[T](rasterExpr, resolution, returnsRaster = false, expressionConfig) with RasterGridExpression with NullIntolerant with Serializable { + override def dataType: DataType = RasterToGridType(expressionConfig.getCellIdType, measureType) + /** The index system to be used. */ val indexSystem: IndexSystem = IndexSystemFactory.getIndexSystem(expressionConfig.getIndexSystem) val geometryAPI: GeometryAPI = GeometryAPI(expressionConfig.getGeometryAPI) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala index a229aae89..7db83db25 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/package.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.expressions -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, ArrayBasedMapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, ArrayBasedMapData, ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -21,8 +21,8 @@ package object raster { * The measure type of the resulting pixel value. * * @return - * The datatype to be used for serialization of the result of - * [[com.databricks.labs.mosaic.expressions.raster.base.RasterToGridExpression]]. + * The datatype to be used for serialization of the result of + * [[com.databricks.labs.mosaic.expressions.raster.base.RasterToGridExpression]]. */ def RasterToGridType(cellIDType: DataType, measureType: DataType): DataType = { ArrayType( @@ -49,6 +49,19 @@ package object raster { mapBuilder.build() } + /** + * Extracts a scala Map[String, String] from a spark map. + * @param mapData + * The map to be used. + * @return + * Deserialized map. + */ + def extractMap(mapData: MapData): Map[String, String] = { + val keys = mapData.keyArray().toArray[UTF8String](StringType).map(_.toString) + val values = mapData.valueArray().toArray[UTF8String](StringType).map(_.toString) + keys.zip(values).toMap + } + /** * Builds a spark map from a scala Map[String, Double]. * @param metaData diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala index 59f1cf9d5..7557306ec 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicContext.scala @@ -8,20 +8,20 @@ import com.databricks.labs.mosaic.core.types.ChipType import com.databricks.labs.mosaic.datasource.multiread.MosaicDataFrameReader import com.databricks.labs.mosaic.expressions.constructors._ import com.databricks.labs.mosaic.expressions.format._ -import com.databricks.labs.mosaic.expressions.geometry._ import com.databricks.labs.mosaic.expressions.geometry.ST_MinMaxXYZ._ +import com.databricks.labs.mosaic.expressions.geometry._ import com.databricks.labs.mosaic.expressions.index._ import com.databricks.labs.mosaic.expressions.raster._ import com.databricks.labs.mosaic.expressions.util.TrySql +import com.databricks.labs.mosaic.utils.FileUtils import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{LongType, StringType} +import org.apache.spark.sql.{Column, SparkSession} -import java.nio.file.Files import scala.reflect.runtime.universe //noinspection DuplicatedCode @@ -258,17 +258,24 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ) /** RasterAPI dependent functions */ + mosaicRegistry.registerExpression[RST_Avg](expressionConfig) mosaicRegistry.registerExpression[RST_BandMetaData](expressionConfig) mosaicRegistry.registerExpression[RST_BoundingBox](expressionConfig) mosaicRegistry.registerExpression[RST_Clip](expressionConfig) mosaicRegistry.registerExpression[RST_CombineAvg](expressionConfig) + mosaicRegistry.registerExpression[RST_Convolve](expressionConfig) mosaicRegistry.registerExpression[RST_DerivedBand](expressionConfig) + mosaicRegistry.registerExpression[RST_Filter](expressionConfig) mosaicRegistry.registerExpression[RST_GeoReference](expressionConfig) mosaicRegistry.registerExpression[RST_GetNoData](expressionConfig) mosaicRegistry.registerExpression[RST_GetSubdataset](expressionConfig) mosaicRegistry.registerExpression[RST_Height](expressionConfig) mosaicRegistry.registerExpression[RST_InitNoData](expressionConfig) mosaicRegistry.registerExpression[RST_IsEmpty](expressionConfig) + mosaicRegistry.registerExpression[RST_MakeTiles](expressionConfig) + mosaicRegistry.registerExpression[RST_Max](expressionConfig) + mosaicRegistry.registerExpression[RST_Min](expressionConfig) + mosaicRegistry.registerExpression[RST_Median](expressionConfig) mosaicRegistry.registerExpression[RST_MemSize](expressionConfig) mosaicRegistry.registerExpression[RST_Merge](expressionConfig) mosaicRegistry.registerExpression[RST_FromBands](expressionConfig) @@ -278,6 +285,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[RST_NumBands](expressionConfig) mosaicRegistry.registerExpression[RST_PixelWidth](expressionConfig) mosaicRegistry.registerExpression[RST_PixelHeight](expressionConfig) + mosaicRegistry.registerExpression[RST_PixelCount](expressionConfig) mosaicRegistry.registerExpression[RST_RasterToGridAvg](expressionConfig) mosaicRegistry.registerExpression[RST_RasterToGridMax](expressionConfig) mosaicRegistry.registerExpression[RST_RasterToGridMin](expressionConfig) @@ -299,6 +307,7 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends mosaicRegistry.registerExpression[RST_Subdatasets](expressionConfig) mosaicRegistry.registerExpression[RST_Summary](expressionConfig) mosaicRegistry.registerExpression[RST_Tessellate](expressionConfig) + mosaicRegistry.registerExpression[RST_Transform](expressionConfig) mosaicRegistry.registerExpression[RST_FromContent](expressionConfig) mosaicRegistry.registerExpression[RST_FromFile](expressionConfig) mosaicRegistry.registerExpression[RST_ToOverlappingTiles](expressionConfig) @@ -653,9 +662,16 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends ColumnAdapter(RST_BandMetaData(raster.expr, lit(band).expr, expressionConfig)) def rst_boundingbox(raster: Column): Column = ColumnAdapter(RST_BoundingBox(raster.expr, expressionConfig)) def rst_clip(raster: Column, geometry: Column): Column = ColumnAdapter(RST_Clip(raster.expr, geometry.expr, expressionConfig)) + def rst_convolve(raster: Column, kernel: Column): Column = + ColumnAdapter(RST_Convolve(raster.expr, kernel.expr, expressionConfig)) + def rst_pixelcount(raster: Column): Column = ColumnAdapter(RST_PixelCount(raster.expr, expressionConfig)) def rst_combineavg(rasterArray: Column): Column = ColumnAdapter(RST_CombineAvg(rasterArray.expr, expressionConfig)) def rst_derivedband(raster: Column, pythonFunc: Column, funcName: Column): Column = ColumnAdapter(RST_DerivedBand(raster.expr, pythonFunc.expr, funcName.expr, expressionConfig)) + def rst_filter(raster: Column, kernelSize: Column, operation: Column): Column = + ColumnAdapter(RST_Filter(raster.expr, kernelSize.expr, operation.expr, expressionConfig)) + def rst_filter(raster: Column, kernelSize: Int, operation: String): Column = + ColumnAdapter(RST_Filter(raster.expr, lit(kernelSize).expr, lit(operation).expr, expressionConfig)) def rst_georeference(raster: Column): Column = ColumnAdapter(RST_GeoReference(raster.expr, expressionConfig)) def rst_getnodata(raster: Column): Column = ColumnAdapter(RST_GetNoData(raster.expr, expressionConfig)) def rst_getsubdataset(raster: Column, subdatasetName: Column): Column = @@ -665,6 +681,18 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_height(raster: Column): Column = ColumnAdapter(RST_Height(raster.expr, expressionConfig)) def rst_initnodata(raster: Column): Column = ColumnAdapter(RST_InitNoData(raster.expr, expressionConfig)) def rst_isempty(raster: Column): Column = ColumnAdapter(RST_IsEmpty(raster.expr, expressionConfig)) + def rst_maketiles(input: Column, driver: Column, size: Column, withCheckpoint: Column): Column = + ColumnAdapter(RST_MakeTiles(input.expr, driver.expr, size.expr, withCheckpoint.expr, expressionConfig)) + def rst_maketiles(input: Column, driver: String, size: Int, withCheckpoint: Boolean): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(size).expr, lit(withCheckpoint).expr, expressionConfig)) + def rst_maketiles(input: Column, driver: String, size: Int): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit(driver).expr, lit(size).expr, lit(false).expr, expressionConfig)) + def rst_maketiles(input: Column): Column = + ColumnAdapter(RST_MakeTiles(input.expr, lit("no_driver").expr, lit(-1).expr, lit(false).expr, expressionConfig)) + def rst_max(raster: Column): Column = ColumnAdapter(RST_Max(raster.expr, expressionConfig)) + def rst_min(raster: Column): Column = ColumnAdapter(RST_Min(raster.expr, expressionConfig)) + def rst_median(raster: Column): Column = ColumnAdapter(RST_Median(raster.expr, expressionConfig)) + def rst_avg(raster: Column): Column = ColumnAdapter(RST_Avg(raster.expr, expressionConfig)) def rst_memsize(raster: Column): Column = ColumnAdapter(RST_MemSize(raster.expr, expressionConfig)) def rst_frombands(bandsArray: Column): Column = ColumnAdapter(RST_FromBands(bandsArray.expr, expressionConfig)) def rst_merge(rasterArray: Column): Column = ColumnAdapter(RST_Merge(rasterArray.expr, expressionConfig)) @@ -720,6 +748,8 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends def rst_summary(raster: Column): Column = ColumnAdapter(RST_Summary(raster.expr, expressionConfig)) def rst_tessellate(raster: Column, resolution: Column): Column = ColumnAdapter(RST_Tessellate(raster.expr, resolution.expr, expressionConfig)) + def rst_transform(raster: Column, srid: Column): Column = + ColumnAdapter(RST_Transform(raster.expr, srid.expr, expressionConfig)) def rst_tessellate(raster: Column, resolution: Int): Column = ColumnAdapter(RST_Tessellate(raster.expr, lit(resolution).expr, expressionConfig)) def rst_fromcontent(raster: Column, driver: Column): Column = @@ -993,11 +1023,20 @@ class MosaicContext(indexSystem: IndexSystem, geometryAPI: GeometryAPI) extends object MosaicContext extends Logging { - val tmpDir: String = Files.createTempDirectory("mosaic").toAbsolutePath.toString - + var _tmpDir: String = "" val mosaicVersion: String = "0.4.0" private var instance: Option[MosaicContext] = None + + def tmpDir(mosaicConfig: MosaicExpressionConfig): String = { + if (_tmpDir == "" || mosaicConfig != null) { + val prefix = mosaicConfig.getTmpPrefix + _tmpDir = FileUtils.createMosaicTempDir(prefix) + _tmpDir + } else { + _tmpDir + } + } def build(indexSystem: IndexSystem, geometryAPI: GeometryAPI): MosaicContext = { instance = Some(new MosaicContext(indexSystem, geometryAPI)) @@ -1043,7 +1082,7 @@ object MosaicContext extends Logging { } if (!isML && !isPhoton && !isTest) { - val msg = """|DEPRECATION ERROR: + val msg = """|DEPRECATION ERROR: | Please use a Databricks: | - Photon-enabled Runtime for performance benefits | - Runtime ML for spatial AI benefits diff --git a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala index f306d4e9c..b16b719cc 100644 --- a/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala +++ b/src/main/scala/com/databricks/labs/mosaic/functions/MosaicExpressionConfig.scala @@ -33,6 +33,10 @@ case class MosaicExpressionConfig(configs: Map[String, String]) { def getRasterCheckpoint: String = configs.getOrElse(MOSAIC_RASTER_CHECKPOINT, MOSAIC_RASTER_CHECKPOINT_DEFAULT) def getCellIdType: DataType = IndexSystemFactory.getIndexSystem(getIndexSystem).cellIdType + + def getRasterBlockSize: Int = configs.getOrElse(MOSAIC_RASTER_BLOCKSIZE, MOSAIC_RASTER_BLOCKSIZE_DEFAULT).toInt + + def getTmpPrefix: String = configs.getOrElse(MOSAIC_RASTER_TMP_PREFIX, "/tmp") def setGDALConf(conf: RuntimeConfig): MosaicExpressionConfig = { val toAdd = conf.getAll.filter(_._1.startsWith(MOSAIC_GDAL_PREFIX)) @@ -54,6 +58,10 @@ case class MosaicExpressionConfig(configs: Map[String, String]) { def setRasterCheckpoint(checkpoint: String): MosaicExpressionConfig = { MosaicExpressionConfig(configs + (MOSAIC_RASTER_CHECKPOINT -> checkpoint)) } + + def setTmpPrefix(prefix: String): MosaicExpressionConfig = { + MosaicExpressionConfig(configs + (MOSAIC_RASTER_TMP_PREFIX -> prefix)) + } def setConfig(key: String, value: String): MosaicExpressionConfig = { MosaicExpressionConfig(configs + (key -> value)) @@ -73,6 +81,7 @@ object MosaicExpressionConfig { .setGeometryAPI(spark.conf.get(MOSAIC_GEOMETRY_API, JTS.name)) .setIndexSystem(spark.conf.get(MOSAIC_INDEX_SYSTEM, H3.name)) .setRasterCheckpoint(spark.conf.get(MOSAIC_RASTER_CHECKPOINT, MOSAIC_RASTER_CHECKPOINT_DEFAULT)) + .setTmpPrefix(spark.conf.get(MOSAIC_RASTER_TMP_PREFIX, "/tmp")) .setGDALConf(spark.conf) } diff --git a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala index d256a9870..6cc928edd 100644 --- a/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala +++ b/src/main/scala/com/databricks/labs/mosaic/gdal/MosaicGDAL.scala @@ -1,11 +1,12 @@ package com.databricks.labs.mosaic.gdal +import com.databricks.labs.mosaic.MOSAIC_RASTER_BLOCKSIZE_DEFAULT import com.databricks.labs.mosaic.functions.{MosaicContext, MosaicExpressionConfig} import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.gdal.gdal.gdal +import org.gdal.osr.SpatialReference -import java.io.{BufferedInputStream, File, PrintWriter} import java.nio.file.{Files, Paths} import scala.language.postfixOps import scala.util.Try @@ -22,9 +23,22 @@ object MosaicGDAL extends Logging { private val libjniso3003Path = "/usr/lib/libgdalalljni.so.30.0.3" private val libogdisoPath = "/usr/lib/ogdi/4.1/libgdal.so" + val defaultBlockSize = 1024 + val vrtBlockSize = 128 // This is a must value for VRTs before GDAL 3.7 + var blockSize: Int = MOSAIC_RASTER_BLOCKSIZE_DEFAULT.toInt + // noinspection ScalaWeakerAccess val GDAL_ENABLED = "spark.mosaic.gdal.native.enabled" var isEnabled = false + var checkpointPath: String = _ + + // Only use this with GDAL rasters + val WSG84: SpatialReference = { + val wsg84 = new SpatialReference() + wsg84.ImportFromEPSG(4326) + wsg84.SetAxisMappingStrategy(org.gdal.osr.osrConstants.OAMS_TRADITIONAL_GIS_ORDER) + wsg84 + } /** Returns true if GDAL is enabled. */ def wasEnabled(spark: SparkSession): Boolean = @@ -32,16 +46,33 @@ object MosaicGDAL extends Logging { /** Configures the GDAL environment. */ def configureGDAL(mosaicConfig: MosaicExpressionConfig): Unit = { - val CPL_TMPDIR = MosaicContext.tmpDir - val GDAL_PAM_PROXY_DIR = MosaicContext.tmpDir + val CPL_TMPDIR = MosaicContext.tmpDir(mosaicConfig) + val GDAL_PAM_PROXY_DIR = MosaicContext.tmpDir(mosaicConfig) gdal.SetConfigOption("GDAL_VRT_ENABLE_PYTHON", "YES") - gdal.SetConfigOption("GDAL_DISABLE_READDIR_ON_OPEN", "EMPTY_DIR") + gdal.SetConfigOption("GDAL_DISABLE_READDIR_ON_OPEN", "TRUE") gdal.SetConfigOption("CPL_TMPDIR", CPL_TMPDIR) gdal.SetConfigOption("GDAL_PAM_PROXY_DIR", GDAL_PAM_PROXY_DIR) gdal.SetConfigOption("GDAL_PAM_ENABLED", "YES") gdal.SetConfigOption("CPL_VSIL_USE_TEMP_FILE_FOR_RANDOM_WRITE", "NO") gdal.SetConfigOption("CPL_LOG", s"$CPL_TMPDIR/gdal.log") + gdal.SetConfigOption("GDAL_CACHEMAX", "512") + gdal.SetConfigOption("GDAL_NUM_THREADS", "ALL_CPUS") mosaicConfig.getGDALConf.foreach { case (k, v) => gdal.SetConfigOption(k.split("\\.").last, v) } + setBlockSize(mosaicConfig) + checkpointPath = mosaicConfig.getRasterCheckpoint + } + + def setBlockSize(mosaicConfig: MosaicExpressionConfig): Unit = { + val blockSize = mosaicConfig.getRasterBlockSize + if (blockSize > 0) { + this.blockSize = blockSize + } + } + + def setBlockSize(size: Int): Unit = { + if (size > 0) { + this.blockSize = size + } } /** Enables the GDAL environment. */ @@ -91,18 +122,19 @@ object MosaicGDAL extends Logging { } } - /** Reads the resource bytes. */ - private def readResourceBytes(name: String): Array[Byte] = { - val bis = new BufferedInputStream(getClass.getResourceAsStream(name)) - try { Stream.continually(bis.read()).takeWhile(-1 !=).map(_.toByte).toArray } - finally bis.close() - } +// /** Reads the resource bytes. */ +// private def readResourceBytes(name: String): Array[Byte] = { +// val bis = new BufferedInputStream(getClass.getResourceAsStream(name)) +// try { Stream.continually(bis.read()).takeWhile(-1 !=).map(_.toByte).toArray } +// finally bis.close() +// } + +// /** Reads the resource lines. */ +// // noinspection SameParameterValue +// private def readResourceLines(name: String): Array[String] = { +// val bytes = readResourceBytes(name) +// val lines = new String(bytes).split("\n") +// lines +// } - /** Reads the resource lines. */ - // noinspection SameParameterValue - private def readResourceLines(name: String): Array[String] = { - val bytes = readResourceBytes(name) - val lines = new String(bytes).split("\n") - lines - } } diff --git a/src/main/scala/com/databricks/labs/mosaic/package.scala b/src/main/scala/com/databricks/labs/mosaic/package.scala index 58ee2f98e..86bdbcec7 100644 --- a/src/main/scala/com/databricks/labs/mosaic/package.scala +++ b/src/main/scala/com/databricks/labs/mosaic/package.scala @@ -21,13 +21,18 @@ package object mosaic { val MOSAIC_GDAL_PREFIX = "spark.databricks.labs.mosaic.gdal." val MOSAIC_GDAL_NATIVE = "spark.databricks.labs.mosaic.gdal.native" val MOSAIC_RASTER_CHECKPOINT = "spark.databricks.labs.mosaic.raster.checkpoint" - val MOSAIC_RASTER_CHECKPOINT_DEFAULT = "dbfs:/tmp/mosaic/raster/checkpoint" + val MOSAIC_RASTER_CHECKPOINT_DEFAULT = "/dbfs/tmp/mosaic/raster/checkpoint" + val MOSAIC_RASTER_TMP_PREFIX = "spark.databricks.labs.mosaic.raster.tmp.prefix" + val MOSAIC_RASTER_BLOCKSIZE = "spark.databricks.labs.mosaic.raster.blocksize" + val MOSAIC_RASTER_BLOCKSIZE_DEFAULT = "128" val MOSAIC_RASTER_READ_STRATEGY = "raster.read.strategy" val MOSAIC_RASTER_READ_IN_MEMORY = "in_memory" val MOSAIC_RASTER_READ_AS_PATH = "as_path" val MOSAIC_RASTER_RE_TILE_ON_READ = "retile_on_read" + val MOSAIC_NO_DRIVER = "no_driver" + def read: MosaicDataFrameReader = new MosaicDataFrameReader(SparkSession.builder().getOrCreate()) diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala new file mode 100644 index 000000000..0a881d785 --- /dev/null +++ b/src/main/scala/com/databricks/labs/mosaic/utils/FileUtils.scala @@ -0,0 +1,35 @@ +package com.databricks.labs.mosaic.utils + +import java.io.{BufferedInputStream, FileInputStream} +import java.nio.file.{Files, Paths} + +object FileUtils { + + def readBytes(path: String): Array[Byte] = { + val bufferSize = 1024 * 1024 // 1MB + val cleanPath = PathUtils.replaceDBFSTokens(path) + val inputStream = new BufferedInputStream(new FileInputStream(cleanPath)) + val buffer = new Array[Byte](bufferSize) + + var bytesRead = 0 + var bytes = Array.empty[Byte] + + while ({ + bytesRead = inputStream.read(buffer); bytesRead + } != -1) { + bytes = bytes ++ buffer.slice(0, bytesRead) + } + inputStream.close() + bytes + } + + def createMosaicTempDir(prefix: String = "/tmp"): String = { + val tempRoot = Paths.get(s"$prefix/mosaic_tmp/") + if (!Files.exists(tempRoot)) { + Files.createDirectories(tempRoot) + } + val tempDir = Files.createTempDirectory(tempRoot, "mosaic") + tempDir.toFile.getAbsolutePath + } + +} diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala index d48c03bfd..2f896c046 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/PathUtils.scala @@ -1,20 +1,23 @@ package com.databricks.labs.mosaic.utils -import com.databricks.labs.mosaic.core.raster.api.GDAL -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.functions.MosaicContext -import java.nio.file.{Files, Paths} +import java.nio.file.{Files, Path, Paths} +import scala.jdk.CollectionConverters._ object PathUtils { val NO_PATH_STRING = "no_path" - def getCleanPath(path: String): String = { - val cleanPath = path + def replaceDBFSTokens(path: String): String = { + path .replace("file:/", "/") .replace("dbfs:/Volumes", "/Volumes") - .replace("dbfs:/","/dbfs/") + .replace("dbfs:/", "/dbfs/") + } + + def getCleanPath(path: String): String = { + val cleanPath = replaceDBFSTokens(path) if (cleanPath.endsWith(".zip") || cleanPath.contains(".zip:")) { getZipPath(cleanPath) } else { @@ -47,7 +50,7 @@ object PathUtils { } def createTmpFilePath(extension: String): String = { - val tmpDir = MosaicContext.tmpDir + val tmpDir = MosaicContext.tmpDir(null) val uuid = java.util.UUID.randomUUID.toString val outPath = s"$tmpDir/raster_${uuid.replace("-", "_")}.$extension" Files.createDirectories(Paths.get(outPath).getParent) @@ -62,16 +65,60 @@ object PathUtils { result } - def copyToTmp(inPath: String): String = { - val copyFromPath = inPath - .replace("file:/", "/") - .replace("dbfs:/Volumes", "/Volumes") - .replace("dbfs:/","/dbfs/") - val driver = MosaicRasterGDAL.identifyDriver(getCleanPath(inPath)) - val extension = if (inPath.endsWith(".zip")) "zip" else GDAL.getExtension(driver) - val tmpPath = createTmpFilePath(extension) - Files.copy(Paths.get(copyFromPath), Paths.get(tmpPath)) + def getStemRegex(path: String): String = { + val cleanPath = replaceDBFSTokens(path) + val fileName = Paths.get(cleanPath).getFileName.toString + val stemName = fileName.substring(0, fileName.lastIndexOf(".")) + val stemEscaped = stemName.replace(".", "\\.") + val stemRegex = s"$stemEscaped\\..*".r + stemRegex.toString + } + + def copyToTmpWithRetry(inPath: String, retries: Int = 3): String = { + var tmpPath = copyToTmp(inPath) + var i = 0 + while (Files.notExists(Paths.get(tmpPath)) && i < retries) { + tmpPath = copyToTmp(inPath) + i += 1 + } tmpPath } + def copyToTmp(inPath: String): String = { + val copyFromPath = replaceDBFSTokens(inPath) + val inPathDir = Paths.get(copyFromPath).getParent.toString + + val fullFileName = copyFromPath.split("/").last + val stemRegex = getStemRegex(inPath) + + wildcardCopy(inPathDir, MosaicContext.tmpDir(null), stemRegex) + + s"${MosaicContext.tmpDir(null)}/$fullFileName" + } + + def wildcardCopy(inDirPath: String, outDirPath: String, pattern: String): Unit = { + import org.apache.commons.io.FileUtils + val copyFromPath = replaceDBFSTokens(inDirPath) + val copyToPath = replaceDBFSTokens(outDirPath) + + val toCopy = Files + .list(Paths.get(copyFromPath)) + .filter(_.getFileName.toString.matches(pattern)) + .collect(java.util.stream.Collectors.toList[Path]) + .asScala + + for (path <- toCopy) { + val destination = Paths.get(copyToPath, path.getFileName.toString) + // noinspection SimplifyBooleanMatch + if (Files.isDirectory(path)) FileUtils.copyDirectory(path.toFile, destination.toFile) + else FileUtils.copyFile(path.toFile, destination.toFile) + } + } + + def parseUnzippedPathFromExtracted(lastExtracted: String, extension: String): String = { + val trimmed = lastExtracted.replace("extracting: ", "").replace(" ", "") + val indexOfFormat = trimmed.indexOf(s".$extension/") + trimmed.substring(0, indexOfFormat + extension.length + 1) + } + } diff --git a/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala b/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala index 85fa12785..ba1d9c417 100644 --- a/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala +++ b/src/main/scala/com/databricks/labs/mosaic/utils/SysUtils.scala @@ -1,6 +1,6 @@ package com.databricks.labs.mosaic.utils -import java.io.{ByteArrayOutputStream, PrintWriter} +import java.io.{BufferedReader, ByteArrayOutputStream, InputStreamReader, PrintWriter} object SysUtils { @@ -11,16 +11,40 @@ object SysUtils { val stderrStream = new ByteArrayOutputStream val stdoutWriter = new PrintWriter(stdoutStream) val stderrWriter = new PrintWriter(stderrStream) - val exitValue = try { - //noinspection ScalaStyle - cmd.!!(ProcessLogger(stdoutWriter.println, stderrWriter.println)) - } catch { - case _: Exception => "ERROR" - } finally { - stdoutWriter.close() - stderrWriter.close() - } + val exitValue = + try { + // noinspection ScalaStyle + cmd.!!(ProcessLogger(stdoutWriter.println, stderrWriter.println)) + } catch { + case e: Exception => s"ERROR: ${e.getMessage}" + } finally { + stdoutWriter.close() + stderrWriter.close() + } (exitValue, stdoutStream.toString, stderrStream.toString) } + def runScript(cmd: Array[String]): (String, String, String) = { + val p = Runtime.getRuntime.exec(cmd) + val stdinStream = new BufferedReader(new InputStreamReader(p.getInputStream)) + val stderrStream = new BufferedReader(new InputStreamReader(p.getErrorStream)) + val exitValue = + try { + p.waitFor() + } catch { + case e: Exception => s"ERROR: ${e.getMessage}" + } + val stdinOutput = stdinStream.lines().toArray.mkString("\n") + val stderrOutput = stderrStream.lines().toArray.mkString("\n") + stdinStream.close() + stderrStream.close() + (s"$exitValue", stdinOutput, stderrOutput) + } + + def getLastOutputLine(prompt: (String, String, String)): String = { + val (_, stdout, _) = prompt + val lines = stdout.split("\n") + lines.last + } + } diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb.aux.xml similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grib.aux.xml rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626950.0440469-3609-11-041ac051-015d-49b0-95df-b5daa7084c7e.grb.aux.xml diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib.aux.xml rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb.aux.xml diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb diff --git a/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml b/src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb.aux.xml similarity index 100% rename from src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grib.aux.xml rename to src/test/resources/binary/grib-cams/adaptor.mars.internal-1650627030.319457-19905-15-0ede0273-89e3-4100-a0f2-48916ca607ed.grb.aux.xml diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala index 15eef2009..88c0f4bbb 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterBandGDAL.scala @@ -10,10 +10,11 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { test("Read band metadata and pixel data from GeoTIFF file.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + val createInfo = Map( + "path" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + "parentPath" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) val testBand = testRaster.getBand(1) testBand.getBand testBand.index shouldBe 1 @@ -36,10 +37,11 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { test("Read band metadata and pixel data from a GRIdded Binary file.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib"), - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") + val createInfo = Map( + "path" -> filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), + "parentPath" -> filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) val testBand = testRaster.getBand(1) testBand.description shouldBe "1[-] HYBL=\"Hybrid level\"" testBand.dataType shouldBe 7 @@ -55,15 +57,17 @@ class TestRasterBandGDAL extends SharedSparkSessionGDAL { test("Read band metadata and pixel data from a NetCDF file.") { assume(System.getProperty("os.name") == "Linux") - val superRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), - filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") + val createInfo = Map( + "path" -> filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), + "parentPath" -> filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") ) + val superRaster = MosaicRasterGDAL.readRaster(createInfo) val subdatasetPath = superRaster.subdatasets("bleaching_alert_area") - val testRaster = MosaicRasterGDAL.readRaster( - subdatasetPath, - subdatasetPath + val sdCreate = Map( + "path" -> subdatasetPath, + "parentPath" -> subdatasetPath ) + val testRaster = MosaicRasterGDAL.readRaster(sdCreate) val testBand = testRaster.getBand(1) testBand.dataType shouldBe 1 diff --git a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala index e39279843..8bbe4b1f3 100644 --- a/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala +++ b/src/test/scala/com/databricks/labs/mosaic/core/raster/TestRasterGDAL.scala @@ -1,9 +1,12 @@ package com.databricks.labs.mosaic.core.raster import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.gdal.MosaicGDAL import com.databricks.labs.mosaic.test.mocks.filePath import org.apache.spark.sql.test.SharedSparkSessionGDAL import org.scalatest.matchers.should.Matchers._ +import org.gdal.gdal.{gdal => gdalJNI} +import org.gdal.gdalconst import scala.sys.process._ import scala.util.Try @@ -31,11 +34,12 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Read raster metadata from GeoTIFF file.") { assume(System.getProperty("os.name") == "Linux") - - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + + val createInfo = Map( + "path" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + "parentPath" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) testRaster.xSize shouldBe 2400 testRaster.ySize shouldBe 2400 testRaster.numBands shouldBe 1 @@ -43,7 +47,7 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { testRaster.SRID shouldBe 0 testRaster.extent shouldBe Seq(-8895604.157333, 1111950.519667, -7783653.637667, 2223901.039333) testRaster.getRaster.GetProjection() - noException should be thrownBy testRaster.spatialRef + noException should be thrownBy testRaster.getSpatialReference an[Exception] should be thrownBy testRaster.getBand(-1) an[Exception] should be thrownBy testRaster.getBand(Int.MaxValue) @@ -53,10 +57,11 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Read raster metadata from a GRIdded Binary file.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib"), - filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") + val createInfo = Map( + "path" -> filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb"), + "parentPath" -> filePath("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) testRaster.xSize shouldBe 14 testRaster.ySize shouldBe 14 testRaster.numBands shouldBe 14 @@ -69,17 +74,19 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Read raster metadata from a NetCDF file.") { assume(System.getProperty("os.name") == "Linux") - - val superRaster = MosaicRasterGDAL.readRaster( - filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), - filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") + + val createInfo = Map( + "path" -> filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc"), + "parentPath" -> filePath("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") ) + val superRaster = MosaicRasterGDAL.readRaster(createInfo) val subdatasetPath = superRaster.subdatasets("bleaching_alert_area") - val testRaster = MosaicRasterGDAL.readRaster( - subdatasetPath, - subdatasetPath + val sdCreateInfo = Map( + "path" -> subdatasetPath, + "parentPath" -> subdatasetPath ) + val testRaster = MosaicRasterGDAL.readRaster(sdCreateInfo) testRaster.xSize shouldBe 7200 testRaster.ySize shouldBe 3600 @@ -95,10 +102,11 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { test("Raster pixel and extent sizes are correct.") { assume(System.getProperty("os.name") == "Linux") - val testRaster = MosaicRasterGDAL.readRaster( - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), - filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") + val createInfo = Map( + "path" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF"), + "parentPath" -> filePath("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") ) + val testRaster = MosaicRasterGDAL.readRaster(createInfo) testRaster.pixelXSize - 463.312716527 < 0.0000001 shouldBe true testRaster.pixelYSize - -463.312716527 < 0.0000001 shouldBe true @@ -115,4 +123,219 @@ class TestRasterGDAL extends SharedSparkSessionGDAL { testRaster.getRaster.delete() } + test("Raster filter operations are correct.") { + assume(System.getProperty("os.name") == "Linux") + + gdalJNI.AllRegister() + + MosaicGDAL.setBlockSize(30) + + val ds = gdalJNI.GetDriverByName("GTiff").Create("/tmp/mosaic_tmp/test.tif", 50, 50, 1, gdalconst.gdalconstConstants.GDT_Float32) + + val values = 0 until 50 * 50 + ds.GetRasterBand(1).WriteRaster(0, 0, 50, 50, values.toArray) + ds.FlushCache() + + val createInfo = Map( + "path" -> "", + "parentPath" -> "", + "driver" -> "GTiff" + ) + var result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "avg").flushCache() + + var resultValues = result.getBand(1).values + + var inputMatrix = values.toArray.grouped(50).toArray + var resultMatrix = resultValues.grouped(50).toArray + + // first block + resultMatrix(10)(11) shouldBe ( + inputMatrix(8)(9) + inputMatrix(8)(10) + inputMatrix(8)(11) + inputMatrix(8)(12) + inputMatrix(8)(13) + + inputMatrix(9)(9) + inputMatrix(9)(10) + inputMatrix(9)(11) + inputMatrix(9)(12) + inputMatrix(9)(13) + + inputMatrix(10)(9) + inputMatrix(10)(10) + inputMatrix(10)(11) + inputMatrix(10)(12) + inputMatrix(10)(13) + + inputMatrix(11)(9) + inputMatrix(11)(10) + inputMatrix(11)(11) + inputMatrix(11)(12) + inputMatrix(11)(13) + + inputMatrix(12)(9) + inputMatrix(12)(10) + inputMatrix(12)(11) + inputMatrix(12)(12) + inputMatrix(12)(13) + ).toDouble / 25.0 + + // block overlap + resultMatrix(30)(32) shouldBe ( + inputMatrix(28)(30) + inputMatrix(28)(31) + inputMatrix(28)(32) + inputMatrix(28)(33) + inputMatrix(28)(34) + + inputMatrix(29)(30) + inputMatrix(29)(31) + inputMatrix(29)(32) + inputMatrix(29)(33) + inputMatrix(29)(34) + + inputMatrix(30)(30) + inputMatrix(30)(31) + inputMatrix(30)(32) + inputMatrix(30)(33) + inputMatrix(30)(34) + + inputMatrix(31)(30) + inputMatrix(31)(31) + inputMatrix(31)(32) + inputMatrix(31)(33) + inputMatrix(31)(34) + + inputMatrix(32)(30) + inputMatrix(32)(31) + inputMatrix(32)(32) + inputMatrix(32)(33) + inputMatrix(32)(34) + ).toDouble / 25.0 + + // mode + + result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "mode").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).groupBy(identity).maxBy(_._2.size)._1.toDouble + + // corner + + resultMatrix(49)(49) shouldBe Seq( + inputMatrix(47)(47), + inputMatrix(47)(48), + inputMatrix(47)(49), + inputMatrix(48)(47), + inputMatrix(48)(48), + inputMatrix(48)(49), + inputMatrix(49)(47), + inputMatrix(49)(48), + inputMatrix(49)(49) + ).groupBy(identity).maxBy(_._2.size)._1.toDouble + + // median + + result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "median").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).sorted.apply(12).toDouble + + // min filter + + result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "min").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).min.toDouble + + // max filter + + result = MosaicRasterGDAL(ds, createInfo, -1).filter(5, "max").flushCache() + + resultValues = result.getBand(1).values + + inputMatrix = values.toArray.grouped(50).toArray + resultMatrix = resultValues.grouped(50).toArray + + // first block + + resultMatrix(10)(11) shouldBe Seq( + inputMatrix(8)(9), + inputMatrix(8)(10), + inputMatrix(8)(11), + inputMatrix(8)(12), + inputMatrix(8)(13), + inputMatrix(9)(9), + inputMatrix(9)(10), + inputMatrix(9)(11), + inputMatrix(9)(12), + inputMatrix(9)(13), + inputMatrix(10)(9), + inputMatrix(10)(10), + inputMatrix(10)(11), + inputMatrix(10)(12), + inputMatrix(10)(13), + inputMatrix(11)(9), + inputMatrix(11)(10), + inputMatrix(11)(11), + inputMatrix(11)(12), + inputMatrix(11)(13), + inputMatrix(12)(9), + inputMatrix(12)(10), + inputMatrix(12)(11), + inputMatrix(12)(12), + inputMatrix(12)(13) + ).max.toDouble + + } + } diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala index 99a1563ca..c51190ea6 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/GDALFileFormatTest.scala @@ -35,39 +35,6 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { } - test("Read grib with GDALFileFormat") { - assume(System.getProperty("os.name") == "Linux") - - val grib = "/binary/grib-cams/" - val filePath = getClass.getResource(grib).getPath - - noException should be thrownBy spark.read - .format("gdal") - .option("extensions", "grib") - .option("raster_storage", "disk") - .option("extensions", "grib") - .load(filePath) - .take(1) - - noException should be thrownBy spark.read - .format("gdal") - .option("extensions", "grib") - .option("raster_storage", "disk") - .option("extensions", "grib") - .load(filePath) - .take(1) - - noException should be thrownBy spark.read - .format("gdal") - .option("extensions", "grib") - .option("raster_storage", "disk") - .option("extensions", "grib") - .load(filePath) - .select("metadata") - .take(1) - - } - test("Read tif with GDALFileFormat") { assume(System.getProperty("os.name") == "Linux") @@ -84,15 +51,15 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { .option("driverName", "TIF") .load(filePath) .take(1) - - noException should be thrownBy spark.read + + spark.read .format("gdal") .option("driverName", "TIF") .load(filePath) .select("metadata") .take(1) - noException should be thrownBy spark.read + spark.read .format("gdal") .option(MOSAIC_RASTER_READ_STRATEGY, "retile_on_read") .load(filePath) @@ -141,4 +108,34 @@ class GDALFileFormatTest extends QueryTest with SharedSparkSessionGDAL { } + test("Read grib with GDALFileFormat") { + assume(System.getProperty("os.name") == "Linux") + + val grib = "/binary/grib-cams/" + val filePath = getClass.getResource(grib).getPath + + spark.read + .format("gdal") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") + .load(filePath) + .take(1) + + noException should be thrownBy spark.read + .format("gdal") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") + .load(filePath) + .take(1) + + spark.read + .format("gdal") + .option("extensions", "grb") + .option("raster.read.strategy", "retile_on_read") + .load(filePath) + .select("metadata") + .take(1) + + } + } diff --git a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala index 5e3a95bc1..721201eaa 100644 --- a/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala +++ b/src/test/scala/com/databricks/labs/mosaic/datasource/multiread/RasterAsGridReaderTest.scala @@ -1,8 +1,8 @@ package com.databricks.labs.mosaic.datasource.multiread -import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.JTS import com.databricks.labs.mosaic.core.index.H3IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext import com.databricks.labs.mosaic.test.MosaicSpatialQueryTest import org.apache.spark.sql.test.SharedSparkSessionGDAL import org.scalatest.matchers.must.Matchers.{be, noException} @@ -14,7 +14,6 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess test("Read netcdf with Raster As Grid Reader") { assume(System.getProperty("os.name") == "Linux") - spark.sparkContext.setLogLevel("FATAL") MosaicContext.build(H3IndexSystem, JTS) val netcdf = "/binary/netcdf-coral/" @@ -36,7 +35,6 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess test("Read grib with Raster As Grid Reader") { assume(System.getProperty("os.name") == "Linux") - spark.sparkContext.setLogLevel("FATAL") MosaicContext.build(H3IndexSystem, JTS) val grib = "/binary/grib-cams/" @@ -57,7 +55,6 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess test("Read tif with Raster As Grid Reader") { assume(System.getProperty("os.name") == "Linux") - spark.sparkContext.setLogLevel("FATAL") MosaicContext.build(H3IndexSystem, JTS) val tif = "/modis/" @@ -76,7 +73,6 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess test("Read zarr with Raster As Grid Reader") { assume(System.getProperty("os.name") == "Linux") - spark.sparkContext.setLogLevel("FATAL") MosaicContext.build(H3IndexSystem, JTS) val zarr = "/binary/zarr-example/" @@ -84,6 +80,8 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess noException should be thrownBy MosaicContext.read .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") .option("combiner", "median") .option("vsizip", "true") .option("tileSize", "10") @@ -93,6 +91,8 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess noException should be thrownBy MosaicContext.read .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") .option("combiner", "count") .option("vsizip", "true") .load(filePath) @@ -101,6 +101,8 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess noException should be thrownBy MosaicContext.read .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") .option("combiner", "average") .option("vsizip", "true") .load(filePath) @@ -109,6 +111,8 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess noException should be thrownBy MosaicContext.read .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") .option("combiner", "avg") .option("vsizip", "true") .load(filePath) @@ -135,6 +139,8 @@ class RasterAsGridReaderTest extends MosaicSpatialQueryTest with SharedSparkSess noException should be thrownBy MosaicContext.read .format("raster_to_grid") + .option("readSubdataset", "true") + .option("subdatasetName", "/group_with_attrs/F_order_array") .option("kRingInterpolate", "3") .load(filePath) diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala new file mode 100644 index 000000000..f01ce2d25 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgBehaviors.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_AvgBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_avg($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_avg(tile) from source + |""".stripMargin) + + val result = df.as[Double].collect().max + + result > 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_avg() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala new file mode 100644 index 000000000..6805f0723 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_AvgTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_AvgTest extends QueryTest with SharedSparkSessionGDAL with RST_AvgBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_avg behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala index 8ce57f5b8..4b7943d8a 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_CombineAvgBehaviors.scala @@ -4,7 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.collect_set +import org.apache.spark.sql.functions.{collect_list, collect_set} import org.scalatest.matchers.should.Matchers._ trait RST_CombineAvgBehaviors extends QueryTest { @@ -28,15 +28,17 @@ trait RST_CombineAvgBehaviors extends QueryTest { .select("path", "tiles") .groupBy("path") .agg( - rst_combineavg(collect_set($"tiles")).as("tiles") + rst_combineavg(collect_list($"tiles")).as("tiles") ) .select("tiles") rastersInMemory.union(rastersInMemory) .createOrReplaceTempView("source") - noException should be thrownBy spark.sql(""" - |select rst_combineavg(collect_set(tiles)) as tiles + //noException should be thrownBy + + spark.sql(""" + |select rst_combineavg(collect_list(tiles)) as tiles |from ( | select path, rst_tessellate(tile, 2) as tiles | from source diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveBehaviors.scala new file mode 100644 index 000000000..fdc09cc28 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveBehaviors.scala @@ -0,0 +1,43 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.{array, lit} +import org.scalatest.matchers.should.Matchers._ + +trait RST_ConvolveBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("result", rst_convolve($"tile", array(array(lit(1.0), lit(2.0), lit(3.0)), array(lit(3.0), lit(2.0), lit(1.0)), array(lit(1.0), lit(3.0), lit(2.0))))) + .select("result") + .collect() + + gridTiles.length should be(7) + + rastersInMemory.createOrReplaceTempView("source") + + spark + .sql(""" + |select rst_convolve(tile, array(array(1, 2, 3), array(2, 3, 1), array(1, 1, 1))) as tile from source + |""".stripMargin) + .collect() + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveTest.scala new file mode 100644 index 000000000..28c049364 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_ConvolveTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_ConvolveTest extends QueryTest with SharedSparkSessionGDAL with RST_ConvolveBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_Convolve with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandBehaviors.scala index ef6466a88..d883fc5cc 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_DerivedBandBehaviors.scala @@ -4,7 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.{collect_set, lit} +import org.apache.spark.sql.functions.{collect_list, lit} import org.scalatest.matchers.should.Matchers._ trait RST_DerivedBandBehaviors extends QueryTest { @@ -40,7 +40,7 @@ trait RST_DerivedBandBehaviors extends QueryTest { .select("path", "tiles") .groupBy("path") .agg( - rst_derivedband(collect_set($"tiles"), lit(pyFuncCode), lit(funcName)).as("tiles") + rst_derivedband(collect_list($"tiles"), lit(pyFuncCode), lit(funcName)).as("tiles") ) .select("tiles") @@ -52,7 +52,7 @@ trait RST_DerivedBandBehaviors extends QueryTest { noException should be thrownBy spark.sql( """ |select rst_derivedband( - | collect_set(tiles), + | collect_list(tiles), |" |import numpy as np |def multiply(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize,raster_ysize, buf_radius, gt, **kwargs): diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala new file mode 100644 index 000000000..2d64a633c --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterBehaviors.scala @@ -0,0 +1,78 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions.lit +import org.scalatest.matchers.should.Matchers._ + +trait RST_FilterBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("result", rst_filter($"tile", 3, "mode")) + .select("result") + .collect() + + gridTiles.length should be(7) + + val gridTiles2 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("mode"))) + .select("result") + .collect() + + gridTiles2.length should be(7) + + val gridTiles3 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("avg"))) + .select("result") + .collect() + + gridTiles3.length should be(7) + + val gridTiles4 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("min"))) + .select("result") + .collect() + + gridTiles4.length should be(7) + + val gridTiles5 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("max"))) + .select("result") + .collect() + + gridTiles5.length should be(7) + + val gridTiles6 = rastersInMemory + .withColumn("result", rst_filter($"tile", lit(3), lit("median"))) + .select("result") + .collect() + + gridTiles6.length should be(7) + + rastersInMemory.createOrReplaceTempView("source") + + noException should be thrownBy spark + .sql(""" + |select rst_filter(tile, 3, 'mode') as tile from source + |""".stripMargin) + .collect() + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala new file mode 100644 index 000000000..a243f7168 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_FilterTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_FilterTest extends QueryTest with SharedSparkSessionGDAL with RST_FilterBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_Filter with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesBehaviors.scala new file mode 100644 index 000000000..31caed8b9 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesBehaviors.scala @@ -0,0 +1,101 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.scalatest.matchers.should.Matchers._ + +trait RST_MakeTilesBehaviors extends QueryTest { + + // noinspection MapGetGet + def behaviors(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + spark.sparkContext.setLogLevel("ERROR") + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("binaryFile") + .load("src/test/resources/modis") + + val gridTiles1 = rastersInMemory + .withColumn("tile", rst_maketiles($"content", "GTiff", -1)) + .select(!rst_isempty($"tile")) + .as[Boolean] + .collect() + + gridTiles1.forall(identity) should be(true) + + rastersInMemory.createOrReplaceTempView("source") + + val gridTilesSQL = spark + .sql(""" + |with subquery as ( + | select rst_maketiles(content, 'GTiff', -1) as tile from source + |) + |select not rst_isempty(tile) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL.forall(identity) should be(true) + + + val gridTilesSQL2 = spark + .sql( + """ + |with subquery as ( + | select rst_maketiles(content, 'GTiff', 4) as tile from source + |) + |select not rst_isempty(tile) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL2.forall(identity) should be(true) + + val gridTilesSQL3 = spark + .sql( + """ + |with subquery as ( + | select rst_maketiles(path, 'GTiff', 4) as tile from source + |) + |select not rst_isempty(tile) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL3.forall(identity) should be(true) + + val gridTilesSQL4 = spark + .sql( + """ + |with subquery as ( + | select rst_maketiles(path, 'GTiff', 4, true) as tile from source + |) + |select not rst_isempty(tile) as result + |from subquery + |""".stripMargin) + .as[Boolean] + .collect() + + gridTilesSQL4.forall(identity) should be(true) + + val gridTiles2 = rastersInMemory + .withColumn("tile", rst_maketiles($"path")) + .select(!rst_isempty($"tile")) + .as[Boolean] + .collect() + + gridTiles2.forall(identity) should be(true) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesTest.scala new file mode 100644 index 000000000..7eaae222b --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MakeTilesTest.scala @@ -0,0 +1,31 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MakeTilesTest extends QueryTest with SharedSparkSessionGDAL with RST_MakeTilesBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_MakeTiles with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behaviors(H3IndexSystem, JTS) + } + } +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala new file mode 100644 index 000000000..daab1ee90 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxBehaviors.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_MaxBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_max($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_max(tile) from source + |""".stripMargin) + + val result = df.as[Double].collect().max + + result > 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_max() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxTest.scala new file mode 100644 index 000000000..0a0b865cd --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MaxTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MaxTest extends QueryTest with SharedSparkSessionGDAL with RST_MaxBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_max behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala new file mode 100644 index 000000000..1b99fbc6f --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianBehaviors.scala @@ -0,0 +1,52 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_MedianBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_median($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_median(tile) from source + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridmax($"tile", lit(3))) + .select("result") + + val result = df.as[Double].collect().max + + result > 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_median() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala new file mode 100644 index 000000000..cfe270813 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MedianTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MedianTest extends QueryTest with SharedSparkSessionGDAL with RST_MedianBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_median behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala index 893d6bdf4..330345f20 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MergeBehaviors.scala @@ -4,7 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.collect_set +import org.apache.spark.sql.functions.collect_list import org.scalatest.matchers.should.Matchers._ trait RST_MergeBehaviors extends QueryTest { @@ -29,7 +29,7 @@ trait RST_MergeBehaviors extends QueryTest { .select("path", "tile") .groupBy("path") .agg( - collect_set("tile").as("tiles") + collect_list("tile").as("tiles") ) .select( rst_merge($"tiles").as("tile") @@ -41,7 +41,7 @@ trait RST_MergeBehaviors extends QueryTest { spark.sql(""" |select rst_merge(tiles) as tile |from ( - | select collect_set(tile) as tiles + | select collect_list(tile) as tiles | from ( | select path, rst_tessellate(tile, 3) as tile | from source @@ -55,7 +55,7 @@ trait RST_MergeBehaviors extends QueryTest { .select("path", "tile") .groupBy("path") .agg( - collect_set("tile").as("tiles") + collect_list("tile").as("tiles") ) .select( rst_merge($"tiles").as("tile") diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala new file mode 100644 index 000000000..d01f79fec --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinBehaviors.scala @@ -0,0 +1,48 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_MinBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_min($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_min(tile) from source + |""".stripMargin) + + val result = df.as[Double].collect().min + + result == 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_min() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala new file mode 100644 index 000000000..ec09792f9 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_MinTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_MinTest extends QueryTest with SharedSparkSessionGDAL with RST_MinBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_min behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountBehaviors.scala new file mode 100644 index 000000000..87582df1f --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountBehaviors.scala @@ -0,0 +1,52 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_PixelCountBehaviors extends QueryTest { + + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val df = rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .withColumn("result", rst_pixelcount($"tile")) + .select("result") + .select(explode($"result").as("result")) + + rastersInMemory + .withColumn("tile", rst_tessellate($"tile", lit(3))) + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_pixelcount(tile) from source + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("result", rst_rastertogridmax($"tile", lit(3))) + .select("result") + + val result = df.as[Double].collect().max + + result > 0 shouldBe true + + an[Exception] should be thrownBy spark.sql(""" + |select rst_pixelcount() from source + |""".stripMargin) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountTest.scala new file mode 100644 index 000000000..1d24b58a4 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_PixelCountTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_PixelCountTest extends QueryTest with SharedSparkSessionGDAL with RST_PixelCountBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing rst_pixelcount behavior with H3IndexSystem and JTS") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala index 8804968e1..38d3fc778 100644 --- a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TessellateBehaviors.scala @@ -4,7 +4,7 @@ import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.index.IndexSystem import com.databricks.labs.mosaic.functions.MosaicContext import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.functions._ import org.scalatest.matchers.should.Matchers._ trait RST_TessellateBehaviors extends QueryTest { @@ -25,8 +25,10 @@ trait RST_TessellateBehaviors extends QueryTest { val gridTiles = rastersInMemory .withColumn("tiles", rst_tessellate($"tile", 3)) - .select("tiles") - + .withColumn("bbox", st_aswkt(rst_boundingbox($"tile"))) + .select("bbox", "path", "tiles") + .withColumn("avg", rst_avg($"tiles")) + rastersInMemory .createOrReplaceTempView("source") @@ -38,9 +40,9 @@ trait RST_TessellateBehaviors extends QueryTest { .withColumn("tiles", rst_tessellate($"tile", 3)) .select("tiles") - val result = gridTiles.collect() + val result = gridTiles.select(explode(col("avg")).alias("a")).groupBy("a").count().collect() - result.length should be(462) + result.length should be(441) val netcdf = spark.read .format("gdal") diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformBehaviors.scala new file mode 100644 index 000000000..9ce449b13 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformBehaviors.scala @@ -0,0 +1,49 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import com.databricks.labs.mosaic.core.index.IndexSystem +import com.databricks.labs.mosaic.functions.MosaicContext +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.functions._ +import org.scalatest.matchers.should.Matchers._ + +trait RST_TransformBehaviors extends QueryTest { + + // noinspection MapGetGet + def behavior(indexSystem: IndexSystem, geometryAPI: GeometryAPI): Unit = { + spark.sparkContext.setLogLevel("FATAL") + val mc = MosaicContext.build(indexSystem, geometryAPI) + mc.register() + val sc = spark + import mc.functions._ + import sc.implicits._ + + val rastersInMemory = spark.read + .format("gdal") + .option("raster_storage", "in-memory") + .load("src/test/resources/modis") + + val gridTiles = rastersInMemory + .withColumn("tile", rst_transform($"tile", lit(27700))) + .withColumn("bbox", st_aswkt(rst_boundingbox($"tile"))) + .select("bbox", "path", "tile") + .withColumn("avg", rst_avg($"tile")) + + rastersInMemory + .createOrReplaceTempView("source") + + noException should be thrownBy spark.sql(""" + |select rst_transform(tile, 27700) from source + |""".stripMargin) + + noException should be thrownBy rastersInMemory + .withColumn("tile", rst_transform($"tile", lit(27700))) + .select("tile") + + val result = gridTiles.select(explode(col("avg")).alias("a")).groupBy("a").count().collect() + + result.length should be(7) + + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformTest.scala b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformTest.scala new file mode 100644 index 000000000..b7c10e548 --- /dev/null +++ b/src/test/scala/com/databricks/labs/mosaic/expressions/raster/RST_TransformTest.scala @@ -0,0 +1,32 @@ +package com.databricks.labs.mosaic.expressions.raster + +import com.databricks.labs.mosaic.core.geometry.api.JTS +import com.databricks.labs.mosaic.core.index.H3IndexSystem +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSessionGDAL + +import scala.util.Try + +class RST_TransformTest extends QueryTest with SharedSparkSessionGDAL with RST_TransformBehaviors { + + private val noCodegen = + withSQLConf( + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> CodegenObjectFactoryMode.NO_CODEGEN.toString + ) _ + + // Hotfix for SharedSparkSession afterAll cleanup. + override def afterAll(): Unit = Try(super.afterAll()) + + // These tests are not index system nor geometry API specific. + // Only testing one pairing is sufficient. + test("Testing RST_Transform with manual GDAL registration (H3, JTS).") { + noCodegen { + assume(System.getProperty("os.name") == "Linux") + behavior(H3IndexSystem, JTS) + } + } + +} diff --git a/src/test/scala/com/databricks/labs/mosaic/models/knn/SpatialKNNBehaviors.scala b/src/test/scala/com/databricks/labs/mosaic/models/knn/SpatialKNNBehaviors.scala index e9508c3c6..be9b5c402 100644 --- a/src/test/scala/com/databricks/labs/mosaic/models/knn/SpatialKNNBehaviors.scala +++ b/src/test/scala/com/databricks/labs/mosaic/models/knn/SpatialKNNBehaviors.scala @@ -29,7 +29,7 @@ trait SpatialKNNBehaviors { this: AnyFlatSpec => val boroughs: DataFrame = getBoroughs(mc) - val tempLocation = MosaicContext.tmpDir + val tempLocation = MosaicContext.tmpDir(null) spark.sparkContext.setCheckpointDir(tempLocation) spark.sparkContext.setLogLevel("ERROR") @@ -94,7 +94,7 @@ trait SpatialKNNBehaviors { this: AnyFlatSpec => val boroughs: DataFrame = getBoroughs(mc) - val tempLocation = MosaicContext.tmpDir + val tempLocation = MosaicContext.tmpDir(null) spark.sparkContext.setCheckpointDir(tempLocation) spark.sparkContext.setLogLevel("ERROR") diff --git a/src/test/scala/com/databricks/labs/mosaic/test/package.scala b/src/test/scala/com/databricks/labs/mosaic/test/package.scala index 5d76fb48a..1c8c689aa 100644 --- a/src/test/scala/com/databricks/labs/mosaic/test/package.scala +++ b/src/test/scala/com/databricks/labs/mosaic/test/package.scala @@ -164,7 +164,7 @@ package object test { } val geotiffBytes: Array[Byte] = fileBytes("/modis/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF") val gribBytes: Array[Byte] = - fileBytes("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grib") + fileBytes("/binary/grib-cams/adaptor.mars.internal-1650626995.380916-11651-14-ca8e7236-16ca-4e11-919d-bdbd5a51da35.grb") val netcdfBytes: Array[Byte] = fileBytes("/binary/netcdf-coral/ct5km_baa-max-7d_v3.1_20220101.nc") def polyDf(sparkSession: SparkSession, mosaicContext: MosaicContext): DataFrame = { diff --git a/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala b/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala new file mode 100644 index 000000000..84a613b31 --- /dev/null +++ b/src/test/scala/org/apache/spark/sql/test/MosaicTestSparkSession.scala @@ -0,0 +1,27 @@ +package org.apache.spark.sql.test + +import org.apache.spark.{SparkConf, SparkContext} + +class MosaicTestSparkSession(sc: SparkContext) extends TestSparkSession(sc) { + + def this(sparkConf: SparkConf) = { + + this( + new SparkContext( + "local[8]", + "test-sql-context", + sparkConf + .set("spark.sql.adaptive.enabled", "false") + .set("spark.driver.memory", "32g") + .set("spark.executor.memory", "32g") + .set("spark.sql.shuffle.partitions", "8") + .set("spark.sql.testkey", "true") + ) + ) + } + + def this() = { + this(new SparkConf) + } + +} diff --git a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala index a720b3f9b..12dcac6f3 100644 --- a/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala +++ b/src/test/scala/org/apache/spark/sql/test/SharedSparkSessionGDAL.scala @@ -1,12 +1,12 @@ package org.apache.spark.sql.test -import com.databricks.labs.mosaic._ import com.databricks.labs.mosaic.gdal.MosaicGDAL +import com.databricks.labs.mosaic.utils.FileUtils +import com.databricks.labs.mosaic.{MOSAIC_GDAL_NATIVE, MOSAIC_RASTER_CHECKPOINT} import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.gdal.gdal.gdal -import java.nio.file.Files import scala.util.Try trait SharedSparkSessionGDAL extends SharedSparkSession { @@ -18,9 +18,9 @@ trait SharedSparkSessionGDAL extends SharedSparkSession { override def createSparkSession: TestSparkSession = { val conf = sparkConf - conf.set(MOSAIC_RASTER_CHECKPOINT, Files.createTempDirectory("mosaic").toFile.getAbsolutePath) + conf.set(MOSAIC_RASTER_CHECKPOINT, FileUtils.createMosaicTempDir(prefix = "/mnt/")) SparkSession.cleanupAnyExistingSession() - val session = new TestSparkSession(conf) + val session = new MosaicTestSparkSession(conf) session.sparkContext.setLogLevel("FATAL") Try { MosaicGDAL.enableGDAL(session) @@ -33,5 +33,5 @@ trait SharedSparkSessionGDAL extends SharedSparkSession { MosaicGDAL.enableGDAL(this.spark) gdal.AllRegister() } - + }