Skip to content

Commit

Permalink
[SpeasyVariable] refac clamp and clean with numpy comparison support
Browse files Browse the repository at this point in the history
It is now possible to do something like var[var>10]=52

Signed-off-by: Alexis Jeandet <[email protected]>
  • Loading branch information
jeandet committed Dec 15, 2024
1 parent d8a49af commit bd751d8
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 78 deletions.
83 changes: 41 additions & 42 deletions speasy/core/data_containers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from copy import deepcopy
from datetime import datetime
from sys import getsizeof
from typing import Dict, List, Tuple, Protocol, TypeVar
from typing import Dict, List, Tuple, Protocol, TypeVar, Union

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Tuple' is not used.

import astropy.units
import numpy as np
Expand Down Expand Up @@ -38,19 +38,16 @@ def from_dictionary(dictionary: Dict[str, str or Dict[str, str] or List], dtype=
def reserve_like(other: T, length: int = 0) -> T:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def __getitem__(self, key)->T:
def __getitem__(self, key) -> T:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def __setitem__(self, k, v: T):
def __setitem__(self, k, v: Union[T, float, int]):
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def __len__(self)->int:
def __len__(self) -> int:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def view(self, index_range: slice) -> T:
...

def __eq__(self, other: T) -> bool:
def __eq__(self, other: Union[T, float, int]) -> Union[bool, np.ndarray]:
...

@property
Expand All @@ -77,6 +74,9 @@ def name(self) -> str:
def nbytes(self) -> int:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def view(self, index_range: Union[slice, np.ndarray]) -> T:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


class DataContainer(DataContainerProtocol['DataContainer']):
__slots__ = ['__values', '__name', '__meta', '__is_time_dependent']
Expand Down Expand Up @@ -128,7 +128,7 @@ def unit(self) -> str:
def nbytes(self) -> int:
return self.__values.nbytes + getsizeof(self.__meta) + getsizeof(self.__name)

def view(self, index_range: slice):
def view(self, index_range: Union[slice, np.ndarray]):
return DataContainer(name=self.__name, meta=self.__meta, values=self.__values[index_range],
is_time_dependent=self.__is_time_dependent)

Expand Down Expand Up @@ -201,27 +201,21 @@ def __len__(self):
def __getitem__(self, key):
return self.view(key)

def __setitem__(self, k, v: 'DataContainer'):
assert type(v) is DataContainer
self.__values[k] = v.__values

def __eq__(self, other: 'DataContainer') -> bool:
return self.__meta == other.__meta and \
self.__name == other.__name and \
self.is_time_dependent == other.is_time_dependent and \
np.all(self.__values.shape == other.__values.shape) and \
np.array_equal(self.__values, other.__values, equal_nan=True)

def replace_val_by_nan(self, val):
if not np.issubdtype(self.__values.dtype, np.floating):
raise ValueError("DataContainer must be a floating type to replace val by nan")
self.__values[self.__values == val] = np.nan

def clamp_by_nan(self, valid_range: Tuple[float, float]):
if not np.issubdtype(self.__values.dtype, np.float64):
raise ValueError("DataContainer must be a floating type to clamp by nan")
self.__values[self.__values < valid_range[0]] = np.nan
self.__values[self.__values > valid_range[1]] = np.nan
def __setitem__(self, k, v: Union['DataContainer', float, int]):
if type(v) is DataContainer:
self.__values[k] = v.__values
else:
self.__values[k] = v

def __eq__(self, other: Union['DataContainer', float, int]) -> Union[bool, np.ndarray]:
if type(other) is DataContainer:
return self.__meta == other.__meta and \
self.__name == other.__name and \
self.is_time_dependent == other.is_time_dependent and \
np.all(self.__values.shape == other.__values.shape) and \
np.array_equal(self.__values, other.__values, equal_nan=True)
else:
return self.__values.__eq__(other)

@property
def meta(self):
Expand Down Expand Up @@ -268,15 +262,19 @@ def reserve_like(other: 'VariableAxis', length: int = 0) -> 'VariableAxis':
def __getitem__(self, key):
if isinstance(key, slice):
return self.view(slice(_to_index(key.start, self.__data.values), _to_index(key.stop, self.__data.values)))
else:
return self.view(key)

def __setitem__(self, k, v: 'VariableAxis'):
assert type(v) is VariableAxis
self.__data[k] = v.__data
def __setitem__(self, k, v: Union['VariableAxis', float, int]):
if type(v) is VariableAxis:
self.__data[k] = v.__data
else:
self.__data[k] = v

def __len__(self):
return len(self.__data)

def view(self, index_range: slice) -> 'VariableAxis':
def view(self, index_range: Union[slice, np.ndarray]) -> 'VariableAxis':
return VariableAxis(data=self.__data[index_range])

def __eq__(self, other: 'VariableAxis') -> bool:
Expand Down Expand Up @@ -310,15 +308,15 @@ def nbytes(self) -> int:
class VariableTimeAxis(DataContainerProtocol['VariableTimeAxis']):
__slots__ = ['__data']

def __init__(self, values: np.array = None, meta: Dict = None, data: DataContainer = None):
def __init__(self, values: np.array = None, meta: Dict = None, name: str = "time", data: DataContainer = None):
if data is not None:
self.__data = data
else:
if values.dtype != np.dtype('datetime64[ns]'):
raise ValueError(
f"Please provide datetime64[ns] for time axis, got {values.dtype}")
self.__data = DataContainer(
values=values, name='time', meta=meta, is_time_dependent=True)
values=values, name=name, meta=meta, is_time_dependent=True)

def to_dictionary(self, array_to_list=False) -> Dict[str, object]:
d = self.__data.to_dictionary(array_to_list=array_to_list)
Expand Down Expand Up @@ -347,17 +345,18 @@ def reserve_like(other: 'VariableTimeAxis', length: int = 0) -> 'VariableTimeAxi
return VariableTimeAxis(data=DataContainer.reserve_like(other.__data, length))

def __getitem__(self, key):
if isinstance(key, slice):
return self.view(key)
return self.view(key)

def __setitem__(self, k, v: 'VariableTimeAxis'):
assert type(v) is VariableTimeAxis
self.__data[k] = v.__data
def __setitem__(self, k, v: Union['VariableTimeAxis', float, int]):
if type(v) is VariableTimeAxis:
self.__data[k] = v.__data
else:
self.__data[k] = v

def __len__(self):
return len(self.__data)

def view(self, index_range: slice) -> "VariableTimeAxis":
def view(self, index_range: Union[slice, np.ndarray]) -> 'VariableTimeAxis':
return VariableTimeAxis(data=self.__data[index_range])

def __eq__(self, other: 'VariableTimeAxis') -> bool:
Expand Down
79 changes: 47 additions & 32 deletions speasy/products/variable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Dict, List, Optional, Any, Tuple
from typing import Dict, List, Optional, Any, Tuple, Union

import astropy.table
import astropy.units
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(
self.__values_container = values
self.__axes = axes

def view(self, index_range: slice) -> "SpeasyVariable":
def view(self, index_range: Union[slice, np.ndarray]) -> "SpeasyVariable":
"""Return view of the current variable within the desired :data:`index_range`.
Parameters
Expand All @@ -121,6 +121,9 @@ def view(self, index_range: slice) -> "SpeasyVariable":
speasy.common.variable.SpeasyVariable
view of the variable on the given range
"""
if type(index_range) is np.ndarray:
if np.isdtype(index_range.dtype, np.bool):
index_range = np.where(index_range)[0]
return SpeasyVariable(
axes=[
axis[index_range] if axis.is_time_dependent else axis
Expand Down Expand Up @@ -169,24 +172,35 @@ def filter_columns(self, columns: List[str]) -> "SpeasyVariable":
columns=columns,
)

def __eq__(self, other: "SpeasyVariable") -> bool:
"""Check if this variable equals another.
def __eq__(self, other: Union["SpeasyVariable", float, int]) -> bool:
"""Check if this variable equals another. Or apply the numpy array comparison if other is a scalar.
Parameters
----------
other: speasy.common.variable.SpeasyVariable
another SpeasyVariable object to compare with
other: speasy.common.variable.SpeasyVariable, float, int
SpeasyVariable or scalar to compare with
Returns
-------
bool:
True if all attributes are equal
bool, np.ndarray
True if both variables are equal or an array with the element wise comparison between values and the given scalar
"""
return (
type(other) is SpeasyVariable
and self.__axes == other.__axes
and self.__values_container == other.__values_container
)
if type(other) is SpeasyVariable:
return self.__axes == other.__axes and self.__values_container == other.__values_container
else:
return self.values.__eq__(other)

def __le__(self, other):
return self.values.__le__(other)

def __lt__(self, other):
return self.values.__lt__(other)

def __ge__(self, other):
return self.values.__ge__(other)

def __gt__(self, other):
return self.values.__gt__(other)

def __len__(self):
return len(self.__axes[0])
Expand All @@ -201,15 +215,19 @@ def __getitem__(self, key):
return self.filter_columns(key)
if type(key) is str and key in self.__columns:
return self.filter_columns([key])
if type(key) is np.ndarray:
return self.view(key)
raise ValueError(
f"No idea how to slice SpeasyVariable with given value: {key}")

def __setitem__(self, k, v: "SpeasyVariable"):
assert type(v) is SpeasyVariable
self.__values_container[k] = v.__values_container
for axis, src_axis in zip(self.__axes, v.__axes):
if axis.is_time_dependent:
axis[k] = src_axis
def __setitem__(self, k, v: Union["SpeasyVariable", float, int]):
if type(v) is SpeasyVariable:
self.__values_container[k] = v.__values_container
for axis, src_axis in zip(self.__axes, v.__axes):
if axis.is_time_dependent:
axis[k] = src_axis
else:
self.__values_container[k] = v

def __mul__(self, other):
return np.multiply(self, other)
Expand Down Expand Up @@ -642,7 +660,7 @@ def replace_fillval_by_nan(self, inplace=False) -> "SpeasyVariable":
else:
res = deepcopy(self)
if (fill_value := self.fill_value) is not None:
res.__values_container.replace_val_by_nan(fill_value)
res[res == fill_value] = np.nan
return res

def clamp_with_nan(self, inplace=False, valid_min=None, valid_max=None) -> "SpeasyVariable":
Expand Down Expand Up @@ -673,7 +691,7 @@ def clamp_with_nan(self, inplace=False, valid_min=None, valid_max=None) -> "Spea
res = deepcopy(self)
valid_min = valid_min or self.valid_range[0]
valid_max = valid_max or self.valid_range[1]
res.__values_container.clamp_by_nan((valid_min, valid_max))
res[np.logical_or(res > valid_max, res < valid_min)] = np.nan
return res

def sanitized(self, drop_fill_values=True, drop_invalid_values=True, drop_nan=True, inplace=False, valid_min=None,
Expand Down Expand Up @@ -714,24 +732,21 @@ def sanitized(self, drop_fill_values=True, drop_invalid_values=True, drop_nan=Tr
indexes_without_fill = None
indexes_without_invalid = None
if drop_nan:
indexes_without_nan = res.values != np.nan
indexes_without_nan = res != np.nan
if drop_fill_values and res.fill_value is not None:
indexes_without_fill = res.values != res.fill_value
indexes_without_fill = res != res.fill_value
if drop_invalid_values:
valid_min = valid_min or res.valid_range[0]
valid_max = valid_max or res.valid_range[1]
if valid_min is not None and valid_max is not None:
indexes_without_invalid = np.logical_and(
res.values >= valid_min, res.values <= valid_max
res >= valid_min, res <= valid_max
)
keep_indexes = np.where(np.logical_and(
indexes_without_nan, indexes_without_fill, indexes_without_invalid
))
res.__values_container.select(keep_indexes, inplace=True)
for axis in res.__axes:
if axis.is_time_dependent:
axis.select(keep_indexes[0], inplace=True)
return res
return res[
np.logical_and(
indexes_without_nan, indexes_without_fill, indexes_without_invalid
)
]

@staticmethod
def reserve_like(other: "SpeasyVariable", length: int = 0) -> "SpeasyVariable":
Expand Down
20 changes: 16 additions & 4 deletions tests/test_speasy_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,18 @@ def test_can_slice_columns(self):
self.assertTrue(np.all(y.values[:, 0] == var.values[:, 1]))
self.assertTrue(np.all(y.axes == var.axes))

def test_can_slice_with_numpy_comparison(self):
var = make_simple_var(1., 10., 1., 1.)
sliced = var[var > 5]
self.assertEqual(len(sliced), 4)
self.assertTrue(np.all(sliced.values > 5))

def test_can_set_values_where_condition_is_true(self):
var = make_simple_var(1., 10., 1., 1.)
var[var < 5] = np.nan
self.assertTrue(np.all(np.isnan(var.values[:4])))
self.assertTrue(not np.any(np.isnan(var.values[4:])))


@ddt
class SpeasyVariableMerge(unittest.TestCase):
Expand Down Expand Up @@ -312,10 +324,10 @@ def test_replaces_fill_value(self):
var = make_simple_var(1., 10., 1., 10., meta={"FILLVAL": 50.})
self.assertEqual(var.fill_value, 50.)
cleaned_copy = var.replace_fillval_by_nan(inplace=False)
self.assertTrue(np.isnan(cleaned_copy.values[4,0]))
self.assertFalse(np.isnan(var.values[4,0]))
self.assertTrue(np.isnan(cleaned_copy.values[4, 0]))
self.assertFalse(np.isnan(var.values[4, 0]))
var.replace_fillval_by_nan(inplace=True)
self.assertTrue(np.isnan(var.values[4,0]))
self.assertTrue(np.isnan(var.values[4, 0]))

def test_clamps(self):
var = make_simple_var(1., 10., 1., 10., meta={"VALIDMIN": 20., "VALIDMAX": 80.})
Expand All @@ -334,6 +346,7 @@ def test_cleans(self):
self.assertFalse(np.any(np.isnan(cleaned_copy.values)))
self.assertTrue(len(cleaned_copy) < len(var))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a < b) cannot provide an informative message. Using assertLess(a, b) instead will give more informative messages.


class TestSpeasyVariableMath(unittest.TestCase):
def setUp(self):
self.var = make_simple_var(1., 10., 1., 10.)
Expand Down Expand Up @@ -446,7 +459,6 @@ def test_scalar_result(self, func):
self.assertIsInstance(func(v), float)



class SpeasyVariableCompare(unittest.TestCase):
def setUp(self):
pass
Expand Down

0 comments on commit bd751d8

Please sign in to comment.