Skip to content

Commit

Permalink
[SpeasyVariable] Adds clamping, sanitizing and some refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Alexis Jeandet <[email protected]>
  • Loading branch information
jeandet committed Dec 13, 2024
1 parent 6140299 commit d8a49af
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 8 deletions.
100 changes: 94 additions & 6 deletions speasy/core/data_containers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy
from datetime import datetime
from sys import getsizeof
from typing import Dict, List
from typing import Dict, List, Tuple, Protocol, TypeVar

import astropy.units
import numpy as np
Expand All @@ -19,7 +20,65 @@ def _to_index(key, time):
return np.searchsorted(time, key, side='left')


class DataContainer(object):
T = TypeVar("T") # keep until we drop python 3.11 support


class DataContainerProtocol(Protocol[T]):
def select(self, indices, inplace=False) -> T:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def to_dictionary(self, array_to_list=False) -> Dict[str, object]:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@staticmethod
def from_dictionary(dictionary: Dict[str, str or Dict[str, str] or List], dtype=np.float64) -> T:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@staticmethod
def reserve_like(other: T, length: int = 0) -> T:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def __getitem__(self, key)->T:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def __setitem__(self, k, v: T):
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def __len__(self)->int:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def view(self, index_range: slice) -> T:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

def __eq__(self, other: T) -> bool:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@property
def unit(self) -> str:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@property
def is_time_dependent(self) -> bool:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@property
def values(self) -> np.array:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@property
def shape(self):
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@property
def name(self) -> str:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.

@property
def nbytes(self) -> int:
...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.


class DataContainer(DataContainerProtocol['DataContainer']):
__slots__ = ['__values', '__name', '__meta', '__is_time_dependent']

def __init__(self, values: np.array, meta: Dict = None, name: str = None, is_time_dependent: bool = True):
Expand All @@ -34,6 +93,13 @@ def reshape(self, new_shape) -> "DataContainer":
self.__values = self.__values.reshape(new_shape)
return self

def select(self, indices, inplace=False) -> "DataContainer":
if inplace:
res = self
else:
res = deepcopy(self)
res.__values = res.__values[indices]

@property
def is_time_dependent(self) -> bool:
return self.__is_time_dependent
Expand Down Expand Up @@ -147,10 +213,16 @@ def __eq__(self, other: 'DataContainer') -> bool:
np.array_equal(self.__values, other.__values, equal_nan=True)

def replace_val_by_nan(self, val):
if self.__values.dtype != np.float64:
self.__values = self.__values.astype(np.float64)
if not np.issubdtype(self.__values.dtype, np.floating):
raise ValueError("DataContainer must be a floating type to replace val by nan")
self.__values[self.__values == val] = np.nan

def clamp_by_nan(self, valid_range: Tuple[float, float]):
if not np.issubdtype(self.__values.dtype, np.float64):
raise ValueError("DataContainer must be a floating type to clamp by nan")
self.__values[self.__values < valid_range[0]] = np.nan
self.__values[self.__values > valid_range[1]] = np.nan

@property
def meta(self):
return self.__meta
Expand All @@ -160,7 +232,7 @@ def name(self):
return self.__name


class VariableAxis(object):
class VariableAxis(DataContainerProtocol['VariableAxis']):
__slots__ = ['__data']

def __init__(self, values: np.array = None, meta: Dict = None, name: str = "", is_time_dependent: bool = False,
Expand All @@ -176,6 +248,14 @@ def to_dictionary(self, array_to_list=False) -> Dict[str, object]:
d.update({"type": "VariableAxis"})
return d

def select(self, indices, inplace=False) -> "VariableAxis":
if inplace:
res = self
else:
res = deepcopy(self)
res.__data.select(indices, inplace=True)
return res

@staticmethod
def from_dictionary(dictionary: Dict[str, str or Dict[str, str] or List], time=None) -> "VariableAxis":
assert dictionary['type'] == "VariableAxis"
Expand Down Expand Up @@ -227,7 +307,7 @@ def nbytes(self) -> int:
return self.__data.nbytes


class VariableTimeAxis(object):
class VariableTimeAxis(DataContainerProtocol['VariableTimeAxis']):
__slots__ = ['__data']

def __init__(self, values: np.array = None, meta: Dict = None, data: DataContainer = None):
Expand All @@ -245,6 +325,14 @@ def to_dictionary(self, array_to_list=False) -> Dict[str, object]:
d.update({"type": "VariableTimeAxis"})
return d

def select(self, indices, inplace=False) -> "VariableTimeAxis":
if inplace:
res = self
else:
res = deepcopy(self)
res.__data.select(indices, inplace=True)
return res

@property
def shape(self):
return self.__data.shape
Expand Down
115 changes: 114 additions & 1 deletion speasy/products/variable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Tuple

import astropy.table
import astropy.units
Expand All @@ -11,6 +11,7 @@
VariableAxis,
VariableTimeAxis,
_to_index,
DataContainerProtocol
)
from speasy.plotting import Plot

Expand Down Expand Up @@ -46,6 +47,10 @@ class SpeasyVariable(SpeasyProduct):
SpeasyVariable name
nbytes: int
memory usage in bytes
fill_value: Any
fill value if found in meta-data
valid_range: Tuple[Any, Any]
valid range if found in meta-data
Methods
-------
Expand Down Expand Up @@ -425,6 +430,17 @@ def fill_value(self) -> Optional[Any]:
"""
return self.meta.get("FILLVAL", None)

@property
def valid_range(self) -> Optional[Tuple[Any, Any]]:
"""SpeasyVariable valid range if found in meta-data
Returns
-------
Tuple[Any, Any]
valid range if found in meta-data
"""
return self.meta.get("VALIDMIN", None), self.meta.get("VALIDMAX", None)

def unit_applied(self, unit: str or None = None, copy=True) -> "SpeasyVariable":
"""Returns a SpeasyVariable with given or automatically found unit applied to values
Expand All @@ -440,6 +456,10 @@ def unit_applied(self, unit: str or None = None, copy=True) -> "SpeasyVariable":
SpeasyVariable
SpeasyVariable identic to source one with values converted to astropy.units.Quantity according to given or found unit
Notes
-----
This interface assume that there is only one unit for the whole variable since all stored in the same array
See Also
--------
unit: returns variable unit if found in meta-data
Expand Down Expand Up @@ -611,6 +631,11 @@ def replace_fillval_by_nan(self, inplace=False) -> "SpeasyVariable":
-------
SpeasyVariable
source variable or copy with fill values replaced by NaN
See Also
--------
clamp_with_nan: replaces values outside valid range by NaN
sanitized: removes fill and invalid values
"""
if inplace:
res = self
Expand All @@ -620,6 +645,94 @@ def replace_fillval_by_nan(self, inplace=False) -> "SpeasyVariable":
res.__values_container.replace_val_by_nan(fill_value)
return res

def clamp_with_nan(self, inplace=False, valid_min=None, valid_max=None) -> "SpeasyVariable":
"""Replaces values outside valid range by NaN, valid range is taken from metadata fields "VALIDMIN" and "VALIDMAX"
Parameters
----------
inplace : bool, optional
Modifies source variable when true else modifies and returns a copy, by default False
valid_min : Float, optional
Optional minimum valid value, takes metadata field "VALIDMIN" if not provided, by default None
valid_max : Float, optional
Optional maximum valid value, takes metadata field "VALIDMAX" if not provided, by default None
Returns
-------
SpeasyVariable
source variable or copy with values clamped by NaN
See Also
--------
replace_fillval_by_nan: replaces fill values by NaN
sanitized: removes fill and invalid values
"""
if inplace:
res = self
else:
res = deepcopy(self)
valid_min = valid_min or self.valid_range[0]
valid_max = valid_max or self.valid_range[1]
res.__values_container.clamp_by_nan((valid_min, valid_max))
return res

def sanitized(self, drop_fill_values=True, drop_invalid_values=True, drop_nan=True, inplace=False, valid_min=None,
valid_max=None) -> "SpeasyVariable":
"""Returns a copy of the variable with fill values and invalid values removed
Parameters
----------
drop_fill_values : bool, optional
Remove fill values, by default True
drop_invalid_values : bool, optional
Remove values outside valid range, by default True
drop_nan : bool, optional
Remove NaN values, by default True
inplace : bool, optional
Modifies source variable when true else modifies and returns a copy, by default False
valid_min : Float, optional
Minimum valid value, takes metadata field "VALIDMIN" if not provided, by default None
valid_max : Float, optional
Maximum valid value, takes metadata field "VALIDMAX" if not provided, by default None
Returns
-------
SpeasyVariable
source variable or copy with fill and invalid values removed
See Also
--------
replace_fillval_by_nan: replaces fill values by NaN
clamp_with_nan: replaces values outside valid range by NaN
"""
if inplace:
res = self
else:
res = deepcopy(self)

indexes_without_nan = None
indexes_without_fill = None
indexes_without_invalid = None
if drop_nan:
indexes_without_nan = res.values != np.nan
if drop_fill_values and res.fill_value is not None:
indexes_without_fill = res.values != res.fill_value
if drop_invalid_values:
valid_min = valid_min or res.valid_range[0]
valid_max = valid_max or res.valid_range[1]
if valid_min is not None and valid_max is not None:
indexes_without_invalid = np.logical_and(
res.values >= valid_min, res.values <= valid_max
)
keep_indexes = np.where(np.logical_and(
indexes_without_nan, indexes_without_fill, indexes_without_invalid
))
res.__values_container.select(keep_indexes, inplace=True)
for axis in res.__axes:
if axis.is_time_dependent:
axis.select(keep_indexes[0], inplace=True)
return res

@staticmethod
def reserve_like(other: "SpeasyVariable", length: int = 0) -> "SpeasyVariable":
"""Create a SpeasyVariable of given length and with the same properties than given variable but unset values
Expand Down
28 changes: 27 additions & 1 deletion tests/test_speasy_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,35 @@ def test_overrides_plot_arguments(self):
ax = var.plot(xaxis_label="Time", yaxis_label="Values", yaxis_units="nT", zaxis_label="Values",
zaxis_units="nT")
self.assertIsNotNone(ax)

except ImportError:
self.skipTest("Can't import matplotlib")

def test_replaces_fill_value(self):
var = make_simple_var(1., 10., 1., 10., meta={"FILLVAL": 50.})
self.assertEqual(var.fill_value, 50.)
cleaned_copy = var.replace_fillval_by_nan(inplace=False)
self.assertTrue(np.isnan(cleaned_copy.values[4,0]))
self.assertFalse(np.isnan(var.values[4,0]))
var.replace_fillval_by_nan(inplace=True)
self.assertTrue(np.isnan(var.values[4,0]))

def test_clamps(self):
var = make_simple_var(1., 10., 1., 10., meta={"VALIDMIN": 20., "VALIDMAX": 80.})
clamped_copy = var.clamp_with_nan()
self.assertTrue(np.all(np.isnan(clamped_copy.values[0:1, 0])))
self.assertTrue(np.all(np.isnan(clamped_copy.values[8:10, 0])))
self.assertFalse(np.any(np.isnan(clamped_copy.values[1:8, 0])))
var.clamp_with_nan(inplace=True)
self.assertTrue(np.all(np.isnan(var.values[0:1, 0])))
self.assertTrue(np.all(np.isnan(var.values[8:10, 0])))
self.assertFalse(np.any(np.isnan(var.values[1:8, 0])))

def test_cleans(self):
var = make_simple_var(1., 10., 1., 10., meta={"FILLVAL": 50., "VALIDMIN": 20., "VALIDMAX": 80.})
cleaned_copy = var.sanitized()
self.assertFalse(np.any(np.isnan(cleaned_copy.values)))
self.assertTrue(len(cleaned_copy) < len(var))

Check notice

Code scanning / CodeQL

Imprecise assert Note test

assertTrue(a < b) cannot provide an informative message. Using assertLess(a, b) instead will give more informative messages.

class TestSpeasyVariableMath(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -421,6 +446,7 @@ def test_scalar_result(self, func):
self.assertIsInstance(func(v), float)



class SpeasyVariableCompare(unittest.TestCase):
def setUp(self):
pass
Expand Down

0 comments on commit d8a49af

Please sign in to comment.