Skip to content

Commit

Permalink
Handles numpy functions that returns arrays of bool
Browse files Browse the repository at this point in the history
Signed-off-by: Alexis Jeandet <[email protected]>
  • Loading branch information
jeandet committed Jan 9, 2025
1 parent 48aad11 commit 6fbed5b
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions speasy/products/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@
DataContainer,
VariableAxis,
VariableTimeAxis,
_to_index,
DataContainerProtocol
_to_index
)
from speasy.plotting import Plot

from .base_product import SpeasyProduct


Expand Down Expand Up @@ -121,9 +119,8 @@ def view(self, index_range: Union[slice, np.ndarray]) -> "SpeasyVariable":
speasy.common.variable.SpeasyVariable
view of the variable on the given range
"""
if type(index_range) is np.ndarray:
if index_range.dtype == bool:
index_range = np.nonzero(index_range)[0]
if (type(index_range) is np.ndarray) and (index_range.dtype == bool):
index_range = np.nonzero(index_range)[0]
return SpeasyVariable(
axes=[
axis[index_range] if axis.is_time_dependent else axis
Expand Down Expand Up @@ -192,7 +189,7 @@ def __eq__(self, other: Union["SpeasyVariable", float, int]) -> bool:

def __ne__(self, other: Union["SpeasyVariable", float, int]) -> bool:
if type(other) is SpeasyVariable:
return not other == self
return not (other == self)
else:
return self.values.__ne__(other)

Expand Down Expand Up @@ -322,6 +319,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, out: 'SpeasyVariable' or None
if isinstance(out, SpeasyVariable):
out.__axes = axes
return out
if type(values) is np.ndarray and values.dtype == bool:
return np.all(values, axis=1, keepdims=True)
else:
return SpeasyVariable(
axes=axes,
Expand Down

0 comments on commit 6fbed5b

Please sign in to comment.