Skip to content

Commit

Permalink
Merge pull request #359 from w-k-jones/wrap_decorators
Browse files Browse the repository at this point in the history
Wrap decorators to preserve docstrings and move to separate module
  • Loading branch information
w-k-jones authored Nov 21, 2023
2 parents d92d1ba + c137f44 commit 5518611
Show file tree
Hide file tree
Showing 5 changed files with 376 additions and 367 deletions.
2 changes: 1 addition & 1 deletion tobac/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
from pandas.testing import assert_frame_equal
from copy import deepcopy
from tobac.utils.internal import (
from tobac.utils.decorators import (
xarray_to_iris,
iris_to_xarray,
xarray_to_irispandas,
Expand Down
4 changes: 2 additions & 2 deletions tobac/utils/bulk_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import logging
from . import internal as internal_utils
from . import decorators
from typing import Callable, Union
from functools import partial
import numpy as np
Expand Down Expand Up @@ -141,7 +141,7 @@ def get_statistics(
return features


@internal_utils.iris_to_xarray
@decorators.iris_to_xarray
def get_statistics_from_mask(
segmentation_mask: xr.DataArray,
*fields: xr.DataArray,
Expand Down
370 changes: 370 additions & 0 deletions tobac/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
"""Decorators for use with other tobac functions
"""

import functools
import warnings


def iris_to_xarray(func):
"""Decorator that converts all input of a function that is in the form of
Iris cubes into xarray DataArrays and converts all outputs with type
xarray DataArrays back into Iris cubes.
Parameters
----------
func : function
Function to be decorated
Returns
-------
wrapper : function
Function including decorator
"""

import iris
import xarray

@functools.wraps(func)
def wrapper(*args, **kwargs):
# print(kwargs)
if any([type(arg) == iris.cube.Cube for arg in args]) or any(
[type(arg) == iris.cube.Cube for arg in kwargs.values()]
):
# print("converting iris to xarray and back")
args = tuple(
[
xarray.DataArray.from_iris(arg)
if type(arg) == iris.cube.Cube
else arg
for arg in args
]
)
kwargs_new = dict(
zip(
kwargs.keys(),
[
xarray.DataArray.from_iris(arg)
if type(arg) == iris.cube.Cube
else arg
for arg in kwargs.values()
],
)
)
# print(args)
# print(kwargs)
output = func(*args, **kwargs_new)
if type(output) == tuple:
output = tuple(
[
xarray.DataArray.to_iris(output_item)
if type(output_item) == xarray.DataArray
else output_item
for output_item in output
]
)
elif type(output) == xarray.DataArray:
output = xarray.DataArray.to_iris(output)
# if output is neither tuple nor an xr.DataArray
else:
output = func(*args, **kwargs)

else:
output = func(*args, **kwargs)
return output

return wrapper


def xarray_to_iris(func):
"""Decorator that converts all input of a function that is in the form of
xarray DataArrays into Iris cubes and converts all outputs with type
Iris cubes back into xarray DataArrays.
Parameters
----------
func : function
Function to be decorated.
Returns
-------
wrapper : function
Function including decorator.
Examples
--------
>>> segmentation_xarray = xarray_to_iris(segmentation)
This line creates a new function that can process xarray fields and
also outputs fields in xarray format, but otherwise works just like
the original function:
>>> mask_xarray, features = segmentation_xarray(
features, data_xarray, dxy, threshold
)
"""

import iris
import xarray

@functools.wraps(func)
def wrapper(*args, **kwargs):
# print(args)
# print(kwargs)
if any([type(arg) == xarray.DataArray for arg in args]) or any(
[type(arg) == xarray.DataArray for arg in kwargs.values()]
):
# print("converting xarray to iris and back")
args = tuple(
[
xarray.DataArray.to_iris(arg)
if type(arg) == xarray.DataArray
else arg
for arg in args
]
)
if kwargs:
kwargs_new = dict(
zip(
kwargs.keys(),
[
xarray.DataArray.to_iris(arg)
if type(arg) == xarray.DataArray
else arg
for arg in kwargs.values()
],
)
)
else:
kwargs_new = kwargs
# print(args)
# print(kwargs)
output = func(*args, **kwargs_new)
if type(output) == tuple:
output = tuple(
[
xarray.DataArray.from_iris(output_item)
if type(output_item) == iris.cube.Cube
else output_item
for output_item in output
]
)
else:
if type(output) == iris.cube.Cube:
output = xarray.DataArray.from_iris(output)

else:
output = func(*args, **kwargs)
# print(output)
return output

return wrapper


def irispandas_to_xarray(func):
"""Decorator that converts all input of a function that is in the form of
Iris cubes/pandas Dataframes into xarray DataArrays/xarray Datasets and
converts all outputs with the type xarray DataArray/xarray Dataset
back into Iris cubes/pandas Dataframes.
Parameters
----------
func : function
Function to be decorated.
Returns
-------
wrapper : function
Function including decorator.
"""
import iris
import xarray
import pandas as pd

@functools.wraps(func)
def wrapper(*args, **kwargs):
# print(kwargs)
if any(
[(type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame) for arg in args]
) or any(
[
(type(arg) == iris.cube.Cube or type(arg) == pd.DataFrame)
for arg in kwargs.values()
]
):
# print("converting iris to xarray and back")
args = tuple(
[
xarray.DataArray.from_iris(arg)
if type(arg) == iris.cube.Cube
else arg.to_xarray()
if type(arg) == pd.DataFrame
else arg
for arg in args
]
)
kwargs = dict(
zip(
kwargs.keys(),
[
xarray.DataArray.from_iris(arg)
if type(arg) == iris.cube.Cube
else arg.to_xarray()
if type(arg) == pd.DataFrame
else arg
for arg in kwargs.values()
],
)
)

output = func(*args, **kwargs)
if type(output) == tuple:
output = tuple(
[
xarray.DataArray.to_iris(output_item)
if type(output_item) == xarray.DataArray
else output_item.to_dataframe()
if type(output_item) == xarray.Dataset
else output_item
for output_item in output
]
)
else:
if type(output) == xarray.DataArray:
output = xarray.DataArray.to_iris(output)
elif type(output) == xarray.Dataset:
output = output.to_dataframe()

else:
output = func(*args, **kwargs)
return output

return wrapper


def xarray_to_irispandas(func):
"""Decorator that converts all input of a function that is in the form of
DataArrays/xarray Datasets into xarray Iris cubes/pandas Dataframes and
converts all outputs with the type Iris cubes/pandas Dataframes back into
xarray DataArray/xarray Dataset.
Parameters
----------
func : function
Function to be decorated.
Returns
-------
wrapper : function
Function including decorator.
Examples
--------
>>> linking_trackpy_xarray = xarray_to_irispandas(
linking_trackpy
)
This line creates a new function that can process xarray inputs and
also outputs in xarray formats, but otherwise works just like the
original function:
>>> track_xarray = linking_trackpy_xarray(
features_xarray, field_xarray, dt, dx
)
"""
import iris
import xarray
import pandas as pd

@functools.wraps(func)
def wrapper(*args, **kwargs):
# print(args)
# print(kwargs)
if any(
[
(type(arg) == xarray.DataArray or type(arg) == xarray.Dataset)
for arg in args
]
) or any(
[
(type(arg) == xarray.DataArray or type(arg) == xarray.Dataset)
for arg in kwargs.values()
]
):
# print("converting xarray to iris and back")
args = tuple(
[
xarray.DataArray.to_iris(arg)
if type(arg) == xarray.DataArray
else arg.to_dataframe()
if type(arg) == xarray.Dataset
else arg
for arg in args
]
)
if kwargs:
kwargs_new = dict(
zip(
kwargs.keys(),
[
xarray.DataArray.to_iris(arg)
if type(arg) == xarray.DataArray
else arg.to_dataframe()
if type(arg) == xarray.Dataset
else arg
for arg in kwargs.values()
],
)
)
else:
kwargs_new = kwargs
# print(args)
# print(kwargs)
output = func(*args, **kwargs_new)
if type(output) == tuple:
output = tuple(
[
xarray.DataArray.from_iris(output_item)
if type(output_item) == iris.cube.Cube
else output_item.to_xarray()
if type(output_item) == pd.DataFrame
else output_item
for output_item in output
]
)
else:
if type(output) == iris.cube.Cube:
output = xarray.DataArray.from_iris(output)
elif type(output) == pd.DataFrame:
output = output.to_xarray()

else:
output = func(*args, **kwargs)
# print(output)
return output

return wrapper


def njit_if_available(func, **kwargs):
"""Decorator to wrap a function with numba.njit if available.
If numba isn't available, it just returns the function.
Parameters
----------
func: function object
Function to wrap with njit
kwargs:
Keyword arguments to pass to numba njit
"""
try:
from numba import njit

return njit(func, kwargs)
except KeyboardInterrupt as kie:
raise
except Exception as exc:
warnings.warn(
"Numba not able to be imported; periodic boundary calculations will be slower."
"Exception raised: " + repr(exc)
)
return func
Loading

0 comments on commit 5518611

Please sign in to comment.