Skip to content

Commit

Permalink
ENH: Interoperability with xarray/dask #479
Browse files Browse the repository at this point in the history
Added first-cut "to_xarray" methods on dimensions, variables, collections,
and fields. Metadata handling is done by xarray using its "decode_cf"
capability. Limitations are listed in the documentation for the
"to_xarray" method on fields.
  • Loading branch information
bekozi committed Apr 3, 2018
1 parent 32748ed commit 954fa5e
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 46 deletions.
34 changes: 0 additions & 34 deletions misc/sh/test-core.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,37 +46,3 @@ inf `grep -v -E "ERROR 1:.*not recognised as an available field" ${OCGIS_TEST_OU
debug "finished run_tests_core()"

}

#function run_tests(){
#
##source ./logging.sh || exit 1
#source ./test-core.sh || exit 1
#
#export RUN_SERIAL_TESTS="true"
##export RUN_SERIAL_TESTS="false"
#
#export RUN_PARALLEL_TESTS="true"
##export RUN_PARALLEL_TESTS="false"
#
## Wall time for the tests (seconds). Equal to 1.5*`time test.sh`.
##WTIME=900
#
#########################################################################################################################
#
#notify "starting test.sh"
#
##cd ${OCGIS_DIR}/misc/sh || { error "Could not cd to OCGIS_DIR: ${OCGIS_DIR}"; exit 1; }
#rm .noseids
#
##$(timeout -k 5 --foreground ${WTIME} bash ./test-core.sh)
#
#run_tests_core
#
##if [ $? == 124 ]; then
## error "Hit wall time (${WTIME}s) when running test-core.sh"
## exit 1
##else
#notify "sucess test.sh"
##fi
#
#}
6 changes: 6 additions & 0 deletions src/ocgis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def deepcopy(self):
"""Return a deep copy of self."""
return deepcopy(self)

def to_xarray(self, **kwargs):
"""
Convert this object to a type understood by ``xarray``. This should be overloaded by subclasses.
"""
raise NotImplementedError


@six.add_metaclass(abc.ABCMeta)
class AbstractNamedObject(AbstractInterfaceObject):
Expand Down
42 changes: 37 additions & 5 deletions src/ocgis/collection/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,34 @@ def axes_shapes(self):
ret['X'] = x
return ret

@property
def bounds_variables(self):
"""
Create a tuple of bounds variables associated with :meth:`~ocgis.collection.field.Field.coordinate_variables`.
:rtype: tuple(:class:`ocgis.Variable`, ...)
"""
ret = [c.bounds for c in self.coordinate_variables if c.bounds is not None]
ret = tuple(ret)
return ret

@property
def coordinate_variables(self):
"""
Return a tuple of spatial coordinate variables. This will attempt to access coordinate variables on the field's
grid. If no grid is available, spatial coordinates will be pulled from the dimension map. The tuple may have a
length of zero if no coordinate variables are available on the field.
Return a tuple of coordinate variables. This will attempt to access spatial coordinate variables on the field's
grid. If no grid is available, spatial coordinates will be pulled from the dimension map. Time will always be
pulled from the field. The tuple may have a length of zero if no coordinate variables are available on the
field.
:rtype: tuple
"""
grid = self.grid
if grid is not None:
ret = grid.coordinate_variables
ret = list(grid.coordinate_variables)
if self.time is not None:
ret.insert(0, self.time)
else:
poss = [self.x, self.y, self.level]
poss = [self.x, self.y, self.level, self.time]
poss = [p for p in poss if p is not None]
ret = tuple(poss)
return ret
Expand Down Expand Up @@ -948,6 +962,24 @@ def set_y(self, variable, dimension, force=True, should_add=True):

set_field_property(self, DimensionMapKey.Y, variable, force, dimension=dimension, should_add=should_add)

def to_xarray(self, **kwargs):
"""
Convert the field to a :class:`xarray.Dataset` with CF metadata interpretation.
Limitations:
* Bounds are treated as data arrays inside the ``xarray`` dataset.
* Integer masked arrays are upcast to float data types in ``xarray``.
* Group hierarchies are not supported in ``xarray``.
:param dict kwargs: Optional keyword arguments to dataset creation. See :meth:`ocgis.VariableCollection.to_xarray`
for additional information.
:rtype: :class:`xarray.Dataset`
"""
from xarray import decode_cf
ret = super(Field, self).to_xarray(**kwargs)
ret = decode_cf(ret)
return ret

def unwrap(self):
"""
Unwrap the field's coordinates contained in its grid and/or geometry.
Expand Down
2 changes: 1 addition & 1 deletion src/ocgis/driver/dimension_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def get_variable(self, entry_key, parent=None, nullable=False):
:type parent: :class:`~ocgis.VariableCollection`
:param bool nullable: If ``True`` and ``parent`` is not ``None``, return ``None`` if the variable is not found
in ``parent``.
:rtype: str
:rtype: str | None
"""
ret = self._get_element_(entry_key, DMK.VARIABLE, None)

Expand Down
3 changes: 3 additions & 0 deletions src/ocgis/test/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
* icclim: test requires ICCLIM
* benchmark: test used for benchmarking/performance
* cli: test related to the command line interface. requires click as a dependency.
* xarray: test related to xarray optional dependency
nosetests -vs --with-id -a '!slow,!remote' ocgis
"""
Expand Down Expand Up @@ -251,6 +252,8 @@ def assertNcEqual(self, uri_src, uri_dest, check_types=True, close=False, metada
self.assertNumpyAllClose(var_value[idx], dvar_value[idx])
else:
self.assertNumpyAll(var_value[idx], dvar_value[idx], check_arr_dtype=check_types)
elif var_value.dtype == np.dtype('|S1'):
self.assertEqual(var_value.tolist(), dvar_value.tolist())
else:
if close:
self.assertNumpyAllClose(var_value, dvar_value)
Expand Down
16 changes: 14 additions & 2 deletions src/ocgis/test/test_ocgis/test_collection/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
from ocgis.base import get_variable_names
from ocgis.collection.field import Field, get_name_mapping
from ocgis.collection.spatial import SpatialCollection
from ocgis.constants import HeaderName, KeywordArgument, DriverKey, DimensionMapKey, DMK
from ocgis.constants import HeaderName, KeywordArgument, DriverKey, DimensionMapKey, DMK, Topology
from ocgis.conv.nc import NcConverter
from ocgis.driver.csv_ import DriverCSV
from ocgis.driver.nc import DriverNetcdf
from ocgis.driver.vector import DriverVector
from ocgis.spatial.base import create_spatial_mask_variable
from ocgis.spatial.geom_cabinet import GeomCabinetIterator
from ocgis.spatial.grid import Grid
from ocgis.test.base import attr, AbstractTestInterface
from ocgis.test.base import attr, AbstractTestInterface, create_gridxy_global, create_exact_field
from ocgis.util.helpers import reduce_multiply
from ocgis.variable.base import Variable
from ocgis.variable.crs import CoordinateReferenceSystem, WGS84, Spherical
Expand Down Expand Up @@ -446,6 +446,18 @@ def test_time(self):
self.assertEqual(f.time.calendar, desired)
self.assertEqual(f.time.bounds.calendar, desired)

@attr('xarray')
def test_to_xarray(self):
grid = create_gridxy_global(crs=Spherical())
field = create_exact_field(grid, 'foo', ntime=3)
field.attrs['i_am_global'] = 'confirm'
field.grid.abstraction = Topology.POINT
field.set_abstraction_geom()
field.time.set_extrapolated_bounds('time_bounds', 'bounds')
xr = field.to_xarray()
self.assertEqual(xr.attrs['i_am_global'], 'confirm')
self.assertGreater(len(xr.coords), 0)

def test_update_crs(self):
# Test copying allows the CRS to be updated on the copy w/out changing the source CRS.
desired = Spherical()
Expand Down
2 changes: 1 addition & 1 deletion src/ocgis/test/test_ocgis/test_util/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_with_no_calc_grouping(self):

ops.prefix = 'ocgis'
ret_ocgis = ops.execute()
self.assertNcEqual(ret, ret_ocgis, check_fill_value=False, check_types=False,
self.assertNcEqual(ret, ret_ocgis, check_fill_value=False, check_types=False, close=True,
ignore_attributes={'global': ['history'], 'ln': ['_FillValue']})

@attr('data')
Expand Down
53 changes: 53 additions & 0 deletions src/ocgis/test/test_ocgis/test_variable/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,6 +1049,7 @@ def test_set_bounds(self):
def test_set_extrapolated_bounds(self):
bv = self.get_boundedvariable(mask=[False, True, False])
self.assertIsNotNone(bv.bounds)
self.assertNotIn('units', bv.bounds.attrs)
bv.set_bounds(None)
self.assertIsNone(bv.bounds)
bv.set_extrapolated_bounds('x_bounds', 'bounds')
Expand Down Expand Up @@ -1121,6 +1122,39 @@ def test_shape(self):
self.assertEqual(len(sub.dimensions[0]), 1)
self.assertEqual(sub.shape, (1,))

@attr('xarray')
def test_to_xarray(self):
# Test a simple variable with a mask and single dimension.
var = Variable(name='foo', value=[1, 2, 3], mask=[False, True, False], dimensions='single', dtype=int,
attrs={'hi': 'there'})
xr = var.to_xarray()
self.assertTrue(np.isnan(xr.values[1]))
self.assertEqual(np.nansum(xr.values), 4)
self.assertEqual(xr.attrs, var.attrs)
self.assertEqual(xr.dims, ('single',))

# Test variable type is maintained without a mask.
var = Variable(name='foo', value=[1, 2, 3], dimensions='single', dtype=int, attrs={'hi': 'there'})
xr = var.to_xarray()
self.assertEqual(xr.values.dtype, int)

# Test a variable with multiple dimensions.
value = np.random.rand(2, 3, 4)
var = Variable(name='nd', value=value, dimensions=('two', 'three', 'four'))
xr = var.to_xarray()
self.assertNumpyAll(var.v(), xr.values)
self.assertEqual(xr.dims, var.dimension_names)

# Test a variable with no dimensions.
var = Variable(name='no_dimensions')
xr = var.to_xarray()
self.assertEqual(xr.name, var.name)

# Test a variable with no dimensions and a scalar value.
var = Variable(name='scalar', value=5.6, dimensions=[])
xr = var.to_xarray()
self.assertEqual(xr.values, 5.6)

def test_units(self):
var = Variable(name='empty')
self.assertIsNone(var.units)
Expand Down Expand Up @@ -1701,3 +1735,22 @@ def test_rename_dimension(self):
self.assertEqual(id(parent), id(vc))
vc.rename_dimension('one', 'one_renamed')
self.assertEqual(var.dimensions[0].name, 'one_renamed')

@attr('xarray')
def test_to_xarray(self):
from xarray import Dataset

vc = self.get_variablecollection()
xr = vc.to_xarray()
self.assertIsInstance(xr, Dataset)
self.assertEqual(vc.keys(), xr.variables.keys())
for k, v in xr.items():
# print(k, v)
ovar_value = vc[k].v()
xr_value = v.values
if ovar_value is None:
self.assertIsNone(xr_value.tolist())
else:
actual = xr_value.tolist()
desired = ovar_value.tolist()
self.assertEqual(actual, desired)
5 changes: 5 additions & 0 deletions src/ocgis/test/test_simple/test_optional_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,8 @@ def test_rtree(self):
si.add(1, Point(1, 2))
ret = list(si.iter_intersects(Point(1, 2), geom_mapping))
self.assertEqual(ret, [1])

@attr('xarray')
def test_xarray(self):
import xarray as xr
_ = xr.DataArray(data=[1, 2, 3])
8 changes: 7 additions & 1 deletion src/ocgis/util/logging_ocgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,16 @@ def get_versions():
v_nose = None
else:
v_nose = nose.__version__
try:
import xarray
except ImportError:
v_xarray = None
else:
v_xarray = xarray.__version__

versions = dict(esmf=v_esmf, cfunits=v_cfunits, rtree=v_rtree, gdal=v_gdal, numpy=v_numpy, netcdf4=v_netcdf4,
icclim=v_icclim, fiona=v_fiona, cf_units=v_cf_units, mpi4py=v_mpi4py, six=v_six, pyproj=v_pyproj,
python=sys.version_info, nose=v_nose)
python=sys.version_info, nose=v_nose, xarray=v_xarray)
return versions


Expand Down
46 changes: 45 additions & 1 deletion src/ocgis/variable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,27 @@ def set_value(self, value, update_mask=False):
if should_set_mask:
self.set_mask(mask_to_set, update=update_mask)

def to_xarray(self):
"""
Convert the variable to a :class:`xarray.DataArray`. This *does not* traverse the parent's hierararchy. Use the
conversion method on the variable's parent to convert all variables in the collection.
:rtype: :class:`xarray.DataArray`
"""
from xarray import DataArray

# Always access a time variable's numeric data.
if hasattr(self, 'value_numtime'):
data = self.value_numtime.data
else:
# Make sure we are passing the masked data array when converting the underlying dataset.
data = self.mv()

# Collect the variable's dimensions.
dims = [d.to_xarray() for d in self.dimensions]

return DataArray(data=data, dims=dims, attrs=self.attrs, name=self.name)

def copy(self):
"""
:return: A shallow copy of the variable.
Expand Down Expand Up @@ -1422,7 +1443,8 @@ def set_bounds(self, value, force=False, clobber_units=None):
self._bounds_name = value.name
self.attrs[bounds_attr_name] = value.name
parent.add_variable(value, force=force)
if clobber_units:
# Do not naively set the units as it may insert a None into the attributes dictionary.
if clobber_units and self.units is not None:
value.units = self.units

# This will synchronize the bounds mask with the variable's mask.
Expand Down Expand Up @@ -2242,6 +2264,28 @@ def strip(self):
self._dimensions = OrderedDict()
self.children = OrderedDict()

def to_xarray(self, **kwargs):
"""
Convert all the variables in the collection to an :class:`xarray.Dataset`.
:param kwargs: Optional keyword arguments to pass to the dataset creation. ``data_vars`` and ``attrs`` are
always overloaded by this method.
:rtype: :class:`xarray.Dataset`
"""
from xarray import Dataset

data_vars = OrderedDict()
# Convert each variable to data array.
for v in self.values():
data_vars[v.name] = v.to_xarray()

# Create the arguments for the dataset creation.
kwargs = kwargs.copy()
kwargs['data_vars'] = data_vars
kwargs['attrs'] = self.attrs

return Dataset(**kwargs)

def write(self, *args, **kwargs):
"""
Write the variable collection to file.
Expand Down
13 changes: 12 additions & 1 deletion src/ocgis/variable/crs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _update_attr_(target, name, value, clobber=False):
updates = {}
for c in spatial_obj.coordinate_variables:
i = spatial_obj.dimension_map.inquire_is_xyz(c)
if i != DMK.LEVEL:
if i != DMK.LEVEL and i is not None:
updates[c] = self._cf_attributes[i]

for target, name_values in updates.items():
Expand Down Expand Up @@ -309,6 +309,17 @@ def prepare_geometry_variable(self, subset_geom, rhs_tol=10.0, inplace=True):
def set_string_max_length_global(self, value=None):
"""Here for variable compatibility."""

def to_xarray(self, **kwargs):
"""
Convert the CRS variable to a :class:`xarray.DataArray`. This *does not* traverse the parent's hierararchy. Use
the conversion method on the variable's parent to convert all variables in the collection.
:rtype: :class:`xarray.DataArray`
"""
from xarray import DataArray

return DataArray(attrs=self.attrs, name=self.name, data=[])

def wrap_or_unwrap(self, action, target, force=False):
from ocgis.variable.geom import GeometryVariable
from ocgis.spatial.grid import Grid
Expand Down
8 changes: 8 additions & 0 deletions src/ocgis/variable/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,14 @@ def set_size(self, value, src_idx=None):
msg = 'Distributed dimensions require a size definition using "size" or "size_current".'
raise ValueError(msg)

def to_xarray(self):
"""
Convert this object to a type understood by ``xarray``.
:rtype: str
"""
return self.name

def __getitem_main__(self, ret, slc):
length_self = len(self)
try:
Expand Down

0 comments on commit 954fa5e

Please sign in to comment.