Skip to content

Commit

Permalink
added ability to export data array to raster (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 authored Jun 25, 2019
1 parent ba18c36 commit d4557fd
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 15 deletions.
8 changes: 8 additions & 0 deletions rioxarray/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,11 @@ class OneDimensionalRaster(RioXarrayError):

class SingleVariableDataset(RioXarrayError):
"""This is for when you have a dataset with a single variable."""


class TooManyDimensions(RioXarrayError):
"""This is raised when there are more dimensions than is supported by the method"""


class InvalidDimensionOrder(RioXarrayError):
"""This is raised when there the dimensions are not ordered correctly."""
104 changes: 90 additions & 14 deletions rioxarray/rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from rasterio.features import geometry_mask
from scipy.interpolate import griddata

from rioxarray.exceptions import NoDataInBounds, OneDimensionalRaster
from rioxarray.exceptions import (
InvalidDimensionOrder,
NoDataInBounds,
OneDimensionalRaster,
TooManyDimensions,
)
from rioxarray.crs import crs_to_wkt

FILL_VALUE_NAMES = ("_FillValue", "missing_value", "fill_value")
Expand Down Expand Up @@ -327,6 +332,34 @@ def _int_bounds(self):
bottom = float(self._obj[self.y_dim][-1])
return left, bottom, right, top

def _check_dimensions(self):
"""
This function validates that the dimensions 2D/3D and
they are are in the proper order.
Returns:
--------
str or None: Name extra dimension.
"""
extra_dims = list(set(list(self._obj.dims)) - set([self.x_dim, self.y_dim]))
if len(extra_dims) > 1:
raise TooManyDimensions("Only 2D and 3D data arrays supported.")
elif extra_dims and self._obj.dims != (extra_dims[0], self.y_dim, self.x_dim):
raise InvalidDimensionOrder(
"Invalid dimension order. Expected order: {0}. "
"You can use `DataArray.transpose{0}`"
" to reorder your dimensions.".format(
(extra_dims[0], self.y_dim, self.x_dim)
)
)
elif not extra_dims and self._obj.dims != (self.y_dim, self.x_dim):
raise InvalidDimensionOrder(
"Invalid dimension order. Expected order: {0}"
"You can use `DataArray.transpose{0}` "
"to reorder your dimensions.".format((self.y_dim, self.x_dim))
)
return extra_dims[0] if extra_dims else None

def bounds(self, recalc=False):
"""Determine the bounds of the `xarray.DataArray`
Expand Down Expand Up @@ -438,17 +471,13 @@ def reproject(
dst_affine, dst_width, dst_height = _make_dst_affine(
self._obj, self.crs, dst_crs, resolution
)
extra_dims = list(set(list(self._obj.dims)) - set([self.x_dim, self.y_dim]))
if len(extra_dims) > 1:
raise RuntimeError("Reproject only supports 2D and 3D datasets.")
if extra_dims:
assert self._obj.dims == (extra_dims[0], self.y_dim, self.x_dim)
extra_dim = self._check_dimensions()
if extra_dim:
dst_data = np.zeros(
(self._obj[extra_dims[0]].size, dst_height, dst_width),
(self._obj[extra_dim].size, dst_height, dst_width),
dtype=self._obj.dtype.type,
)
else:
assert self._obj.dims == (self.y_dim, self.x_dim)
dst_data = np.zeros((dst_height, dst_width), dtype=self._obj.dtype.type)

try:
Expand Down Expand Up @@ -742,13 +771,10 @@ def interpolate_na(self, method="nearest"):
:class:`xarray.DataArray`: An interpolated :class:`xarray.DataArray` object.
"""
extra_dims = list(set(list(self._obj.dims)) - set([self.x_dim, self.y_dim]))
if len(extra_dims) > 1:
raise RuntimeError("Interpolate only supports 2D and 3D datasets.")
if extra_dims:
assert self._obj.dims == (extra_dims[0], self.y_dim, self.x_dim)
extra_dim = self._check_dimensions()
if extra_dim:
interp_data = []
for _, sub_xds in self._obj.groupby(extra_dims[0]):
for _, sub_xds in self._obj.groupby(extra_dim):
interp_data.append(
self._interpolate_na(sub_xds.load().data, method=method)
)
Expand All @@ -769,6 +795,56 @@ def interpolate_na(self, method="nearest"):

return interp_array

def to_raster(
self, raster_path, driver="GTiff", dtype=None, tags=None, **profile_kwargs
):
"""
Export the DataArray to a raster file.
Parameters
----------
raster_path: str
The path to output the raster to.
driver: str, optional
The name of the GDAL/rasterio driver to use to export the raster.
Default is "GTiff".
dtype: str, optional
The data type to write the raster to. Default is the datasets dtype.
tags: dict, optional
A dictionary of tags to write to the raster.
**profile_kwargs
Additional keyword arguments to pass into writing the raster. The
nodata, transform, crs, count, width, and height attributes
are automatically added.
"""
width, height = self.shape
dtype = str(self._obj.dtype) if dtype is None else dtype
extra_dim = self._check_dimensions()
count = 1
if extra_dim is not None:
count = self._obj[extra_dim].size
with rasterio.open(
raster_path,
"w",
driver=driver,
height=int(height),
width=int(width),
count=count,
dtype=dtype,
crs=self.crs,
transform=self.transform(recalc=True),
nodata=self.nodata,
**profile_kwargs,
) as dst:
data = self._obj.values.astype(dtype)
if data.ndim == 2:
dst.write(data, 1)
else:
dst.write(data)
if tags is not None:
dst.update_tags(**tags)


@xarray.register_dataset_accessor("rio")
class RasterDataset(XRasterBase):
Expand Down
4 changes: 3 additions & 1 deletion sphinx/history.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
History
=======

In pre-release stages now.
0.0.4
------
- Added ability to export data array to raster (pull #8)
34 changes: 34 additions & 0 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy
import pytest
import rasterio
import xarray
from numpy.testing import assert_almost_equal, assert_array_equal
from rasterio.crs import CRS
Expand Down Expand Up @@ -672,3 +673,36 @@ def test_geographic_resample_integer():
# mds_interp.to_netcdf(sentinel_2_interp)
# test
_assert_xarrays_equal(mds_interp, mdc)


def test_to_raster(tmpdir):
tmp_raster = tmpdir.join("modis_raster.tiff")
test_tags = {"test": "1"}
with xarray.open_dataarray(
os.path.join(TEST_INPUT_DATA_DIR, "MODIS_ARRAY.nc"), autoclose=True
) as mda:
mda.rio.to_raster(str(tmp_raster), tags=test_tags)
xds = mda.copy()

with rasterio.open(str(tmp_raster)) as rds:
assert rds.crs == xds.rio.crs
assert_array_equal(rds.transform, xds.rio.transform())
assert_array_equal(rds.nodata, xds.rio.nodata)
assert_array_equal(rds.read(1), xds.values)
assert rds.count == 1
assert rds.tags() == {"AREA_OR_POINT": "Area", **test_tags}


def test_to_raster_3d(tmpdir):
tmp_raster = tmpdir.join("planet_3d_raster.tiff")
with xarray.open_dataset(
os.path.join(TEST_INPUT_DATA_DIR, "PLANET_SCOPE_3D.nc"), autoclose=True
) as mda:
mda.green.rio.to_raster(str(tmp_raster))
xds = mda.green.copy()

with rasterio.open(str(tmp_raster)) as rds:
assert rds.crs == xds.rio.crs
assert_array_equal(rds.transform, xds.rio.transform())
assert_array_equal(rds.nodata, xds.rio.nodata)
assert_array_equal(rds.read(), xds.values)

0 comments on commit d4557fd

Please sign in to comment.