Skip to content

Commit

Permalink
Fixed a broken python test from refactor. Live tested rasterio udf ex…
Browse files Browse the repository at this point in the history
…amples (last 2) and updated code and comments. Added additional json_spec examples to rst_mapalgebra.
  • Loading branch information
mjohns-databricks committed Jan 23, 2024
1 parent 87d45a2 commit de74df2
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 41 deletions.
6 changes: 6 additions & 0 deletions docs/source/api/raster-functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,12 @@ rst_mapalgebra
arrays (such as +, -, *, and /) along with logical operators (such as >, <, =). For this distributed implementation,
all rasters must have the same dimensions and no projection checking is performed.
Here are examples of the json_spec': (1) shows default indexing, (2) shows reusing an index,
and (3) shows band indexing.
(1) '{"calc": "A+B/C"}'
(2) '{"calc": "A+B/C", "A_index": 0, "B_index": 1, "C_index": 1}'
(3) '{"calc": "A+B/C", "A_index": 0, "B_index": 1, "C_index": 2, "A_band": 1, "B_band": 1, "C_band": 1}'

:param tile: A column containing the raster tile.
:type tile: Column (RasterTileType)
:param json_spec: A column containing the map algebra operation specification.
Expand Down
83 changes: 44 additions & 39 deletions docs/source/api/rasterio-udfs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,11 @@ Firstly we will create a spark DataFrame from a directory of raster files.
Next we will define a function that will write a given raster file to disk. A "gotcha" to keep in mind is that you do
not want to have a file context manager open when you go to write out its context as the context manager will not yet
have been flushed.
have been flushed. Another "gotcha" might be that the raster dataset does not have CRS included; if this arises, we
recommend adjusting the function to specify the CRS and set it on the dst variable, more at
`rasterio.crs <https://rasterio.readthedocs.io/en/stable/api/rasterio.crs.html>`_. We would also point out that notional
"file_id" param can be constructed as a repeatable name from other field(s) in your dataframe / table or be random,
depending on your needs.
.. code-block:: python
Expand All @@ -253,31 +257,30 @@ have been flushed.
# - [1] populate the initial profile
# # profile is needed in order to georeference the image
profile = None
with MemoryFile(BytesIO(raster)) as memfile:
with memfile.open() as dataset:
profile = dataset.profile
# - [2] get the correct extension
extensions_map = rasterio.drivers.raster_driver_extensions()
driver_map = {v: k for k, v in extensions_map.items()}
extension = driver_map[driver] #e.g. GTiff
file_name = f"{file_id}.{extension}"
# - [3] write local raster
# - this is showing a single band [1]
# being written
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = f"{tmp_dir}/{file_name}"
profile = None
data_arr = None
with MemoryFile(BytesIO(raster)) as memfile:
with memfile.open() as dataset:
profile = dataset.profile
data_arr = dataset.read()
# here you can update profile using .update method
# example https://rasterio.readthedocs.io/en/latest/topics/writing.html
# - [2] get the correct extension
extensions_map = rasterio.drivers.raster_driver_extensions()
driver_map = {v: k for k, v in extensions_map.items()}
extension = driver_map[driver] #e.g. GTiff
file_name = f"{file_id}.{extension}"
# - [3] write local raster
# - this is showing a single band [1]
# being written
tmp_path = f"{tmp_dir}/{file_name}"
with rasterio.open(
tmp_path,
"w",
**profile
tmp_path,
"w",
**profile
) as dst:
dst.write(raster,1) # <- adjust as needed
dst.write(data_arr) # <- adjust as needed
# - [4] copy to fuse path
Path(fuse_dir).mkdir(parents=True, exist_ok=True)
fuse_path = f"{fuse_dir}/{file_name}"
Expand All @@ -294,25 +297,27 @@ Finally we will apply the function to the DataFrame.
"tile.raster",
lit("GTiff").alias("driver"),
"uuid",
lit("dbfs:/path/to/output/dir").alias("fuse_dir")
lit("/dbfs/path/to/output/dir").alias("fuse_dir")
)
).display()
+----------------------------------------------+
| write_raster(raster, driver, uuid, fuse_dir) |
+----------------------------------------------+
| dbfs:/path/to/output/dir/1234.tif |
| dbfs:/path/to/output/dir/4545.tif |
| dbfs:/path/to/output/dir/3215.tif |
| /dbfs/path/to/output/dir/1234.tif |
| /dbfs/path/to/output/dir/4545.tif |
| /dbfs/path/to/output/dir/3215.tif |
| ... |
+----------------------------------------------+
Sometimes you don't need to be quite as fancy. Consider when you simply want to specify to write out raster contents,
assuming you specify the extension in the file_id. This is just writing binary column to file, nothing further.
assuming you specify the extension in the file_name. This is just writing binary column to file, nothing further. Again,
we use a notional "uuid" column as part of "file_name" param, which would have the same considerations as mentioned
above.
.. code-block:: python
@udf("string")
def write_contents(raster, file_name, fuse_dir):
def write_binary(raster_bin, file_name, fuse_dir):
from pathlib import Path
import os
import shutil
Expand All @@ -326,7 +331,7 @@ assuming you specify the extension in the file_id. This is just writing binary c
# - write within the tmp_dir context
# - flush the writer before copy
tmp_file = open(tmp_path, "wb")
tmp_file.write(raster) # <- write entire binary content
tmp_file.write(raster_bin) # <- write entire binary content
tmp_file.close()
# - copy local to fuse
shutil.copyfile(tmp_path, fuse_path)
Expand All @@ -337,17 +342,17 @@ Finally we will apply the function to the DataFrame.
.. code-block:: python
df.select(
write_contents(
write_binary(
"tile.raster",
F.concat("uuid", F.lit(".tif").alias("file_name"),
lit("dbfs:/path/to/output/dir").alias("fuse_dir")
F.concat("uuid", F.lit(".tif")).alias("file_name"),
F.lit("/dbfs/path/to/output/dir").alias("fuse_dir")
)
).display()
+----------------------------------------+
| write_tif(raster, file_name, fuse_dir) |
+----------------------------------------+
| dbfs:/path/to/output/dir/1234.tif |
| dbfs:/path/to/output/dir/4545.tif |
| dbfs:/path/to/output/dir/3215.tif |
| ... |
+----------------------------------------+
+-------------------------------------------+
| write_binary(raster, file_name, fuse_dir) |
+-------------------------------------------+
| /dbfs/path/to/output/dir/1234.tif |
| /dbfs/path/to/output/dir/4545.tif |
| /dbfs/path/to/output/dir/3215.tif |
| ... |
+-------------------------------------------+
4 changes: 2 additions & 2 deletions python/test/test_vector_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def test_aggregation_functions(self):
.join(right_df, col("left_index.index_id") == col("right_index.index_id"))
.groupBy("left_id", "right_id")
.agg(
api.st_intersects_aggregate(
api.st_intersects_agg(
col("left_index"), col("right_index")
).alias("agg_intersects"),
api.st_intersection_aggregate(
api.st_intersection_agg(
col("left_index"), col("right_index")
).alias("agg_intersection"),
first("left_geom").alias("left_geom"),
Expand Down

0 comments on commit de74df2

Please sign in to comment.