diff --git a/rioxarray/exceptions.py b/rioxarray/exceptions.py index acf629e6..723a3915 100644 --- a/rioxarray/exceptions.py +++ b/rioxarray/exceptions.py @@ -18,3 +18,11 @@ class OneDimensionalRaster(RioXarrayError): class SingleVariableDataset(RioXarrayError): """This is for when you have a dataset with a single variable.""" + + +class TooManyDimensions(RioXarrayError): + """This is raised when there are more dimensions than is supported by the method""" + + +class InvalidDimensionOrder(RioXarrayError): + """This is raised when there the dimensions are not ordered correctly.""" diff --git a/rioxarray/rioxarray.py b/rioxarray/rioxarray.py index 68c10e76..797b2010 100644 --- a/rioxarray/rioxarray.py +++ b/rioxarray/rioxarray.py @@ -23,7 +23,12 @@ from rasterio.features import geometry_mask from scipy.interpolate import griddata -from rioxarray.exceptions import NoDataInBounds, OneDimensionalRaster +from rioxarray.exceptions import ( + InvalidDimensionOrder, + NoDataInBounds, + OneDimensionalRaster, + TooManyDimensions, +) from rioxarray.crs import crs_to_wkt FILL_VALUE_NAMES = ("_FillValue", "missing_value", "fill_value") @@ -327,6 +332,34 @@ def _int_bounds(self): bottom = float(self._obj[self.y_dim][-1]) return left, bottom, right, top + def _check_dimensions(self): + """ + This function validates that the dimensions 2D/3D and + they are are in the proper order. + + Returns: + -------- + str or None: Name extra dimension. + """ + extra_dims = list(set(list(self._obj.dims)) - set([self.x_dim, self.y_dim])) + if len(extra_dims) > 1: + raise TooManyDimensions("Only 2D and 3D data arrays supported.") + elif extra_dims and self._obj.dims != (extra_dims[0], self.y_dim, self.x_dim): + raise InvalidDimensionOrder( + "Invalid dimension order. Expected order: {0}. " + "You can use `DataArray.transpose{0}`" + " to reorder your dimensions.".format( + (extra_dims[0], self.y_dim, self.x_dim) + ) + ) + elif not extra_dims and self._obj.dims != (self.y_dim, self.x_dim): + raise InvalidDimensionOrder( + "Invalid dimension order. Expected order: {0}" + "You can use `DataArray.transpose{0}` " + "to reorder your dimensions.".format((self.y_dim, self.x_dim)) + ) + return extra_dims[0] if extra_dims else None + def bounds(self, recalc=False): """Determine the bounds of the `xarray.DataArray` @@ -438,17 +471,13 @@ def reproject( dst_affine, dst_width, dst_height = _make_dst_affine( self._obj, self.crs, dst_crs, resolution ) - extra_dims = list(set(list(self._obj.dims)) - set([self.x_dim, self.y_dim])) - if len(extra_dims) > 1: - raise RuntimeError("Reproject only supports 2D and 3D datasets.") - if extra_dims: - assert self._obj.dims == (extra_dims[0], self.y_dim, self.x_dim) + extra_dim = self._check_dimensions() + if extra_dim: dst_data = np.zeros( - (self._obj[extra_dims[0]].size, dst_height, dst_width), + (self._obj[extra_dim].size, dst_height, dst_width), dtype=self._obj.dtype.type, ) else: - assert self._obj.dims == (self.y_dim, self.x_dim) dst_data = np.zeros((dst_height, dst_width), dtype=self._obj.dtype.type) try: @@ -742,13 +771,10 @@ def interpolate_na(self, method="nearest"): :class:`xarray.DataArray`: An interpolated :class:`xarray.DataArray` object. """ - extra_dims = list(set(list(self._obj.dims)) - set([self.x_dim, self.y_dim])) - if len(extra_dims) > 1: - raise RuntimeError("Interpolate only supports 2D and 3D datasets.") - if extra_dims: - assert self._obj.dims == (extra_dims[0], self.y_dim, self.x_dim) + extra_dim = self._check_dimensions() + if extra_dim: interp_data = [] - for _, sub_xds in self._obj.groupby(extra_dims[0]): + for _, sub_xds in self._obj.groupby(extra_dim): interp_data.append( self._interpolate_na(sub_xds.load().data, method=method) ) @@ -769,6 +795,56 @@ def interpolate_na(self, method="nearest"): return interp_array + def to_raster( + self, raster_path, driver="GTiff", dtype=None, tags=None, **profile_kwargs + ): + """ + Export the DataArray to a raster file. + + Parameters + ---------- + raster_path: str + The path to output the raster to. + driver: str, optional + The name of the GDAL/rasterio driver to use to export the raster. + Default is "GTiff". + dtype: str, optional + The data type to write the raster to. Default is the datasets dtype. + tags: dict, optional + A dictionary of tags to write to the raster. + **profile_kwargs + Additional keyword arguments to pass into writing the raster. The + nodata, transform, crs, count, width, and height attributes + are automatically added. + + """ + width, height = self.shape + dtype = str(self._obj.dtype) if dtype is None else dtype + extra_dim = self._check_dimensions() + count = 1 + if extra_dim is not None: + count = self._obj[extra_dim].size + with rasterio.open( + raster_path, + "w", + driver=driver, + height=int(height), + width=int(width), + count=count, + dtype=dtype, + crs=self.crs, + transform=self.transform(recalc=True), + nodata=self.nodata, + **profile_kwargs, + ) as dst: + data = self._obj.values.astype(dtype) + if data.ndim == 2: + dst.write(data, 1) + else: + dst.write(data) + if tags is not None: + dst.update_tags(**tags) + @xarray.register_dataset_accessor("rio") class RasterDataset(XRasterBase): diff --git a/sphinx/history.rst b/sphinx/history.rst index ff7fe9eb..19b0abda 100644 --- a/sphinx/history.rst +++ b/sphinx/history.rst @@ -1,4 +1,6 @@ History ======= -In pre-release stages now. +0.0.4 +------ +- Added ability to export data array to raster (pull #8) diff --git a/test/integration/test_integration_rioxarray.py b/test/integration/test_integration_rioxarray.py index 0d0cf491..6db69b27 100644 --- a/test/integration/test_integration_rioxarray.py +++ b/test/integration/test_integration_rioxarray.py @@ -2,6 +2,7 @@ import numpy import pytest +import rasterio import xarray from numpy.testing import assert_almost_equal, assert_array_equal from rasterio.crs import CRS @@ -672,3 +673,36 @@ def test_geographic_resample_integer(): # mds_interp.to_netcdf(sentinel_2_interp) # test _assert_xarrays_equal(mds_interp, mdc) + + +def test_to_raster(tmpdir): + tmp_raster = tmpdir.join("modis_raster.tiff") + test_tags = {"test": "1"} + with xarray.open_dataarray( + os.path.join(TEST_INPUT_DATA_DIR, "MODIS_ARRAY.nc"), autoclose=True + ) as mda: + mda.rio.to_raster(str(tmp_raster), tags=test_tags) + xds = mda.copy() + + with rasterio.open(str(tmp_raster)) as rds: + assert rds.crs == xds.rio.crs + assert_array_equal(rds.transform, xds.rio.transform()) + assert_array_equal(rds.nodata, xds.rio.nodata) + assert_array_equal(rds.read(1), xds.values) + assert rds.count == 1 + assert rds.tags() == {"AREA_OR_POINT": "Area", **test_tags} + + +def test_to_raster_3d(tmpdir): + tmp_raster = tmpdir.join("planet_3d_raster.tiff") + with xarray.open_dataset( + os.path.join(TEST_INPUT_DATA_DIR, "PLANET_SCOPE_3D.nc"), autoclose=True + ) as mda: + mda.green.rio.to_raster(str(tmp_raster)) + xds = mda.green.copy() + + with rasterio.open(str(tmp_raster)) as rds: + assert rds.crs == xds.rio.crs + assert_array_equal(rds.transform, xds.rio.transform()) + assert_array_equal(rds.nodata, xds.rio.nodata) + assert_array_equal(rds.read(), xds.values)