diff --git a/speasy/products/variable.py b/speasy/products/variable.py index b3d42c1e..367ab553 100644 --- a/speasy/products/variable.py +++ b/speasy/products/variable.py @@ -10,11 +10,9 @@ DataContainer, VariableAxis, VariableTimeAxis, - _to_index, - DataContainerProtocol + _to_index ) from speasy.plotting import Plot - from .base_product import SpeasyProduct @@ -121,9 +119,8 @@ def view(self, index_range: Union[slice, np.ndarray]) -> "SpeasyVariable": speasy.common.variable.SpeasyVariable view of the variable on the given range """ - if type(index_range) is np.ndarray: - if index_range.dtype == bool: - index_range = np.nonzero(index_range)[0] + if (type(index_range) is np.ndarray) and (index_range.dtype == bool): + index_range = np.nonzero(index_range)[0] return SpeasyVariable( axes=[ axis[index_range] if axis.is_time_dependent else axis @@ -192,7 +189,7 @@ def __eq__(self, other: Union["SpeasyVariable", float, int]) -> bool: def __ne__(self, other: Union["SpeasyVariable", float, int]) -> bool: if type(other) is SpeasyVariable: - return not other == self + return not (other == self) else: return self.values.__ne__(other) @@ -322,6 +319,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, out: 'SpeasyVariable' or None if isinstance(out, SpeasyVariable): out.__axes = axes return out + if type(values) is np.ndarray and values.dtype == bool: + return np.all(values, axis=1, keepdims=True) else: return SpeasyVariable( axes=axes,