diff --git a/.github/actions/python_build/action.yml b/.github/actions/python_build/action.yml index b6abf6239..3eab1778d 100644 --- a/.github/actions/python_build/action.yml +++ b/.github/actions/python_build/action.yml @@ -11,7 +11,7 @@ runs: shell: bash run: | cd python - pip install build wheel pyspark==${{ matrix.spark }} + pip install build wheel pyspark==${{ matrix.spark }} numpy==${{ matrix.numpy }} pip install . - name: Test and build python package shell: bash diff --git a/.github/actions/scala_build/action.yml b/.github/actions/scala_build/action.yml index 950ae2200..5e12ead9a 100644 --- a/.github/actions/scala_build/action.yml +++ b/.github/actions/scala_build/action.yml @@ -12,18 +12,21 @@ runs: with: java-version: '8' distribution: 'zulu' + - name: Configure python interpreter + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + - name: Add packaged GDAL dependencies + shell: bash + run : | + sudo apt-get update && sudo apt-get install -y unixodbc libcurl3-gnutls libsnappy-dev libopenjp2-7 + pip install databricks-mosaic-gdal==${{ matrix.gdal }} + sudo tar -xf /opt/hostedtoolcache/Python/${{ matrix.python }}/x64/lib/python3.9/site-packages/databricks-mosaic-gdal/resources/gdal-${{ matrix.gdal }}-filetree.tar.xz -C / + sudo tar -xhf /opt/hostedtoolcache/Python/${{ matrix.python }}/x64/lib/python3.9/site-packages/databricks-mosaic-gdal/resources/gdal-${{ matrix.gdal }}-symlinks.tar.xz -C / - name: Test and build the scala JAR - skip tests is false if: inputs.skip_tests == 'false' shell: bash - run: | - pip install databricks-mosaic-gdal==3.4.3 - sudo tar -xf /home/runner/.local/lib/python3.8/site-packages/databricks-mosaic-gdal/resources/gdal-3.4.3-filetree.tar.xz -C / - sudo tar -xhf /home/runner/.local/lib/python3.8/site-packages/databricks-mosaic-gdal/resources/gdal-3.4.3-symlinks.tar.xz -C / - sudo add-apt-repository ppa:ubuntugis/ubuntugis-unstable - sudo apt clean && sudo apt -o Acquire::Retries=3 update --fix-missing -y - sudo apt-get -o Acquire::Retries=3 update -y - sudo apt-get -o Acquire::Retries=3 install -y gdal-bin=3.4.3+dfsg-1~focal0 libgdal-dev=3.4.3+dfsg-1~focal0 python3-gdal=3.4.3+dfsg-1~focal0 - sudo mvn -q clean install + run: sudo mvn -q clean install - name: Build the scala JAR - skip tests is true if: inputs.skip_tests == 'true' shell: bash diff --git a/.github/workflows/build_main.yml b/.github/workflows/build_main.yml index ac5cb0623..5d303827b 100644 --- a/.github/workflows/build_main.yml +++ b/.github/workflows/build_main.yml @@ -2,10 +2,10 @@ name: build main on: push: branches-ignore: - - "R/*" - - "r/*" - - "python/*" - - "scala/*" + - "R/**" + - "r/**" + - "python/**" + - "scala/**" pull_request: branches: - "**" @@ -16,8 +16,10 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} strategy: matrix: - python: [ 3.9 ] - spark: [ 3.2.1 ] + python: [ 3.9.5 ] + numpy: [ 1.21.5 ] + gdal: [ 3.4.3 ] + spark: [ 3.3.2 ] R: [ 4.1.2 ] steps: - name: checkout code diff --git a/.github/workflows/build_python.yml b/.github/workflows/build_python.yml index c2492002c..cf9771037 100644 --- a/.github/workflows/build_python.yml +++ b/.github/workflows/build_python.yml @@ -3,7 +3,7 @@ name: build_python on: push: branches: - - "python/*" + - "python/**" jobs: build: @@ -12,8 +12,10 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} strategy: matrix: - python: [ 3.9 ] - spark: [ 3.2.1 ] + python: [ 3.9.5 ] + numpy: [ 1.21.5 ] + gdal: [ 3.4.3 ] + spark: [ 3.3.2 ] R: [ 4.1.2 ] steps: - name: checkout code diff --git a/.github/workflows/build_r.yml b/.github/workflows/build_r.yml index 644ba9d7d..28d8172d5 100644 --- a/.github/workflows/build_r.yml +++ b/.github/workflows/build_r.yml @@ -3,8 +3,8 @@ name: build_R on: push: branches: - - 'r/*' - - 'R/*' + - 'r/**' + - 'R/**' jobs: build: @@ -13,8 +13,10 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} strategy: matrix: - python: [ 3.9 ] - spark: [ 3.2.1 ] + python: [ 3.9.5 ] + numpy: [ 1.21.5 ] + gdal: [ 3.4.3 ] + spark: [ 3.3.2 ] R: [ 4.1.2 ] steps: - name: checkout code diff --git a/.github/workflows/build_scala.yml b/.github/workflows/build_scala.yml index c6297e6f0..b94dbe686 100644 --- a/.github/workflows/build_scala.yml +++ b/.github/workflows/build_scala.yml @@ -2,7 +2,7 @@ name: build_scala on: push: branches: - - "scala/" + - "scala/**" jobs: build: @@ -11,8 +11,10 @@ jobs: GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }} strategy: matrix: - python: [ 3.9 ] - spark: [ 3.2.1 ] + python: [ 3.9.5 ] + numpy: [ 1.21.5 ] + gdal: [ 3.4.3 ] + spark: [ 3.3.2 ] R: [ 4.1.2 ] steps: - name: checkout code diff --git a/python/mosaic/api/aggregators.py b/python/mosaic/api/aggregators.py index 3f1f54ac8..b4e9bf7e1 100644 --- a/python/mosaic/api/aggregators.py +++ b/python/mosaic/api/aggregators.py @@ -48,9 +48,7 @@ def st_intersection_aggregate( ) -def st_intersection_agg( - leftIndex: ColumnOrName, rightIndex: ColumnOrName -) -> Column: +def st_intersection_agg(leftIndex: ColumnOrName, rightIndex: ColumnOrName) -> Column: """ Computes the intersection of all `leftIndex` : `rightIndex` pairs and unions these to produce a single geometry. @@ -100,9 +98,7 @@ def st_intersects_aggregate( ) -def st_intersects_agg( - leftIndex: ColumnOrName, rightIndex: ColumnOrName -) -> Column: +def st_intersects_agg(leftIndex: ColumnOrName, rightIndex: ColumnOrName) -> Column: """ Tests if any `leftIndex` : `rightIndex` pairs intersect. diff --git a/python/mosaic/api/functions.py b/python/mosaic/api/functions.py index 76df46554..41d8325d5 100644 --- a/python/mosaic/api/functions.py +++ b/python/mosaic/api/functions.py @@ -18,7 +18,6 @@ "st_convexhull", "st_buffer", "st_bufferloop", - "st_buffer_cap_style", "st_dump", "st_envelope", "st_srid", @@ -208,33 +207,6 @@ def st_bufferloop( ) -def st_buffer_cap_style(geom: ColumnOrName, radius: ColumnOrName, cap_style: ColumnOrName) -> Column: - """ - Compute the buffered geometry based on geom and radius. - - Parameters - ---------- - geom : Column - The input geometry - radius : Column - The radius of buffering - cap_style : Column - The cap style of the buffer - - Returns - ------- - Column - A geometry - - """ - return config.mosaic_context.invoke_function( - "st_buffer_cap_style", - pyspark_to_java_column(geom), - pyspark_to_java_column(radius), - pyspark_to_java_column(cap_style) - ) - - def st_dump(geom: ColumnOrName) -> Column: """ Explodes a multi-geometry into one row per constituent geometry. diff --git a/python/mosaic/api/gdal.py b/python/mosaic/api/gdal.py index 0887bfc00..5113e9bbc 100644 --- a/python/mosaic/api/gdal.py +++ b/python/mosaic/api/gdal.py @@ -25,14 +25,11 @@ def setup_gdal( ------- """ - sc = spark.sparkContext - mosaicContextClass = getattr( - sc._jvm.com.databricks.labs.mosaic.functions, "MosaicContext" + mosaicGDALObject = getattr( + spark.sparkContext._jvm.com.databricks.labs.mosaic.gdal, "MosaicGDAL" ) - mosaicGDALObject = getattr(sc._jvm.com.databricks.labs.mosaic.gdal, "MosaicGDAL") mosaicGDALObject.prepareEnvironment(spark._jsparkSession, init_script_path) print("GDAL setup complete.\n") - print(f"Shared objects (*.so) stored in: {shared_objects_path}.\n") print(f"Init script stored in: {init_script_path}.\n") print( "Please restart the cluster with the generated init script to complete the setup.\n" diff --git a/python/mosaic/api/raster.py b/python/mosaic/api/raster.py index 5bc140f72..e84d41ad4 100644 --- a/python/mosaic/api/raster.py +++ b/python/mosaic/api/raster.py @@ -13,29 +13,29 @@ "rst_boundingbox", "rst_clip", "rst_combineavg", - "rst_fromfile", "rst_frombands", + "rst_fromfile", "rst_georeference", - "ret_getnodata", + "rst_getnodata", "rst_getsubdataset", "rst_height", - "rst_isempty", "rst_initnodata", + "rst_isempty", "rst_memsize", - "rst_metadata", "rst_merge", - "rst_numbands", + "rst_metadata", "rst_ndvi", + "rst_numbands", "rst_pixelheight", "rst_pixelwidth", "rst_rastertogridavg", "rst_rastertogridcount", "rst_rastertogridmax", - "rst_rastertogridmin", "rst_rastertogridmedian", - "rst_rastertoworldcoord", + "rst_rastertogridmin", "rst_rastertoworldcoordx", "rst_rastertoworldcoordy", + "rst_rastertoworldcoord", "rst_retile", "rst_rotation", "rst_scalex", @@ -45,17 +45,17 @@ "rst_skewy", "rst_srid", "rst_subdatasets", - "rst_summary", "rst_subdivide", + "rst_summary", "rst_tessellate", "rst_to_overlapping_tiles", "rst_tryopen", "rst_upperleftx", "rst_upperlefty", "rst_width", - "rst_worldtorastercoord", "rst_worldtorastercoordx", "rst_worldtorastercoordy", + "rst_worldtorastercoord", ] @@ -172,7 +172,7 @@ def rst_georeference(raster: ColumnOrName) -> Column: ) -def ret_getnodata(raster: ColumnOrName) -> Column: +def rst_getnodata(raster: ColumnOrName) -> Column: """ Returns the nodata value of the band. @@ -190,7 +190,7 @@ def ret_getnodata(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "ret_getnodata", pyspark_to_java_column(raster) + "rst_getnodata", pyspark_to_java_column(raster) ) @@ -253,8 +253,7 @@ def rst_initnodata(raster: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_initnodata", - pyspark_to_java_column(raster) + "rst_initnodata", pyspark_to_java_column(raster) ) @@ -897,13 +896,16 @@ def rst_fromfile(raster: ColumnOrName, sizeInMB: ColumnOrName) -> Column: """ return config.mosaic_context.invoke_function( - "rst_fromfile", - pyspark_to_java_column(raster), - pyspark_to_java_column(sizeInMB) + "rst_fromfile", pyspark_to_java_column(raster), pyspark_to_java_column(sizeInMB) ) -def rst_to_overlapping_tiles(raster: ColumnOrName, width: ColumnOrName, height: ColumnOrName, overlap: ColumnOrName) -> Column: +def rst_to_overlapping_tiles( + raster: ColumnOrName, + width: ColumnOrName, + height: ColumnOrName, + overlap: ColumnOrName, +) -> Column: """ Tiles the raster into tiles of the given size. :param raster: @@ -916,7 +918,7 @@ def rst_to_overlapping_tiles(raster: ColumnOrName, width: ColumnOrName, height: pyspark_to_java_column(raster), pyspark_to_java_column(width), pyspark_to_java_column(height), - pyspark_to_java_column(overlap) + pyspark_to_java_column(overlap), ) @@ -1048,7 +1050,10 @@ def rst_worldtorastercoord( """ return config.mosaic_context.invoke_function( - "rst_worldtorastercoord", pyspark_to_java_column(raster) + "rst_worldtorastercoord", + pyspark_to_java_column(raster), + pyspark_to_java_column(x), + pyspark_to_java_column(y), ) @@ -1074,7 +1079,10 @@ def rst_worldtorastercoordx( """ return config.mosaic_context.invoke_function( - "rst_worldtorastercoordx", pyspark_to_java_column(raster) + "rst_worldtorastercoordx", + pyspark_to_java_column(raster), + pyspark_to_java_column(x), + pyspark_to_java_column(y), ) @@ -1100,5 +1108,8 @@ def rst_worldtorastercoordy( """ return config.mosaic_context.invoke_function( - "rst_worldtorastercoordy", pyspark_to_java_column(raster) + "rst_worldtorastercoordy", + pyspark_to_java_column(raster), + pyspark_to_java_column(x), + pyspark_to_java_column(y), ) diff --git a/python/mosaic/config/config.py b/python/mosaic/config/config.py index a59979be6..bc5f80c9f 100644 --- a/python/mosaic/config/config.py +++ b/python/mosaic/config/config.py @@ -10,3 +10,4 @@ display_handler: DisplayHandler ipython_hook: InteractiveShell notebook_utils = None +default_gdal_init_script_path: str = "/dbfs/FileStore/geospatial/mosaic/gdal/" diff --git a/python/mosaic/core/mosaic_context.py b/python/mosaic/core/mosaic_context.py index ff49a4e37..85b15ab6b 100644 --- a/python/mosaic/core/mosaic_context.py +++ b/python/mosaic/core/mosaic_context.py @@ -51,9 +51,7 @@ def __init__(self, spark: SparkSession): IndexSystem = self._indexSystemFactory.getIndexSystem(self._index_system) GeometryAPIClass = getattr(self._mosaicPackageObject, self._geometry_api) - self._context = self._mosaicContextClass.build( - IndexSystem, GeometryAPIClass() - ) + self._context = self._mosaicContextClass.build(IndexSystem, GeometryAPIClass()) def invoke_function(self, name: str, *args: Any) -> MosaicColumn: func = getattr(self._context.functions(), name) diff --git a/python/setup.cfg b/python/setup.cfg index d7a7fbe05..2bc75b1be 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -17,10 +17,14 @@ classifiers = [options] packages = find: python_requires = >=3.7.0 +setup_requires = + pyspark==3.3.2 + ipython>=7.22.0 + install_requires = keplergl==0.3.2 h3==3.7.3 - ipython>=7.22.0 + gdal[numpy]==3.4.3 [options.package_data] mosaic = diff --git a/python/test/data/MCD43A4.A2018185.h10v07.006.2018194033728_B04.TIF b/python/test/data/MCD43A4.A2018185.h10v07.006.2018194033728_B04.TIF new file mode 100644 index 000000000..2eae2a2ad Binary files /dev/null and b/python/test/data/MCD43A4.A2018185.h10v07.006.2018194033728_B04.TIF differ diff --git a/python/test/test_gdal_install.py b/python/test/test_gdal_install.py new file mode 100644 index 000000000..b1e2d0ae8 --- /dev/null +++ b/python/test/test_gdal_install.py @@ -0,0 +1,20 @@ +from .utils import SparkTestCase, GDALInstaller + + +class TestGDALInstall(SparkTestCase): + def test_setup_gdal(self): + installer = GDALInstaller(self.spark) + try: + installer.copy_objects() + except Exception: + self.fail("Copying objects with `setup_gdal()` raised an exception.") + + try: + installer_result = installer.run_init_script() + except Exception: + self.fail("Execution of GDAL init script raised an exception.") + + self.assertEqual(installer_result, 0) + + gdalinfo_result = installer.test_gdalinfo() + self.assertEqual(gdalinfo_result, "GDAL 3.4.3, released 2022/04/22\n") diff --git a/python/test/test_raster_functions.py b/python/test/test_raster_functions.py new file mode 100644 index 000000000..077bd06df --- /dev/null +++ b/python/test/test_raster_functions.py @@ -0,0 +1,157 @@ +import logging +import random +import unittest + +from pyspark.sql.functions import abs, col, first, lit, sqrt, array + +from .context import api +from .utils import MosaicTestCaseWithGDAL + + +class TestRasterFunctions(MosaicTestCaseWithGDAL): + def test_read_raster(self): + result = self.generate_singleband_raster_df().first() + self.assertEqual(result.length, 1067862) + self.assertEqual(result.x_size, 2400) + self.assertEqual(result.y_size, 2400) + self.assertEqual(result.srid, 0) + self.assertEqual(result.bandCount, 1) + self.assertEqual( + result.metadata["LONGNAME"], + "MODIS/Terra+Aqua BRDF/Albedo Nadir BRDF-Adjusted Ref Daily L3 Global - 500m", + ) + self.assertEqual(result.tile["driver"], "GTiff") + + def test_raster_scalar_functions(self): + result = ( + self.generate_singleband_raster_df() + .withColumn("rst_bandmetadata", api.rst_bandmetadata("tile", lit(1))) + .withColumn("rst_boundingbox", api.rst_boundingbox("tile")) + .withColumn( + "rst_boundingbox", api.st_buffer("rst_boundingbox", lit(-0.001)) + ) + .withColumn("rst_clip", api.rst_clip("tile", "rst_boundingbox")) + .withColumn( + "rst_combineavg", + api.rst_combineavg(array(col("tile"), col("rst_clip"))), + ) + .withColumn("rst_frombands", api.rst_frombands(array("tile", "tile"))) + .withColumn("tile_from_file", api.rst_fromfile("path", lit(-1))) + .withColumn("rst_georeference", api.rst_georeference("tile")) + .withColumn("rst_getnodata", api.rst_getnodata("tile")) + .withColumn("rst_subdatasets", api.rst_subdatasets("tile")) + # .withColumn("rst_getsubdataset", api.rst_getsubdataset("tile")) + .withColumn("rst_height", api.rst_height("tile")) + .withColumn("rst_initnodata", api.rst_initnodata("tile")) + .withColumn("rst_isempty", api.rst_isempty("tile")) + .withColumn("rst_memsize", api.rst_memsize("tile")) + .withColumn("rst_merge", api.rst_merge(array("tile", "tile"))) + .withColumn("rst_metadata", api.rst_metadata("tile")) + .withColumn("rst_ndvi", api.rst_ndvi("tile", lit(1), lit(1))) + .withColumn("rst_numbands", api.rst_numbands("tile")) + .withColumn("rst_pixelheight", api.rst_pixelheight("tile")) + .withColumn("rst_pixelwidth", api.rst_pixelwidth("tile")) + .withColumn("rst_rastertogridavg", api.rst_rastertogridavg("tile", lit(9))) + .withColumn( + "rst_rastertogridcount", api.rst_rastertogridcount("tile", lit(9)) + ) + .withColumn("rst_rastertogridmax", api.rst_rastertogridmax("tile", lit(9))) + .withColumn( + "rst_rastertogridmedian", api.rst_rastertogridmedian("tile", lit(9)) + ) + .withColumn("rst_rastertogridmin", api.rst_rastertogridmin("tile", lit(9))) + .withColumn( + "rst_rastertoworldcoordx", + api.rst_rastertoworldcoordx("tile", lit(1200), lit(1200)), + ) + .withColumn( + "rst_rastertoworldcoordy", + api.rst_rastertoworldcoordy("tile", lit(1200), lit(1200)), + ) + .withColumn( + "rst_rastertoworldcoord", + api.rst_rastertoworldcoord("tile", lit(1200), lit(1200)), + ) + .withColumn("rst_rotation", api.rst_rotation("tile")) + .withColumn("rst_scalex", api.rst_scalex("tile")) + .withColumn("rst_scaley", api.rst_scaley("tile")) + .withColumn("rst_srid", api.rst_srid("tile")) + .withColumn("rst_summary", api.rst_summary("tile")) + # .withColumn("rst_tryopen", api.rst_tryopen(col("path"))) # needs an issue + .withColumn("rst_upperleftx", api.rst_upperleftx("tile")) + .withColumn("rst_upperlefty", api.rst_upperlefty("tile")) + .withColumn("rst_width", api.rst_width("tile")) + .withColumn( + "rst_worldtorastercoordx", + api.rst_worldtorastercoordx("tile", lit(0.0), lit(0.0)), + ) + .withColumn( + "rst_worldtorastercoordy", + api.rst_worldtorastercoordy("tile", lit(0.0), lit(0.0)), + ) + .withColumn( + "rst_worldtorastercoord", + api.rst_worldtorastercoord("tile", lit(0.0), lit(0.0)), + ) + ) + result.write.format("noop").mode("overwrite").save() + self.assertEqual(result.count(), 1) + + def test_raster_flatmap_functions(self): + retile_result = self.generate_singleband_raster_df().withColumn( + "rst_retile", api.rst_retile("tile", lit(1200), lit(1200)) + ) + retile_result.write.format("noop").mode("overwrite").save() + self.assertEqual(retile_result.count(), 4) + + subdivide_result = self.generate_singleband_raster_df().withColumn( + "rst_subdivide", api.rst_subdivide("tile", lit(1)) + ) + subdivide_result.write.format("noop").mode("overwrite").save() + self.assertEqual(retile_result.count(), 4) + + # TODO: reproject into WGS84 + tessellate_result = self.generate_singleband_raster_df().withColumn( + "rst_tessellate", api.rst_tessellate("tile", lit(3)) + ) + + tessellate_result.write.format("noop").mode("overwrite").save() + self.assertEqual(tessellate_result.count(), 55) + + overlap_result = self.generate_singleband_raster_df().withColumn( + "rst_to_overlapping_tiles", + api.rst_to_overlapping_tiles("tile", lit(200), lit(200), lit(10)), + ) + + overlap_result.write.format("noop").mode("overwrite").save() + self.assertEqual(overlap_result.count(), 86) + + def test_raster_aggregator_functions(self): + collection = ( + self.generate_singleband_raster_df() + .withColumn("extent", api.st_astext(api.rst_boundingbox("tile"))) + .withColumn( + "rst_to_overlapping_tiles", + api.rst_to_overlapping_tiles("tile", lit(200), lit(200), lit(10)), + ) + ) + + merge_result = ( + collection.groupBy("path") + .agg(api.rst_merge_agg("tile").alias("tile")) + .withColumn("extent", api.st_astext(api.rst_boundingbox("tile"))) + ) + + self.assertEqual(merge_result.count(), 1) + self.assertEqual(collection.first()["extent"], merge_result.first()["extent"]) + + combine_avg_result = ( + collection.groupBy("path") + .agg(api.rst_combineavg_agg("tile").alias("tile")) + .withColumn("extent", api.st_astext(api.rst_boundingbox("tile"))) + ) + + self.assertEqual(combine_avg_result.count(), 1) + self.assertEqual( + collection.first()["extent"], combine_avg_result.first()["extent"] + ) diff --git a/python/test/test_functions.py b/python/test/test_vector_functions.py similarity index 99% rename from python/test/test_functions.py rename to python/test/test_vector_functions.py index 8b8666dc7..84d2afbd2 100644 --- a/python/test/test_functions.py +++ b/python/test/test_vector_functions.py @@ -6,7 +6,7 @@ from .utils import MosaicTestCase -class TestFunctions(MosaicTestCase): +class TestVectorFunctions(MosaicTestCase): def test_st_point(self): expected = [ "POINT (0 0)", diff --git a/python/test/utils/__init__.py b/python/test/utils/__init__.py index 96a703ab1..c46240505 100644 --- a/python/test/utils/__init__.py +++ b/python/test/utils/__init__.py @@ -1,2 +1,3 @@ from .mosaic_test_case import * from .mosaic_test_case_with_gdal import * +from .setup_gdal import GDALInstaller diff --git a/python/test/utils/mosaic_test_case_with_gdal.py b/python/test/utils/mosaic_test_case_with_gdal.py index 4a5fe321c..233698bab 100644 --- a/python/test/utils/mosaic_test_case_with_gdal.py +++ b/python/test/utils/mosaic_test_case_with_gdal.py @@ -2,6 +2,8 @@ from .mosaic_test_case import MosaicTestCase +from pyspark.sql.dataframe import DataFrame + class MosaicTestCaseWithGDAL(MosaicTestCase): @classmethod @@ -9,3 +11,10 @@ def setUpClass(cls) -> None: super().setUpClass() api.enable_mosaic(cls.spark) api.enable_gdal(cls.spark) + + def generate_singleband_raster_df(self) -> DataFrame: + return ( + self.spark.read.format("gdal") + .option("raster.read.strategy", "in_memory") + .load("test/data/MCD43A4.A2018185.h10v07.006.2018194033728_B04.TIF") + ) diff --git a/python/test/utils/setup_gdal.py b/python/test/utils/setup_gdal.py new file mode 100644 index 000000000..a1febb238 --- /dev/null +++ b/python/test/utils/setup_gdal.py @@ -0,0 +1,36 @@ +import os +import tempfile +import subprocess +from pkg_resources import working_set, Requirement + +from test.context import api + + +class GDALInstaller: + def __init__(self, spark): + self._site_packages = working_set.find(Requirement("keplergl")).location + self._temp_dir = tempfile.TemporaryDirectory() + self.spark = spark + + def __del__(self): + self._temp_dir.cleanup() + + def copy_objects(self): + api.setup_gdal(self.spark, self._temp_dir.name) + + def run_init_script(self): + gdal_install_script_target = os.path.join( + self._temp_dir.name, "mosaic-gdal-init.sh" + ) + os.chmod(gdal_install_script_target, mode=0x744) + result = subprocess.run( + [gdal_install_script_target], + stdout=subprocess.PIPE, + env=dict(os.environ, DATABRICKS_ROOT_VIRTUALENV_ENV=self._site_packages), + ) + print(result.stdout.decode()) + return result.returncode + + def test_gdalinfo(self): + result = subprocess.run(["gdalinfo", "--version"], stdout=subprocess.PIPE) + return result.stdout.decode() diff --git a/python/test/utils/spark_test_case.py b/python/test/utils/spark_test_case.py index 29d9b916c..2c92b5758 100644 --- a/python/test/utils/spark_test_case.py +++ b/python/test/utils/spark_test_case.py @@ -23,7 +23,7 @@ def setUpClass(cls) -> None: .getOrCreate() ) cls.spark.conf.set("spark.databricks.labs.mosaic.jar.autoattach", "false") - cls.spark.sparkContext.setLogLevel("warn") + cls.spark.sparkContext.setLogLevel("WARN") @classmethod def tearDownClass(cls) -> None: diff --git a/src/main/resources/scripts/install-gdal-databricks.sh b/src/main/resources/scripts/install-gdal-databricks.sh index 741ef5031..09916d4b2 100644 --- a/src/main/resources/scripts/install-gdal-databricks.sh +++ b/src/main/resources/scripts/install-gdal-databricks.sh @@ -21,8 +21,8 @@ GDAL_RESOURCE_DIR=$(find $DATABRICKS_ROOT_VIRTUALENV_ENV -name "databricks-mosai # -- untar files to root # - from databricks-mosaic-gdal install dir -tar -xf $GDAL_RESOURCE_DIR/resources/gdal-3.4.3-filetree.tar.xz -C / +tar -xf $GDAL_RESOURCE_DIR/resources/gdal-3.4.3-filetree.tar.xz --skip-old-files -C / # -- untar symlinks to root # - from databricks-mosaic-gdal install dir -tar -xhf $GDAL_RESOURCE_DIR/resources/gdal-3.4.3-symlinks.tar.xz -C / +tar -xhf $GDAL_RESOURCE_DIR/resources/gdal-3.4.3-symlinks.tar.xz --skip-old-files -C / diff --git a/src/main/scala/com/databricks/labs/mosaic/core/crs/CRSBoundsProvider.scala b/src/main/scala/com/databricks/labs/mosaic/core/crs/CRSBoundsProvider.scala index f6e89cf23..c26bbfc77 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/crs/CRSBoundsProvider.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/crs/CRSBoundsProvider.scala @@ -1,9 +1,10 @@ package com.databricks.labs.mosaic.core.crs import java.io.InputStream - import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI +import scala.io.Codec + /** * CRSBoundsProvider provides APIs to get bounds extreme points based on CRS * dataset name (ie. EPSG) and CRS id (ie. 4326). The lookup is not exhaustive @@ -67,7 +68,7 @@ object CRSBoundsProvider { */ def apply(geometryAPI: GeometryAPI): CRSBoundsProvider = { val stream: InputStream = getClass.getResourceAsStream("/CRSBounds.csv") - val lines: List[String] = scala.io.Source.fromInputStream(stream).getLines.toList.drop(1) + val lines: List[String] = scala.io.Source.fromInputStream(stream)(Codec.UTF8).getLines.toList.drop(1) val lookupItems = lines .drop(1) .map(line => { diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/CombineAVG.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/CombineAVG.scala index 14647cebb..e41f82a09 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/CombineAVG.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/CombineAVG.scala @@ -19,20 +19,18 @@ object CombineAVG { * A new raster with average of input rasters. */ def compute(rasters: => Seq[MosaicRasterGDAL]): MosaicRasterGDAL = { + val pythonFunc = """ |import numpy as np |import sys | - |def average(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize,raster_ysize, buf_radius, gt, **kwargs): - | div = np.zeros(in_ar[0].shape) - | for i in range(len(in_ar)): - | div += (in_ar[i] != 0) - | div[div == 0] = 1 - | - | y = np.sum(in_ar, axis = 0, dtype = 'float64') - | y = y / div - | - | np.clip(y,0, sys.float_info.max, out = out_ar) + |def average(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize, buf_radius, gt, **kwargs): + | stacked_array = np.array(in_ar) + | pixel_sum = np.sum(stacked_array, axis=0) + | div = np.sum(stacked_array > 0, axis=0) + | div = np.where(div==0, 1, div) + | np.divide(pixel_sum, div, out=out_ar, casting='unsafe') + | np.clip(out_ar, stacked_array.min(), stacked_array.max(), out=out_ar) |""".stripMargin PixelCombineRasters.combine(rasters, pythonFunc, "average") } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/NDVI.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/NDVI.scala index e907d1eb7..1b60baba8 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/NDVI.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/NDVI.scala @@ -21,10 +21,13 @@ object NDVI { * MosaicRasterGDAL with NDVI computed. */ def compute(raster: => MosaicRasterGDAL, redIndex: Int, nirIndex: Int): MosaicRasterGDAL = { - val ndviPath = PathUtils.createTmpFilePath(raster.uuid.toString, GDAL.getExtension(raster.getDriversShortName)) + val tmpPath = PathUtils.createTmpFilePath(raster.uuid.toString, GDAL.getExtension(raster.getDriversShortName)) + raster.writeToPath(tmpPath) + val tmpRaster = MosaicRasterGDAL(tmpPath, isTemp=true, raster.getParentPath, raster.getDriversShortName, raster.getMemSize) + val ndviPath = PathUtils.createTmpFilePath(raster.uuid.toString + "NDVI", GDAL.getExtension(raster.getDriversShortName)) // noinspection ScalaStyle val gdalCalcCommand = - s"""gdal_calc -A ${raster.getPath} --A_band=$redIndex -B ${raster.getPath} --B_band=$nirIndex --outfile=$ndviPath --calc="(B-A)/(B+A)"""" + s"""gdal_calc -A ${tmpRaster.getPath} --A_band=$redIndex -B ${tmpRaster.getPath} --B_band=$nirIndex --outfile=$ndviPath --calc="(B-A)/(B+A)"""" GDALCalc.executeCalc(gdalCalcCommand, ndviPath) } diff --git a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala index cd4d92a97..eca7f16d6 100644 --- a/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala +++ b/src/main/scala/com/databricks/labs/mosaic/core/raster/operator/gdal/GDALCalc.scala @@ -21,7 +21,7 @@ object GDALCalc { require(gdalCalcCommand.startsWith("gdal_calc"), "Not a valid GDAL Calc command.") import sys.process._ val toRun = gdalCalcCommand.replace("gdal_calc", gdal_calc) - s"sudo python3 $toRun".!! + s"python3 $toRun".!! val result = GDAL.raster(resultPath, resultPath) result } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala index 47271a35b..fdb4bfdf0 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_FromBands.scala @@ -1,6 +1,7 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.raster.operator.merge.MergeBands +import com.databricks.labs.mosaic.core.types.RasterTileType import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.RasterArrayExpression @@ -8,7 +9,6 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} -import org.apache.spark.sql.types.BinaryType /** The expression for stacking and resampling input bands. */ case class RST_FromBands( @@ -16,7 +16,7 @@ case class RST_FromBands( expressionConfig: MosaicExpressionConfig ) extends RasterArrayExpression[RST_FromBands]( bandsExpr, - BinaryType, + RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) @@ -30,7 +30,10 @@ case class RST_FromBands( * @return * The stacked and resampled raster. */ - override def rasterTransform(rasters: => Seq[MosaicRasterTile]): Any = MergeBands.merge(rasters.map(_.getRaster), "bilinear") + override def rasterTransform(rasters: => Seq[MosaicRasterTile]): Any = { + val raster = MergeBands.merge(rasters.map(_.getRaster), "bilinear") + new MosaicRasterTile(rasters.head.getIndex, raster, rasters.head.getParentPath, rasters.head.getDriver) + } } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala index 7a45a3eaa..25c1a0442 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_GetNoData.scala @@ -7,6 +7,7 @@ import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} +import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.types.{ArrayType, DoubleType} /** The expression for extracting the no data value of a raster. */ @@ -31,7 +32,7 @@ case class RST_GetNoData( * The no data value of the raster. */ override def rasterTransform(tile: => MosaicRasterTile): Any = { - tile.getRaster.getBands.map(_.noDataValue) + ArrayData.toArrayData(tile.getRaster.getBands.map(_.noDataValue)) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala index 85621055e..636b079ac 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_NDVI.scala @@ -1,14 +1,14 @@ package com.databricks.labs.mosaic.expressions.raster -import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL import com.databricks.labs.mosaic.core.raster.operator.NDVI +import com.databricks.labs.mosaic.core.types.RasterTileType +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} -import org.apache.spark.sql.types.BinaryType /** The expression for computing NDVI index. */ case class RST_NDVI( @@ -20,7 +20,7 @@ case class RST_NDVI( rastersExpr, redIndex, nirIndex, - BinaryType, + RasterTileType(expressionConfig.getCellIdType), returnsRaster = true, expressionConfig = expressionConfig ) @@ -38,10 +38,11 @@ case class RST_NDVI( * @return * The raster contains NDVI index. */ - override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any, arg2: Any): Any = { val redInd = arg1.asInstanceOf[Int] val nirInd = arg2.asInstanceOf[Int] - NDVI.compute(raster, redInd, nirInd) + val result = NDVI.compute(tile.getRaster, redInd, nirInd) + new MosaicRasterTile(tile.getIndex, result, tile.getParentPath, tile.getDriver) } } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala index e349d032c..11e734a67 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoord.scala @@ -3,6 +3,7 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.geometry.api.GeometryAPI import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -26,10 +27,10 @@ case class RST_RasterToWorldCoord( * GeoTransform. This ensures the projection of the raster is respected. * The output is a WKT point. */ - override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any, arg2: Any): Any = { val x = arg1.asInstanceOf[Int] val y = arg2.asInstanceOf[Int] - val gt = raster.getRaster.GetGeoTransform() + val gt = tile.getRaster.getRaster.GetGeoTransform() val (xGeo, yGeo) = GDAL.toWorldCoord(gt, x, y) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala index 1b6787088..0ef8a7def 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordX.scala @@ -2,6 +2,7 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -24,10 +25,10 @@ case class RST_RasterToWorldCoordX( * Returns the world coordinates of the raster x pixel by applying * GeoTransform. This ensures the projection of the raster is respected. */ - override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any, arg2: Any): Any = { val x = arg1.asInstanceOf[Int] val y = arg2.asInstanceOf[Int] - val gt = raster.getRaster.GetGeoTransform() + val gt = tile.getRaster.getRaster.GetGeoTransform() val (xGeo, _) = GDAL.toWorldCoord(gt, x, y) xGeo diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala index 65ac5decc..2e6703b3c 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_RasterToWorldCoordY.scala @@ -2,6 +2,7 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -24,10 +25,10 @@ case class RST_RasterToWorldCoordY( * Returns the world coordinates of the raster y pixel by applying * GeoTransform. This ensures the projection of the raster is respected. */ - override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any, arg2: Any): Any = { val x = arg1.asInstanceOf[Int] val y = arg2.asInstanceOf[Int] - val gt = raster.getRaster.GetGeoTransform() + val gt = tile.getRaster.getRaster.GetGeoTransform() val (_, yGeo) = GDAL.toWorldCoord(gt, x, y) yGeo diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala index 00d4ac7e0..f526cdb2a 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_TryOpen.scala @@ -12,8 +12,8 @@ import org.apache.spark.sql.types._ /** Returns true if the raster is empty. */ case class RST_TryOpen(raster: Expression, expressionConfig: MosaicExpressionConfig) extends RasterExpression[RST_TryOpen](raster, BooleanType, returnsRaster = false, expressionConfig) - with NullIntolerant - with CodegenFallback { + with NullIntolerant + with CodegenFallback { /** Returns true if the raster can be opened. */ override def rasterTransform(tile: => MosaicRasterTile): Any = { @@ -40,4 +40,4 @@ object RST_TryOpen extends WithExpressionInfo { GenericExpressionFactory.getBaseBuilder[RST_TryOpen](1, expressionConfig) } -} +} \ No newline at end of file diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala index adf8b7c19..6f1774bf3 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoord.scala @@ -2,6 +2,7 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -24,10 +25,10 @@ case class RST_WorldToRasterCoord( * Returns the x and y of the raster by applying GeoTransform as a tuple of * Integers. This will ensure projection of the raster is respected. */ - override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any, arg2: Any): Any = { val xGeo = arg1.asInstanceOf[Double] val yGeo = arg2.asInstanceOf[Double] - val gt = raster.getRaster.GetGeoTransform() + val gt = tile.getRaster.getRaster.GetGeoTransform() val (x, y) = GDAL.fromWorldCoord(gt, xGeo, yGeo) InternalRow.fromSeq(Seq(x, y)) diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala index 00de3af69..7f6c3d65d 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordX.scala @@ -2,6 +2,7 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -24,9 +25,9 @@ case class RST_WorldToRasterCoordX( * Returns the x coordinate of the raster by applying GeoTransform. This * will ensure projection of the raster is respected. */ - override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any, arg2: Any): Any = { val xGeo = arg1.asInstanceOf[Double] - val gt = raster.getRaster.GetGeoTransform() + val gt = tile.getRaster.getRaster.GetGeoTransform() GDAL.fromWorldCoord(gt, xGeo, 0)._1 } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala index 5b3bf09a7..16b2f2831 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/RST_WorldToRasterCoordY.scala @@ -2,6 +2,7 @@ package com.databricks.labs.mosaic.expressions.raster import com.databricks.labs.mosaic.core.raster.api.GDAL import com.databricks.labs.mosaic.core.raster.gdal.MosaicRasterGDAL +import com.databricks.labs.mosaic.core.types.model.MosaicRasterTile import com.databricks.labs.mosaic.expressions.base.{GenericExpressionFactory, WithExpressionInfo} import com.databricks.labs.mosaic.expressions.raster.base.Raster2ArgExpression import com.databricks.labs.mosaic.functions.MosaicExpressionConfig @@ -24,9 +25,9 @@ case class RST_WorldToRasterCoordY( * Returns the y coordinate of the raster by applying GeoTransform. This * will ensure projection of the raster is respected. */ - override def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any = { + override def rasterTransform(tile: => MosaicRasterTile, arg1: Any, arg2: Any): Any = { val xGeo = arg1.asInstanceOf[Double] - val gt = raster.getRaster.GetGeoTransform() + val gt = tile.getRaster.getRaster.GetGeoTransform() GDAL.fromWorldCoord(gt, xGeo, 0)._2 } diff --git a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala index 0146b9380..1db525c82 100644 --- a/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala +++ b/src/main/scala/com/databricks/labs/mosaic/expressions/raster/base/Raster2ArgExpression.scala @@ -66,7 +66,7 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( * @return * A result of the expression. */ - def rasterTransform(raster: => MosaicRasterGDAL, arg1: Any, arg2: Any): Any + def rasterTransform(raster: => MosaicRasterTile, arg1: Any, arg2: Any): Any /** * Evaluation of the expression. It evaluates the raster path and the loads @@ -87,9 +87,9 @@ abstract class Raster2ArgExpression[T <: Expression: ClassTag]( override def nullSafeEval(input: Any, arg1: Any, arg2: Any): Any = { GDAL.enable() val tile = MosaicRasterTile.deserialize(input.asInstanceOf[InternalRow], expressionConfig.getCellIdType) - val raster = tile.getRaster - val result = rasterTransform(raster, arg1, arg2) - val serialized = serialize(result, returnsRaster, dataType, expressionConfig) + val result = rasterTransform(tile, arg1, arg2) + val serialized = serialize(result, returnsRaster, outputType, expressionConfig) + // passed by name makes things re-evaluated RasterCleaner.dispose(tile) serialized }