Skip to content

Commit

Permalink
Implement support for better ndarray copy management in `colour.con…
Browse files Browse the repository at this point in the history
…tinuous.Signal` class.
  • Loading branch information
KelSolaar committed May 16, 2023
1 parent e4dacbc commit 93deeab
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 18 deletions.
36 changes: 20 additions & 16 deletions colour/continuous/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
full,
is_pandas_installed,
multiline_repr,
ndarray_copy,
ndarray_copy_enable,
optional,
required,
runtime_warning,
Expand Down Expand Up @@ -329,7 +331,7 @@ def domain(self) -> NDArrayFloat:
Continuous signal independent domain variable :math:`x`.
"""

return np.copy(self._domain)
return ndarray_copy(self._domain)

@domain.setter
def domain(self, value: ArrayLike):
Expand Down Expand Up @@ -372,7 +374,7 @@ def range(self) -> NDArrayFloat: # noqa: A003
Continuous signal corresponding range variable :math:`y`.
"""

return np.copy(self._range)
return ndarray_copy(self._range)

@range.setter
def range(self, value: ArrayLike): # noqa: A003
Expand Down Expand Up @@ -531,7 +533,7 @@ def function(self) -> Callable:
if self._domain.size != 0 and self._range.size != 0:
self._function = self._extrapolator(
self._interpolator(
self.domain, self.range, **self._interpolator_kwargs
self._domain, self._range, **self._interpolator_kwargs
),
**self._extrapolator_kwargs,
)
Expand Down Expand Up @@ -590,7 +592,7 @@ def __str__(self) -> str:
[ 9. 100.]]
"""

return str(tstack([self.domain, self.range]))
return str(tstack([self._domain, self._range]))

def __repr__(self) -> str:
"""
Expand Down Expand Up @@ -629,17 +631,17 @@ def __repr__(self) -> str:
[
{
"formatter": lambda x: repr( # noqa: ARG005
tstack([self.domain, self.range])
tstack([self._domain, self._range])
),
},
{
"name": "interpolator",
"formatter": lambda x: self.interpolator.__name__, # noqa: ARG005
"formatter": lambda x: self._interpolator.__name__, # noqa: ARG005
},
{"name": "interpolator_kwargs"},
{
"name": "extrapolator",
"formatter": lambda x: self.extrapolator.__name__, # noqa: ARG005
"formatter": lambda x: self._extrapolator.__name__, # noqa: ARG005
},
{"name": "extrapolator_kwargs"},
],
Expand All @@ -657,12 +659,12 @@ def __hash__(self) -> int:

return hash(
(
self.domain.tobytes(),
self.range.tobytes(),
self.interpolator.__name__,
repr(self.interpolator_kwargs),
self.extrapolator.__name__,
repr(self.extrapolator_kwargs),
self._domain.tobytes(),
self._range.tobytes(),
self._interpolator.__name__,
repr(self._interpolator_kwargs),
self._extrapolator.__name__,
repr(self._extrapolator_kwargs),
)
)

Expand Down Expand Up @@ -842,6 +844,7 @@ def __contains__(self, x: ArrayLike) -> bool:
)
)

@ndarray_copy_enable(False)
def __eq__(self, other: Any) -> bool:
"""
Return whether the continuous signal is equal to given other object.
Expand Down Expand Up @@ -948,7 +951,7 @@ def _fill_domain_nan(
variable.
"""

self.domain = fill_nan(self.domain, method, default)
self.domain = fill_nan(self._domain, method, default)

def _fill_range_nan(
self,
Expand All @@ -973,8 +976,9 @@ def _fill_range_nan(
variable.
"""

self.range = fill_nan(self.range, method, default)
self.range = fill_nan(self._range, method, default)

@ndarray_copy_enable(False)
def arithmetical_operation(
self,
a: ArrayLike | AbstractContinuousFunction,
Expand Down Expand Up @@ -1073,7 +1077,7 @@ def arithmetical_operation(
exclusive_or = np.setxor1d(self._domain, a.domain)
self[exclusive_or] = full(exclusive_or.shape, np.nan)
else:
self.range = ioperator(self.range, a)
self.range = ioperator(self._range, a)

return self
else:
Expand Down
8 changes: 8 additions & 0 deletions colour/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@
from_range_100,
from_range_degrees,
from_range_int,
is_ndarray_copy_enabled,
set_ndarray_copy_enable,
ndarray_copy_enable,
ndarray_copy,
closest_indexes,
closest,
interval,
Expand Down Expand Up @@ -217,6 +221,10 @@
"from_range_100",
"from_range_degrees",
"from_range_int",
"is_ndarray_copy_enabled",
"set_ndarray_copy_enable",
"ndarray_copy_enable",
"ndarray_copy",
"closest_indexes",
"closest",
"normalise_maximum",
Expand Down
140 changes: 140 additions & 0 deletions colour/utilities/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@
"from_range_100",
"from_range_degrees",
"from_range_int",
"is_ndarray_copy_enabled",
"set_ndarray_copy_enable",
"ndarray_copy_enable",
"ndarray_copy",
"closest_indexes",
"closest",
"interval",
Expand Down Expand Up @@ -1788,6 +1792,142 @@ def from_range_int(
return a


_NDARRAY_COPY_ENABLED: bool = True
"""
Global variable storing the current *Colour* state for
:class:`numpy.ndarray` copy.
"""


def is_ndarray_copy_enabled() -> bool:
"""
Return whether *Colour* :class:`numpy.ndarray` copy is enabled: Various API
objects return a copy of their internal :class:`numpy.ndarray` for safety
purposes but this can be a slow operation impacting performance.
Returns
-------
:class:`bool`
Whether *Colour* :class:`numpy.ndarray` copy is enabled.
Examples
--------
>>> with ndarray_copy_enable(False):
... is_ndarray_copy_enabled()
...
False
>>> with ndarray_copy_enable(True):
... is_ndarray_copy_enabled()
...
True
"""

return _NDARRAY_COPY_ENABLED


def set_ndarray_copy_enable(enable: bool):
"""
Set *Colour* :class:`numpy.ndarray` copy enabled state.
Parameters
----------
enable
Whether to enable *Colour* :class:`numpy.ndarray` copy.
Examples
--------
>>> with ndarray_copy_enable(is_ndarray_copy_enabled()):
... print(is_ndarray_copy_enabled())
... set_ndarray_copy_enable(False)
... print(is_ndarray_copy_enabled())
...
True
False
"""

global _NDARRAY_COPY_ENABLED

_NDARRAY_COPY_ENABLED = enable


class ndarray_copy_enable:
"""
Define a context manager and decorator temporarily setting *Colour*
:class:`numpy.ndarray` copy enabled state.
Parameters
----------
enable
Whether to enable or disable *Colour* :class:`numpy.ndarray` copy.
"""

def __init__(self, enable: bool) -> None:
self._enable = enable
self._previous_state = is_ndarray_copy_enabled()

def __enter__(self) -> ndarray_copy_enable:
"""
Set the *Colour* :class:`numpy.ndarray` copy enabled state
upon entering the context manager.
"""

set_ndarray_copy_enable(self._enable)

return self

def __exit__(self, *args: Any):
"""
Set the *Colour* :class:`numpy.ndarray` copy enabled state
upon exiting the context manager.
"""

set_ndarray_copy_enable(self._previous_state)

def __call__(self, function: Callable) -> Callable:
"""Call the wrapped definition."""

@functools.wraps(function)
def wrapper(*args: Any, **kwargs: Any) -> Any:
with self:
return function(*args, **kwargs)

return wrapper


def ndarray_copy(a: NDArray) -> NDArray:
"""
Return a :class:`numpy.ndarray` copy if the relevant *Colour* state is
enabled: Various API objects return a copy of their internal
:class:`numpy.ndarray` for safety purposes but this can be a slow operation
impacting performance.
Parameters
----------
a
Array :math:`a` to return a copy of.
Returns
-------
:class:`numpy.ndarray`
Array :math:`a` copy according to *Colour* state.
Examples
--------
>>> a = np.linspace(0, 1, 10)
>>> id(a) == id(ndarray_copy(a))
False
>>> with ndarray_copy_enable(False):
... id(a) == id(ndarray_copy(a))
...
True
"""

if _NDARRAY_COPY_ENABLED:
return np.copy(a)
else:
return a


def closest_indexes(a: ArrayLike, b: ArrayLike) -> NDArray:
"""
Return the array :math:`a` closest element indexes to the reference array
Expand Down
Loading

0 comments on commit 93deeab

Please sign in to comment.