diff --git a/speasy/core/data_containers.py b/speasy/core/data_containers.py index 01246a21..945578e0 100644 --- a/speasy/core/data_containers.py +++ b/speasy/core/data_containers.py @@ -1,7 +1,7 @@ from copy import deepcopy from datetime import datetime from sys import getsizeof -from typing import Dict, List, Tuple, Protocol, TypeVar +from typing import Dict, List, Tuple, Protocol, TypeVar, Union import astropy.units import numpy as np @@ -38,19 +38,16 @@ def from_dictionary(dictionary: Dict[str, str or Dict[str, str] or List], dtype= def reserve_like(other: T, length: int = 0) -> T: ... - def __getitem__(self, key)->T: + def __getitem__(self, key) -> T: ... - def __setitem__(self, k, v: T): + def __setitem__(self, k, v: Union[T, float, int]): ... - def __len__(self)->int: + def __len__(self) -> int: ... - def view(self, index_range: slice) -> T: - ... - - def __eq__(self, other: T) -> bool: + def __eq__(self, other: Union[T, float, int]) -> Union[bool, np.ndarray]: ... @property @@ -77,6 +74,9 @@ def name(self) -> str: def nbytes(self) -> int: ... + def view(self, index_range: Union[slice, np.ndarray]) -> T: + ... + class DataContainer(DataContainerProtocol['DataContainer']): __slots__ = ['__values', '__name', '__meta', '__is_time_dependent'] @@ -128,7 +128,7 @@ def unit(self) -> str: def nbytes(self) -> int: return self.__values.nbytes + getsizeof(self.__meta) + getsizeof(self.__name) - def view(self, index_range: slice): + def view(self, index_range: Union[slice, np.ndarray]): return DataContainer(name=self.__name, meta=self.__meta, values=self.__values[index_range], is_time_dependent=self.__is_time_dependent) @@ -201,27 +201,21 @@ def __len__(self): def __getitem__(self, key): return self.view(key) - def __setitem__(self, k, v: 'DataContainer'): - assert type(v) is DataContainer - self.__values[k] = v.__values - - def __eq__(self, other: 'DataContainer') -> bool: - return self.__meta == other.__meta and \ - self.__name == other.__name and \ - self.is_time_dependent == other.is_time_dependent and \ - np.all(self.__values.shape == other.__values.shape) and \ - np.array_equal(self.__values, other.__values, equal_nan=True) - - def replace_val_by_nan(self, val): - if not np.issubdtype(self.__values.dtype, np.floating): - raise ValueError("DataContainer must be a floating type to replace val by nan") - self.__values[self.__values == val] = np.nan - - def clamp_by_nan(self, valid_range: Tuple[float, float]): - if not np.issubdtype(self.__values.dtype, np.float64): - raise ValueError("DataContainer must be a floating type to clamp by nan") - self.__values[self.__values < valid_range[0]] = np.nan - self.__values[self.__values > valid_range[1]] = np.nan + def __setitem__(self, k, v: Union['DataContainer', float, int]): + if type(v) is DataContainer: + self.__values[k] = v.__values + else: + self.__values[k] = v + + def __eq__(self, other: Union['DataContainer', float, int]) -> Union[bool, np.ndarray]: + if type(other) is DataContainer: + return self.__meta == other.__meta and \ + self.__name == other.__name and \ + self.is_time_dependent == other.is_time_dependent and \ + np.all(self.__values.shape == other.__values.shape) and \ + np.array_equal(self.__values, other.__values, equal_nan=True) + else: + return self.__values.__eq__(other) @property def meta(self): @@ -268,15 +262,19 @@ def reserve_like(other: 'VariableAxis', length: int = 0) -> 'VariableAxis': def __getitem__(self, key): if isinstance(key, slice): return self.view(slice(_to_index(key.start, self.__data.values), _to_index(key.stop, self.__data.values))) + else: + return self.view(key) - def __setitem__(self, k, v: 'VariableAxis'): - assert type(v) is VariableAxis - self.__data[k] = v.__data + def __setitem__(self, k, v: Union['VariableAxis', float, int]): + if type(v) is VariableAxis: + self.__data[k] = v.__data + else: + self.__data[k] = v def __len__(self): return len(self.__data) - def view(self, index_range: slice) -> 'VariableAxis': + def view(self, index_range: Union[slice, np.ndarray]) -> 'VariableAxis': return VariableAxis(data=self.__data[index_range]) def __eq__(self, other: 'VariableAxis') -> bool: @@ -310,7 +308,7 @@ def nbytes(self) -> int: class VariableTimeAxis(DataContainerProtocol['VariableTimeAxis']): __slots__ = ['__data'] - def __init__(self, values: np.array = None, meta: Dict = None, data: DataContainer = None): + def __init__(self, values: np.array = None, meta: Dict = None, name: str = "time", data: DataContainer = None): if data is not None: self.__data = data else: @@ -318,7 +316,7 @@ def __init__(self, values: np.array = None, meta: Dict = None, data: DataContain raise ValueError( f"Please provide datetime64[ns] for time axis, got {values.dtype}") self.__data = DataContainer( - values=values, name='time', meta=meta, is_time_dependent=True) + values=values, name=name, meta=meta, is_time_dependent=True) def to_dictionary(self, array_to_list=False) -> Dict[str, object]: d = self.__data.to_dictionary(array_to_list=array_to_list) @@ -347,17 +345,18 @@ def reserve_like(other: 'VariableTimeAxis', length: int = 0) -> 'VariableTimeAxi return VariableTimeAxis(data=DataContainer.reserve_like(other.__data, length)) def __getitem__(self, key): - if isinstance(key, slice): - return self.view(key) + return self.view(key) - def __setitem__(self, k, v: 'VariableTimeAxis'): - assert type(v) is VariableTimeAxis - self.__data[k] = v.__data + def __setitem__(self, k, v: Union['VariableTimeAxis', float, int]): + if type(v) is VariableTimeAxis: + self.__data[k] = v.__data + else: + self.__data[k] = v def __len__(self): return len(self.__data) - def view(self, index_range: slice) -> "VariableTimeAxis": + def view(self, index_range: Union[slice, np.ndarray]) -> 'VariableTimeAxis': return VariableTimeAxis(data=self.__data[index_range]) def __eq__(self, other: 'VariableTimeAxis') -> bool: diff --git a/speasy/products/variable.py b/speasy/products/variable.py index 7f6e9f9c..85e01653 100644 --- a/speasy/products/variable.py +++ b/speasy/products/variable.py @@ -1,5 +1,5 @@ from copy import deepcopy -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, List, Optional, Any, Tuple, Union import astropy.table import astropy.units @@ -108,7 +108,7 @@ def __init__( self.__values_container = values self.__axes = axes - def view(self, index_range: slice) -> "SpeasyVariable": + def view(self, index_range: Union[slice, np.ndarray]) -> "SpeasyVariable": """Return view of the current variable within the desired :data:`index_range`. Parameters @@ -121,6 +121,9 @@ def view(self, index_range: slice) -> "SpeasyVariable": speasy.common.variable.SpeasyVariable view of the variable on the given range """ + if type(index_range) is np.ndarray: + if np.isdtype(index_range.dtype, np.bool): + index_range = np.where(index_range)[0] return SpeasyVariable( axes=[ axis[index_range] if axis.is_time_dependent else axis @@ -169,24 +172,35 @@ def filter_columns(self, columns: List[str]) -> "SpeasyVariable": columns=columns, ) - def __eq__(self, other: "SpeasyVariable") -> bool: - """Check if this variable equals another. + def __eq__(self, other: Union["SpeasyVariable", float, int]) -> bool: + """Check if this variable equals another. Or apply the numpy array comparison if other is a scalar. Parameters ---------- - other: speasy.common.variable.SpeasyVariable - another SpeasyVariable object to compare with + other: speasy.common.variable.SpeasyVariable, float, int + SpeasyVariable or scalar to compare with Returns ------- - bool: - True if all attributes are equal + bool, np.ndarray + True if both variables are equal or an array with the element wise comparison between values and the given scalar """ - return ( - type(other) is SpeasyVariable - and self.__axes == other.__axes - and self.__values_container == other.__values_container - ) + if type(other) is SpeasyVariable: + return self.__axes == other.__axes and self.__values_container == other.__values_container + else: + return self.values.__eq__(other) + + def __le__(self, other): + return self.values.__le__(other) + + def __lt__(self, other): + return self.values.__lt__(other) + + def __ge__(self, other): + return self.values.__ge__(other) + + def __gt__(self, other): + return self.values.__gt__(other) def __len__(self): return len(self.__axes[0]) @@ -201,15 +215,19 @@ def __getitem__(self, key): return self.filter_columns(key) if type(key) is str and key in self.__columns: return self.filter_columns([key]) + if type(key) is np.ndarray: + return self.view(key) raise ValueError( f"No idea how to slice SpeasyVariable with given value: {key}") - def __setitem__(self, k, v: "SpeasyVariable"): - assert type(v) is SpeasyVariable - self.__values_container[k] = v.__values_container - for axis, src_axis in zip(self.__axes, v.__axes): - if axis.is_time_dependent: - axis[k] = src_axis + def __setitem__(self, k, v: Union["SpeasyVariable", float, int]): + if type(v) is SpeasyVariable: + self.__values_container[k] = v.__values_container + for axis, src_axis in zip(self.__axes, v.__axes): + if axis.is_time_dependent: + axis[k] = src_axis + else: + self.__values_container[k] = v def __mul__(self, other): return np.multiply(self, other) @@ -642,7 +660,7 @@ def replace_fillval_by_nan(self, inplace=False) -> "SpeasyVariable": else: res = deepcopy(self) if (fill_value := self.fill_value) is not None: - res.__values_container.replace_val_by_nan(fill_value) + res[res == fill_value] = np.nan return res def clamp_with_nan(self, inplace=False, valid_min=None, valid_max=None) -> "SpeasyVariable": @@ -673,7 +691,7 @@ def clamp_with_nan(self, inplace=False, valid_min=None, valid_max=None) -> "Spea res = deepcopy(self) valid_min = valid_min or self.valid_range[0] valid_max = valid_max or self.valid_range[1] - res.__values_container.clamp_by_nan((valid_min, valid_max)) + res[np.logical_or(res > valid_max, res < valid_min)] = np.nan return res def sanitized(self, drop_fill_values=True, drop_invalid_values=True, drop_nan=True, inplace=False, valid_min=None, @@ -714,24 +732,21 @@ def sanitized(self, drop_fill_values=True, drop_invalid_values=True, drop_nan=Tr indexes_without_fill = None indexes_without_invalid = None if drop_nan: - indexes_without_nan = res.values != np.nan + indexes_without_nan = res != np.nan if drop_fill_values and res.fill_value is not None: - indexes_without_fill = res.values != res.fill_value + indexes_without_fill = res != res.fill_value if drop_invalid_values: valid_min = valid_min or res.valid_range[0] valid_max = valid_max or res.valid_range[1] if valid_min is not None and valid_max is not None: indexes_without_invalid = np.logical_and( - res.values >= valid_min, res.values <= valid_max + res >= valid_min, res <= valid_max ) - keep_indexes = np.where(np.logical_and( - indexes_without_nan, indexes_without_fill, indexes_without_invalid - )) - res.__values_container.select(keep_indexes, inplace=True) - for axis in res.__axes: - if axis.is_time_dependent: - axis.select(keep_indexes[0], inplace=True) - return res + return res[ + np.logical_and( + indexes_without_nan, indexes_without_fill, indexes_without_invalid + ) + ] @staticmethod def reserve_like(other: "SpeasyVariable", length: int = 0) -> "SpeasyVariable": diff --git a/tests/test_speasy_variable.py b/tests/test_speasy_variable.py index be66266a..ef5fbf75 100644 --- a/tests/test_speasy_variable.py +++ b/tests/test_speasy_variable.py @@ -165,6 +165,18 @@ def test_can_slice_columns(self): self.assertTrue(np.all(y.values[:, 0] == var.values[:, 1])) self.assertTrue(np.all(y.axes == var.axes)) + def test_can_slice_with_numpy_comparison(self): + var = make_simple_var(1., 10., 1., 1.) + sliced = var[var > 5] + self.assertEqual(len(sliced), 4) + self.assertTrue(np.all(sliced.values > 5)) + + def test_can_set_values_where_condition_is_true(self): + var = make_simple_var(1., 10., 1., 1.) + var[var < 5] = np.nan + self.assertTrue(np.all(np.isnan(var.values[:4]))) + self.assertTrue(not np.any(np.isnan(var.values[4:]))) + @ddt class SpeasyVariableMerge(unittest.TestCase): @@ -312,10 +324,10 @@ def test_replaces_fill_value(self): var = make_simple_var(1., 10., 1., 10., meta={"FILLVAL": 50.}) self.assertEqual(var.fill_value, 50.) cleaned_copy = var.replace_fillval_by_nan(inplace=False) - self.assertTrue(np.isnan(cleaned_copy.values[4,0])) - self.assertFalse(np.isnan(var.values[4,0])) + self.assertTrue(np.isnan(cleaned_copy.values[4, 0])) + self.assertFalse(np.isnan(var.values[4, 0])) var.replace_fillval_by_nan(inplace=True) - self.assertTrue(np.isnan(var.values[4,0])) + self.assertTrue(np.isnan(var.values[4, 0])) def test_clamps(self): var = make_simple_var(1., 10., 1., 10., meta={"VALIDMIN": 20., "VALIDMAX": 80.}) @@ -334,6 +346,7 @@ def test_cleans(self): self.assertFalse(np.any(np.isnan(cleaned_copy.values))) self.assertTrue(len(cleaned_copy) < len(var)) + class TestSpeasyVariableMath(unittest.TestCase): def setUp(self): self.var = make_simple_var(1., 10., 1., 10.) @@ -446,7 +459,6 @@ def test_scalar_result(self, func): self.assertIsInstance(func(v), float) - class SpeasyVariableCompare(unittest.TestCase): def setUp(self): pass