From 33ec7b4b8c50cbac2aca1b285ecb68703a564a78 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Thu, 16 Jan 2025 17:33:12 +0800 Subject: [PATCH 1/6] Add convert_in_si method --- brainunit/_base.py | 199 +++++++++++++++++-------- brainunit/math/_einops.py | 5 - brainunit/math/_fun_accept_unitless.py | 8 +- brainunit/math/_fun_change_unit.py | 8 +- brainunit/math/_fun_keep_unit.py | 5 - brainunit/math/_fun_remove_unit.py | 9 +- 6 files changed, 143 insertions(+), 91 deletions(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index e6fc540..28279a3 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -15,6 +15,7 @@ from __future__ import annotations +import inspect import numbers import operator from contextlib import contextmanager @@ -64,6 +65,7 @@ # advanced functions 'get_or_create_dimension', + 'convert_in_si', ] StaticScalar = Union[ @@ -76,7 +78,6 @@ A = TypeVar('A') - def compatible_with_equinox(mode: bool = True): """ This function is developed to set the compatibility with equinox. @@ -1169,7 +1170,6 @@ def _wrap_function_keep_unit(func): """ def f(x: Quantity, *args, **kwds): # pylint: disable=C0111 - # x = x.factorless() return Quantity(func(x.mantissa, *args, **kwds), unit=x.unit) f._arg_units = [None] @@ -1195,7 +1195,6 @@ def _wrap_function_change_unit(func, unit_fun): def f(x, *args, **kwds): # pylint: disable=C0111 assert isinstance(x, Quantity), "Only Quantity objects can be passed to this function" - # x = x.factorless() return maybe_decimal(Quantity(func(x.mantissa, *args, **kwds), unit=unit_fun(x.unit, x.unit))) f._arg_units = [None] @@ -1219,7 +1218,6 @@ def _wrap_function_remove_unit(func): def f(x, *args, **kwds): # pylint: disable=C0111 assert isinstance(x, Quantity), "Only Quantity objects can be passed to this function" - # x = x.factorless() return func(x.mantissa, *args, **kwds) f._arg_units = [None] @@ -1808,7 +1806,7 @@ def __mul__(self, other) -> 'Unit' | Quantity: dim, scale=scale, base=self.base, - factor=self.factor, + factor=factor, name=name, dispname=dispname, iscompound=iscompound, @@ -2626,12 +2624,10 @@ def ndim(self) -> int: @property def imag(self) -> 'Quantity': - # self = self.factorless() return Quantity(jnp.imag(self.mantissa), unit=self.unit) @property def real(self) -> 'Quantity': - # self = self.factorless() return Quantity(jnp.real(self.mantissa), unit=self.unit) @property @@ -2640,12 +2636,10 @@ def size(self) -> int: @property def T(self) -> 'Quantity': - # self = self.factorless() return Quantity(jnp.asarray(self.mantissa).T, unit=self.unit) @property def mT(self) -> 'Quantity': - # self = self.factorless() return Quantity(jnp.asarray(self.mantissa).mT, unit=self.unit) @property @@ -2708,7 +2702,6 @@ def __iter__(self): - https://github.com/google/jax/issues/7713 - https://github.com/google/jax/pull/3821 """ - # self = self.factorless() if self.ndim == 0: yield self @@ -2717,7 +2710,6 @@ def __iter__(self): yield Quantity(self.mantissa[i], unit=self.unit) def __getitem__(self, index) -> 'Quantity': - # self = self.factorless() if isinstance(index, slice) and (index == _all_slice): return Quantity(self.mantissa, unit=self.unit) @@ -2765,7 +2757,6 @@ def scatter_add( out : Quantity The scatter-added value. """ - # self = self.factorless() # check value if not isinstance(value, Quantity): @@ -2825,7 +2816,6 @@ def scatter_mul( out : Quantity The scatter-multiplied value. """ - # self = self.factorless() # check value if not isinstance(value, Quantity): @@ -2863,7 +2853,6 @@ def scatter_div( out : Quantity The scatter-divided value. """ - # self = self.factorless() # check value if not isinstance(value, Quantity): @@ -2901,7 +2890,6 @@ def scatter_max( out : Quantity The scatter-maximum value. """ - # self = self.factorless() # check value if not isinstance(value, Quantity): @@ -2939,7 +2927,6 @@ def scatter_min( out : Quantity The scatter-minimum value. """ - # self = self.factorless() # check value if not isinstance(value, Quantity): @@ -2965,19 +2952,19 @@ def __len__(self) -> int: return len(self.mantissa) def __neg__(self) -> 'Quantity': - # self = self.factorless() + return Quantity(self.mantissa.__neg__(), unit=self.unit) def __pos__(self) -> 'Quantity': - # self = self.factorless() + return Quantity(self.mantissa.__pos__(), unit=self.unit) def __abs__(self) -> 'Quantity': - # self = self.factorless() + return Quantity(self.mantissa.__abs__(), unit=self.unit) def __invert__(self) -> 'Quantity': - # self = self.factorless() + return Quantity(self.mantissa.__invert__(), unit=self.unit) def _comparison(self, other: Any, operator_str: str, operation: Callable): @@ -3040,7 +3027,6 @@ def _binary_operation( inplace: bool, optional Whether to do the operation in-place (defaults to ``False``). """ - # self = self.factorless() # format "other" other = _to_quantity(other) @@ -3195,7 +3181,7 @@ def __imatmul__(self, oc): # -------------------- # def __pow__(self, oc): - # self = self.factorless() + if compat_with_equinox: try: from equinox.internal._omega import ω # noqa @@ -3211,7 +3197,7 @@ def __pow__(self, oc): def __rpow__(self, oc): # oc ** self - # self = self.factorless() + assert self.is_unitless, f"Cannot calculate {oc} ** {self}, the exponent has to be dimensionless" return oc ** self.mantissa @@ -3259,7 +3245,7 @@ def __ixor__(self, oc) -> 'Quantity': def __lshift__(self, oc) -> 'Quantity': # self << oc - # self = self.factorless() + if isinstance(oc, Quantity): assert oc.is_unitless, "The shift amount must be dimensionless" oc = oc.mantissa @@ -3268,20 +3254,20 @@ def __lshift__(self, oc) -> 'Quantity': def __rlshift__(self, oc) -> 'Quantity' | jax.typing.ArrayLike: # oc << self - # self = self.factorless() + assert self.is_unitless, "The shift amount must be dimensionless" return oc << self.mantissa def __ilshift__(self, oc) -> 'Quantity': # self <<= oc - # self = self.factorless() + r = self.__lshift__(oc) self.update_mantissa(r.mantissa) return self def __rshift__(self, oc) -> 'Quantity': # self >> oc - # self = self.factorless() + if isinstance(oc, Quantity): assert oc.is_unitless, "The shift amount must be dimensionless" oc = oc.mantissa @@ -3290,13 +3276,13 @@ def __rshift__(self, oc) -> 'Quantity': def __rrshift__(self, oc) -> 'Quantity' | jax.typing.ArrayLike: # oc >> self - # self = self.factorless() + assert self.is_unitless, "The shift amount must be dimensionless" return oc >> self.mantissa def __irshift__(self, oc) -> 'Quantity': # self >>= oc - # self = self.factorless() + r = self.__rshift__(oc) self.update_mantissa(r.mantissa) return self @@ -3308,7 +3294,7 @@ def __round__(self, ndigits: int = None) -> 'Quantity': :param ndigits: The number of decimals to round to. :return: The rounded Quantity. """ - # self = self.factorless() + return Quantity(self.mantissa.__round__(ndigits), unit=self.unit) # def __reduce__(self): @@ -3376,7 +3362,7 @@ def round( The real and imaginary parts of complex numbers are rounded separately. The result of rounding a float is a float. """ - # self = self.factorless() + return Quantity(jnp.round(self.mantissa, decimals), unit=self.unit) def astype( @@ -3390,7 +3376,7 @@ def astype( dtype: str, dtype Typecode or data-type to which the array is cast. """ - # self = self.factorless() + if dtype is None: return Quantity(self.mantissa, unit=self.unit) else: @@ -3404,24 +3390,24 @@ def clip( """ Return an array whose values are limited to [min, max]. One of max or min must be given. """ - # self = self.factorless() + _, min = unit_scale_align_to_first(self, min) _, max = unit_scale_align_to_first(self, max) return Quantity(jnp.clip(self.mantissa, min.mantissa, max.mantissa), unit=self.unit) def conj(self) -> 'Quantity': """Complex-conjugate all elements.""" - # self = self.factorless() + return Quantity(jnp.conj(self.mantissa), unit=self.unit) def conjugate(self) -> 'Quantity': """Return the complex conjugate, element-wise.""" - # self = self.factorless() + return Quantity(jnp.conjugate(self.mantissa), unit=self.unit) def copy(self) -> 'Quantity': """Return a copy of the quantity.""" - # self = self.factorless() + return type(self)(jnp.copy(self.mantissa), unit=self.unit) def dot(self, b) -> 'Quantity': @@ -3431,23 +3417,22 @@ def dot(self, b) -> 'Quantity': def fill(self, value: Quantity) -> 'Quantity': """Fill the array with a scalar mantissa.""" - # self = self.factorless() + fail_for_dimension_mismatch(self, value, "fill") self[:] = value return self def flatten(self) -> 'Quantity': - # self = self.factorless() + return Quantity(jnp.reshape(self.mantissa, -1), unit=self.unit) def item(self, *args) -> 'Quantity': """Copy an element of an array to a standard Python scalar and return it.""" - # self = self.factorless() + return Quantity(self.mantissa.item(*args), unit=self.unit) def prod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis is not None """Return the product of the array elements over the given axis.""" - # self = self.factorless() prod_res = jnp.prod(self.mantissa, *args, **kwds) # Calculating the correct dimensions is not completly trivial (e.g. @@ -3468,7 +3453,6 @@ def prod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis is n def nanprod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis is not None """Return the product of array elements over a given axis treating Not a Numbers (NaNs) as ones.""" - # self = self.factorless() prod_res = jnp.nanprod(self.mantissa, *args, **kwds) nan_mask = jnp.isnan(self.mantissa) @@ -3479,7 +3463,6 @@ def nanprod(self, *args, **kwds) -> 'Quantity': # TODO: check error when axis i return maybe_decimal(r) def cumprod(self, *args, **kwds): # TODO: check error when axis is not None - # self = self.factorless() prod_res = jnp.cumprod(self.mantissa, *args, **kwds) dim_exponent = jnp.ones_like(self.mantissa).cumsum(*args, **kwds) @@ -3489,7 +3472,6 @@ def cumprod(self, *args, **kwds): # TODO: check error when axis is not None return maybe_decimal(r) def nancumprod(self, *args, **kwds): # TODO: check error when axis is not None - # self = self.factorless() prod_res = jnp.nancumprod(self.mantissa, *args, **kwds) nan_mask = jnp.isnan(self.mantissa) @@ -3509,25 +3491,25 @@ def put(self, indices, values) -> 'Quantity': values: array_like Values to place in the array at target indices. """ - # self = self.factorless() + fail_for_dimension_mismatch(self, values, "put") self.__setitem__(indices, values) return self def repeat(self, repeats, axis=None) -> 'Quantity': """Repeat elements of an array.""" - # self = self.factorless() + r = jnp.repeat(self.mantissa, repeats=repeats, axis=axis) return Quantity(r, unit=self.unit) def reshape(self, shape, order='C') -> 'Quantity': """Returns an array containing the same data with a new shape.""" - # self = self.factorless() + return Quantity(jnp.reshape(self.mantissa, shape, order=order), unit=self.unit) def resize(self, new_shape) -> 'Quantity': """Change shape and size of array in-place.""" - # self = self.factorless() + self.update_mantissa(jnp.resize(self.mantissa, new_shape)) return self @@ -3548,18 +3530,18 @@ def sort(self, axis=-1, stable=True, order=None) -> 'Quantity': but unspecified fields will still be used, in the order in which they come up in the dtype, to break ties. """ - # self = self.factorless() + self.update_mantissa(jnp.sort(self.mantissa, axis=axis, stable=stable, order=order)) return self def squeeze(self, axis=None) -> 'Quantity': """Remove axes of length one from ``a``.""" - # self = self.factorless() + return Quantity(jnp.squeeze(self.mantissa, axis=axis), unit=self.unit) def swapaxes(self, axis1, axis2) -> 'Quantity': """Return a view of the array with `axis1` and `axis2` interchanged.""" - # self = self.factorless() + return Quantity(jnp.swapaxes(self.mantissa, axis1, axis2), unit=self.unit) def split(self, indices_or_sections, axis=0) -> List['Quantity']: @@ -3590,7 +3572,7 @@ def split(self, indices_or_sections, axis=0) -> List['Quantity']: sub-arrays : list of ndarrays A list of sub-arrays as views into `ary`. """ - # self = self.factorless() + return [Quantity(a, unit=self.unit) for a in jnp.split(self.mantissa, indices_or_sections, axis=axis)] def take( @@ -3603,7 +3585,6 @@ def take( fill_value=None, ) -> 'Quantity': """Return an array formed from the elements of a at the given indices.""" - # self = self.factorless() if isinstance(fill_value, Quantity): fail_for_dimension_mismatch(self, fill_value, "take") @@ -3670,7 +3651,7 @@ def transpose(self, *axes) -> 'Quantity': out : ndarray View of `a`, with axes suitably permuted. """ - # self = self.factorless() + return Quantity(jnp.transpose(self.mantissa, *axes), unit=self.unit) def tile(self, reps) -> 'Quantity': @@ -3702,7 +3683,7 @@ def tile(self, reps) -> 'Quantity': c : ndarray The tiled output array. """ - # self = self.factorless() + return Quantity(jnp.tile(self.mantissa, reps), unit=self.unit) def view(self, *args, dtype=None) -> 'Quantity': @@ -3841,7 +3822,7 @@ def view(self, *args, dtype=None) -> 'Quantity': [4, 16] """ - # self = self.factorless() + if len(args) == 0: if dtype is None: raise ValueError('Provide dtype or shape.') @@ -3863,7 +3844,7 @@ def view(self, *args, dtype=None) -> 'Quantity': def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray: """Support ``numpy.array()`` and ``numpy.asarray()`` functions.""" - # self = self.factorless() + if self.dim.is_dimensionless: return np.asarray(self.to_decimal(), dtype=dtype) else: @@ -3873,7 +3854,7 @@ def __array__(self, dtype: Optional[jax.typing.DTypeLike] = None) -> np.ndarray: ) def __float__(self): - # self = self.factorless() + if self.dim.is_dimensionless and self.ndim == 0: return float(self.to_decimal()) else: @@ -3883,7 +3864,7 @@ def __float__(self): ) def __int__(self): - # self = self.factorless() + if self.dim.is_dimensionless and self.ndim == 0: return int(self.to_decimal()) else: @@ -3893,7 +3874,7 @@ def __int__(self): ) def __index__(self): - # self = self.factorless() + if self.dim.is_dimensionless: return operator.index(self.to_decimal()) else: @@ -3914,7 +3895,7 @@ def unsqueeze(self, axis: int) -> 'Quantity': See :func:`brainstate.math.unsqueeze` """ - # self = self.factorless() + return Quantity(jnp.expand_dims(self.mantissa, axis), unit=self.unit) def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Quantity': @@ -3931,7 +3912,7 @@ def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Quantity': expanded : Quantity A view with the new axis inserted. """ - # self = self.factorless() + return Quantity(jnp.expand_dims(self.mantissa, axis), unit=self.unit) def expand_as(self, array: Union['Quantity', jax.typing.ArrayLike]) -> 'Quantity': @@ -3949,14 +3930,14 @@ def expand_as(self, array: Union['Quantity', jax.typing.ArrayLike]) -> 'Quantity typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location. """ - # self = self.factorless() + if isinstance(array, Quantity): fail_for_dimension_mismatch(self, array, "expand_as (Quantity)") array = array.mantissa return Quantity(jnp.broadcast_to(self.mantissa, array), unit=self.unit) def pow(self, oc) -> 'Quantity': - # self = self.factorless() + return self.__pow__(oc) def clone(self) -> 'Quantity': @@ -3998,15 +3979,15 @@ def cpu(self, device=None) -> 'Quantity': # dtype exchanging # # ---------------- # def half(self) -> 'Quantity': - # self = self.factorless() + return Quantity(jnp.asarray(self.mantissa, dtype=jnp.float16), unit=self.unit) def float(self) -> 'Quantity': - # self = self.factorless() + return Quantity(jnp.asarray(self.mantissa, dtype=jnp.float32), unit=self.unit) def double(self) -> 'Quantity': - # self = self.factorless() + return Quantity(jnp.asarray(self.mantissa, dtype=jnp.float64), unit=self.unit) @@ -4963,3 +4944,89 @@ def _assign_unit(f, val, unit): def _is_quantity(x): return isinstance(x, Quantity) + + +def _convert_in_si(x): + """ + Convert a Quantity to a Quantity in SI units. + """ + if isinstance(x, Quantity) or isinstance(x, Unit): + return x.factorless() + return x + + +def convert_in_si(): + """ + Convert all the local variables in SI units. + + This function traverses the local variables in the calling scope and converts all `Quantity` + instances (including those nested in lists, tuples, or dictionaries) to their SI unit equivalents. + The conversion is performed by calling the `factorless()` method on each `Quantity` instance, + which strips the unit and returns the raw value in SI units. + + Notes: + - This function modifies the local variables in the calling scope. + - Only `Quantity` instances are affected; other types of variables remain unchanged. + - If a `Quantity` instance is nested within a list, tuple, or dictionary, it will be + recursively converted to its SI unit equivalent. + + Examples: + >>> import brainunit as u + >>> time1 = 1 * u.second + >>> time2 = 1 * u.minute + >>> time3 = time1 + time2 + >>> time4 = time2 + time1 + >>> time3 + 61. * second + >>> time4 + 1.0166667 * minute + + >>> u.convert_in_si() # Convert all local variables to SI units + >>> time3 = time1 + time2 + >>> time4 = time2 + time1 + >>> time3 + 61. * second + >>> time4 + 6.1 * dasecond + + >>> length1 = 1 * u.inch + >>> result1 = time1 * length1 + >>> result2 = u.math.multiply(time1, length1) + >>> result1 + 1. * second * inch + >>> result2 + 1. * second * inch + + >>> u.convert_in_si() # Convert all local variables to SI units + >>> result1 = time1 * length1 + >>> result2 = u.math.multiply(time1, length1) + >>> result1 + 0.0254 * second * meter + >>> result2 + 0.0254 * second * meter + + >>> dict1 = { + ... 'time1': 1 * u.second, + ... 'time2': 1 * u.minute, + ... 'length1': 1 * u.inch, + ...} + >>> u.convert_in_si() # Convert all local variables to SI units + >>> dict1 + {'length1': 0.0254 * meter, 'time1': 1 * second, 'time2': 6. * dasecond} + + Raises: + None: This function does not raise any exceptions explicitly, but may propagate + exceptions from `factorless()` or `jax.tree.map()` if they fail. + + See Also: + - `Quantity.factorless()`: Method used to convert `Quantity` instances to SI units. + """ + frame = inspect.currentframe() + try: + caller_frame = frame.f_back + caller_globals = caller_frame.f_globals + + for key, val in list(caller_globals.items()): + caller_globals[key] = jax.tree.map(_convert_in_si, val, is_leaf=lambda x: _is_quantity(x)) + finally: + del frame diff --git a/brainunit/math/_einops.py b/brainunit/math/_einops.py index 0a31a37..3da0ce2 100644 --- a/brainunit/math/_einops.py +++ b/brainunit/math/_einops.py @@ -1212,11 +1212,6 @@ def einsum( .. _opt_einsum: https://github.com/dgasmith/opt_einsum """ - # operands = jax.tree.map( - # lambda x: x.factorless() if isinstance(x, Quantity) else x, - # operands, - # is_leaf=lambda x: isinstance(x, Quantity) - # ) operands = (subscripts, *operands) spec = operands[0] if isinstance(operands[0], str) else None diff --git a/brainunit/math/_fun_accept_unitless.py b/brainunit/math/_fun_accept_unitless.py index bd1bbb6..c72657e 100644 --- a/brainunit/math/_fun_accept_unitless.py +++ b/brainunit/math/_fun_accept_unitless.py @@ -52,7 +52,7 @@ def _fun_accept_unitless_unary( **kwargs ): if isinstance(x, Quantity): - # x = x.factorless() + if unit_to_scale is None: assert x.dim.is_dimensionless, ( f'{func} only support dimensionless input. But we got {x}. \n' @@ -772,7 +772,7 @@ def _fun_accept_unitless_binary( **kwargs ): if isinstance(x, Quantity): - # x = x.factorless() + if unit_to_scale is None: assert x.dim.is_dimensionless, ( f'{func} only support dimensionless input. But we got {x}. \n' @@ -784,7 +784,7 @@ def _fun_accept_unitless_binary( assert isinstance(unit_to_scale, Unit), f'unit_to_scale should be a Unit instance. Got {unit_to_scale}' x = x.to_decimal(unit_to_scale) if isinstance(y, Quantity): - # y = y.factorless() + if unit_to_scale is None: assert y.dim.is_dimensionless, (f'Input should be dimensionless for the function "{func}" ' f'when scaling "unit_to_scale" is not provided.') @@ -1145,11 +1145,9 @@ def invert( def _fun_unitless_binary(func, x, y, *args, **kwargs): if isinstance(x, Quantity): - # x = x.factorless() assert x.dim.is_dimensionless, f'Expected dimensionless array, got {x}' x = x.to_decimal() if isinstance(y, Quantity): - # y = y.factorless() assert y.dim.is_dimensionless, f'Expected dimensionless array, got {y}' y = y.to_decimal() return func(x, y, *args, **kwargs) diff --git a/brainunit/math/_fun_change_unit.py b/brainunit/math/_fun_change_unit.py index 468d42a..84ca3ef 100644 --- a/brainunit/math/_fun_change_unit.py +++ b/brainunit/math/_fun_change_unit.py @@ -46,7 +46,6 @@ def _fun_change_unit_unary(val_fun, unit_fun, x, *args, **kwargs): if isinstance(x, Quantity): - # x = x.factorless() r = Quantity(val_fun(x.mantissa, *args, **kwargs), unit=unit_fun(x.unit)) return maybe_decimal(r) return val_fun(x, *args, **kwargs) @@ -517,18 +516,17 @@ def nancumprod( def _fun_change_unit_binary(val_fun, unit_fun, x, y, *args, **kwargs): if isinstance(x, Quantity) and isinstance(y, Quantity): - # x = x.factorless() - # y = y.factorless() + return maybe_decimal( Quantity(val_fun(x.mantissa, y.mantissa, *args, **kwargs), unit=unit_fun(x.unit, y.unit)) ) elif isinstance(x, Quantity): - # x = x.factorless() + return maybe_decimal( Quantity(val_fun(x.mantissa, y, *args, **kwargs), unit=unit_fun(x.unit, UNITLESS)) ) elif isinstance(y, Quantity): - # y = y.factorless() + return maybe_decimal( Quantity(val_fun(x, y.mantissa, *args, **kwargs), unit=unit_fun(UNITLESS, y.unit)) ) diff --git a/brainunit/math/_fun_keep_unit.py b/brainunit/math/_fun_keep_unit.py index 285a694..6ca3c04 100644 --- a/brainunit/math/_fun_keep_unit.py +++ b/brainunit/math/_fun_keep_unit.py @@ -83,11 +83,6 @@ def _fun_keep_unit_sequence( **kwargs ): leaves, treedef = jax.tree.flatten(args, is_leaf=lambda x: isinstance(x, Quantity)) - # leaves = jax.tree.map( - # lambda x: x.factorless() if isinstance(x, Quantity) else x, - # leaves, - # is_leaf=lambda x: isinstance(x, Quantity) - # ) leaves = unit_scale_align_to_first(*leaves) unit = leaves[0].unit leaves = [x.mantissa for x in leaves] diff --git a/brainunit/math/_fun_remove_unit.py b/brainunit/math/_fun_remove_unit.py index 773dc97..4694bb2 100644 --- a/brainunit/math/_fun_remove_unit.py +++ b/brainunit/math/_fun_remove_unit.py @@ -67,7 +67,7 @@ def get_promote_dtypes( def _fun_remove_unit_unary(func, x, *args, **kwargs): if isinstance(x, Quantity): - # x = x.factorless() + return func(x.mantissa, *args, **kwargs) else: return func(x, *args, **kwargs) @@ -374,15 +374,14 @@ def logical_not( def _fun_logic_binary(func, x, y, *args, **kwargs): if isinstance(x, Quantity) and isinstance(y, Quantity): - # x = x.factorless() - # y = y.factorless() + return func(x.mantissa, y.in_unit(x.unit).mantissa, *args, **kwargs) elif isinstance(x, Quantity): - # x = x.factorless() + assert x.is_unitless, f'Expected unitless array when y is not Quantity, while got {x}' return func(x.mantissa, y, *args, **kwargs) elif isinstance(y, Quantity): - # y = y.factorless() + assert y.is_unitless, f'Expected unitless array when x is not Quantity, while got {y}' return func(x, y.mantissa, *args, **kwargs) else: From ce1944022a6a5442ee681b8341bcffa774b01ce9 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Thu, 16 Jan 2025 17:40:06 +0800 Subject: [PATCH 2/6] Update _base.py --- brainunit/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index 28279a3..b9bec79 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -4962,7 +4962,7 @@ def convert_in_si(): This function traverses the local variables in the calling scope and converts all `Quantity` instances (including those nested in lists, tuples, or dictionaries) to their SI unit equivalents. The conversion is performed by calling the `factorless()` method on each `Quantity` instance, - which strips the unit and returns the raw value in SI units. + which convert the unit and returns the quantities in SI units. Notes: - This function modifies the local variables in the calling scope. From 16001624ffed652cf1dac3372e0ebd3ff706574e Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 20 Jan 2025 23:53:31 +0800 Subject: [PATCH 3/6] Add environ module --- brainunit/__init__.py | 2 + brainunit/_base.py | 96 ++--------------- brainunit/environ.py | 216 ++++++++++++++++++++++++++++++++++++++ brainunit/environ_test.py | 37 +++++++ 4 files changed, 264 insertions(+), 87 deletions(-) create mode 100644 brainunit/environ.py create mode 100644 brainunit/environ_test.py diff --git a/brainunit/__init__.py b/brainunit/__init__.py index 86630e8..6b0047c 100644 --- a/brainunit/__init__.py +++ b/brainunit/__init__.py @@ -18,6 +18,7 @@ from . import _matplotlib_compat from . import autograd from . import constants +from . import environ from . import fft from . import lax from . import linalg @@ -38,6 +39,7 @@ 'math', 'linalg', 'autograd', + 'environ', 'fft', 'constants', 'sparse' diff --git a/brainunit/_base.py b/brainunit/_base.py index b9bec79..7a4c1a6 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -29,6 +29,9 @@ from jax.interpreters.partial_eval import DynamicJaxprTracer from jax.tree_util import register_pytree_node_class +from .environ import (get_compute_mode, + SI_MODE, + NON_SI_MODE) from ._misc import set_module_as from ._sparse_base import SparseMatrix @@ -65,7 +68,6 @@ # advanced functions 'get_or_create_dimension', - 'convert_in_si', ] StaticScalar = Union[ @@ -1233,6 +1235,7 @@ def _assert_same_base(u1, u2): f"But we got {u1.base} != {u1.base}.") +# TODO: Cannot find compound standard unit def _find_standard_unit(dim: Dimension, base, scale, factor) -> Tuple[Optional[str], bool, bool]: """ Find a standard unit for the given dimension, base, scale, and factor. @@ -2199,6 +2202,11 @@ def __init__( # dimension self._unit = unit + if get_compute_mode() == SI_MODE: + self._mantissa = self._mantissa * self._unit.factor + self._unit = self._unit.factorless() + + @property def at(self): """ @@ -4944,89 +4952,3 @@ def _assign_unit(f, val, unit): def _is_quantity(x): return isinstance(x, Quantity) - - -def _convert_in_si(x): - """ - Convert a Quantity to a Quantity in SI units. - """ - if isinstance(x, Quantity) or isinstance(x, Unit): - return x.factorless() - return x - - -def convert_in_si(): - """ - Convert all the local variables in SI units. - - This function traverses the local variables in the calling scope and converts all `Quantity` - instances (including those nested in lists, tuples, or dictionaries) to their SI unit equivalents. - The conversion is performed by calling the `factorless()` method on each `Quantity` instance, - which convert the unit and returns the quantities in SI units. - - Notes: - - This function modifies the local variables in the calling scope. - - Only `Quantity` instances are affected; other types of variables remain unchanged. - - If a `Quantity` instance is nested within a list, tuple, or dictionary, it will be - recursively converted to its SI unit equivalent. - - Examples: - >>> import brainunit as u - >>> time1 = 1 * u.second - >>> time2 = 1 * u.minute - >>> time3 = time1 + time2 - >>> time4 = time2 + time1 - >>> time3 - 61. * second - >>> time4 - 1.0166667 * minute - - >>> u.convert_in_si() # Convert all local variables to SI units - >>> time3 = time1 + time2 - >>> time4 = time2 + time1 - >>> time3 - 61. * second - >>> time4 - 6.1 * dasecond - - >>> length1 = 1 * u.inch - >>> result1 = time1 * length1 - >>> result2 = u.math.multiply(time1, length1) - >>> result1 - 1. * second * inch - >>> result2 - 1. * second * inch - - >>> u.convert_in_si() # Convert all local variables to SI units - >>> result1 = time1 * length1 - >>> result2 = u.math.multiply(time1, length1) - >>> result1 - 0.0254 * second * meter - >>> result2 - 0.0254 * second * meter - - >>> dict1 = { - ... 'time1': 1 * u.second, - ... 'time2': 1 * u.minute, - ... 'length1': 1 * u.inch, - ...} - >>> u.convert_in_si() # Convert all local variables to SI units - >>> dict1 - {'length1': 0.0254 * meter, 'time1': 1 * second, 'time2': 6. * dasecond} - - Raises: - None: This function does not raise any exceptions explicitly, but may propagate - exceptions from `factorless()` or `jax.tree.map()` if they fail. - - See Also: - - `Quantity.factorless()`: Method used to convert `Quantity` instances to SI units. - """ - frame = inspect.currentframe() - try: - caller_frame = frame.f_back - caller_globals = caller_frame.f_globals - - for key, val in list(caller_globals.items()): - caller_globals[key] = jax.tree.map(_convert_in_si, val, is_leaf=lambda x: _is_quantity(x)) - finally: - del frame diff --git a/brainunit/environ.py b/brainunit/environ.py new file mode 100644 index 0000000..75e3bff --- /dev/null +++ b/brainunit/environ.py @@ -0,0 +1,216 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import contextlib +import dataclasses +import functools +import inspect +import os +import re +import threading +from collections import defaultdict +from typing import Any, Callable, Dict, Hashable + +__all__ = [ + # functions for environment settings + 'set', 'context', 'get', 'all', + # functions for getting default behaviors + 'get_compute_mode', + # constants + 'SI_MODE', 'NON_SI_MODE' +] + +SI_MODE: str = 'si' +NON_SI_MODE: str = 'non_si' + + +@dataclasses.dataclass +class DefaultContext(threading.local): + # default environment settings + settings: Dict[Hashable, Any] = dataclasses.field(default_factory=dict) + # current environment settings + contexts: defaultdict[Hashable, Any] = dataclasses.field(default_factory=lambda: defaultdict(list)) + # environment functions + functions: Dict[Hashable, Any] = dataclasses.field(default_factory=dict) + +DEFAULT = DefaultContext() +_NOT_PROVIDE = object() + + +@contextlib.contextmanager +def context(**kwargs): + r""" + Context-manager that sets a computing environment for brainunit. + + For instance:: + + >>> import brainunit as u + >>> global_1 = 2 * u.kmh + >>> global_2 = 0 + >>> def create_a(a): + ... return a.mantissa * 2 * u.minute + >>> with u.environ.context(compute_mode='si'): + ... a = create_a([1, 2, 3] * u.minute) # If input is [1, 2, 3] * u.second, the result would differ + ... b = [4, 5, 6] * u.inch + ... global_2 = (b / a) / global_1 + + """ + if 'compute_mode' in kwargs: + if kwargs['compute_mode'] == SI_MODE: + _convert_to_si_quantity(**kwargs) + else: + pass + + try: + for k, v in kwargs.items(): + + # update the current environment + DEFAULT.contexts[k].append(v) + + # restore the environment functions + if k in DEFAULT.functions: + DEFAULT.functions[k](v) + + # yield the current all environment information + yield all() + finally: + + for k, v in kwargs.items(): + + # restore the current environment + DEFAULT.contexts[k].pop() + + # restore the environment functions + if k in DEFAULT.functions: + DEFAULT.functions[k](get(k)) + + +def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None): + """ + Get one of the default computation environment. + + Returns + ------- + item: Any + The default computation environment. + """ + if key in DEFAULT.contexts: + if len(DEFAULT.contexts[key]) > 0: + return DEFAULT.contexts[key][-1] + if key in DEFAULT.settings: + return DEFAULT.settings[key] + + if default is _NOT_PROVIDE: + if desc is not None: + raise KeyError( + f"'{key}' is not found in the context. \n" + f"You can set it by `brainstate.share.context({key}=value)` " + f"locally or `brainstate.share.set({key}=value)` globally. \n" + f"Description: {desc}" + ) + else: + raise KeyError( + f"'{key}' is not found in the context. \n" + f"You can set it by `brainstate.share.context({key}=value)` " + f"locally or `brainstate.share.set({key}=value)` globally." + ) + return default + + +def all() -> dict: + """ + Get all the current default computation environment. + + Returns + ------- + r: dict + The current default computation environment. + """ + r = dict() + for k, v in DEFAULT.contexts.items(): + if v: + r[k] = v[-1] + for k, v in DEFAULT.settings.items(): + if k not in r: + r[k] = v + return r + + +def get_compute_mode() -> str: + """ + Get the current compute mode. + + Returns + ------- + mode: str + The current compute mode. + """ + return get('compute_mode') + +def set( + compute_mode: str = None, + **kwargs +): + """ + Set the global default computation environment. + + + + Args: + compute_mode: str, optional + The default compute mode. Default is computing in 'si'. + """ + if compute_mode is not None: + assert compute_mode in ['si', 'non_si'], f"compute_mode must be 'si' or 'non_si'. Got: {compute_mode}" + kwargs['compute_mode'] = compute_mode + + # set default environment + DEFAULT.settings.update(kwargs) + + # update the environment functions + for k, v in kwargs.items(): + if k in DEFAULT.functions: + DEFAULT.functions[k](v) + +def _convert_to_si_quantity(**kwargs): + """ + Convert all the local variables in SI units. + + Traverses the local variables in the calling scope and converts all `Quantity` + instances (including those nested in lists, tuples, or dictionaries) to their SI unit equivalents. + The conversion is performed by calling the `factorless()` method on each `Quantity` instance, + which convert the unit and returns the quantities in SI units. + """ + set(compute_mode=kwargs['compute_mode']) + from ._base import Quantity, Unit + frame = inspect.currentframe().f_back.f_back.f_back + original = {k: v for k, v in frame.f_locals.items() + if isinstance(v, (Quantity, Unit))} + + try: + # Convert to SI + for k, v in original.items(): + frame.f_locals[k] = v.factorless() + yield + finally: + # Restore original values + for k, v in original.items(): + frame.f_locals[k] = v + +set(compute_mode=NON_SI_MODE) \ No newline at end of file diff --git a/brainunit/environ_test.py b/brainunit/environ_test.py new file mode 100644 index 0000000..79728f0 --- /dev/null +++ b/brainunit/environ_test.py @@ -0,0 +1,37 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import unittest + +import pytest +import brainunit as u + +class TestEnviron(unittest.TestCase): + def test_compute_mode(self): + global_1 = 2 * u.kmh + global_2 = 0 + + def create_a(a): + return a.mantissa * 2 * u.minute + + with u.environ.context(compute_mode='si'): + a = create_a([1, 2, 3] * u.minute) # If input is [1, 2, 3] * u.second, the result would differ + b = [4, 5, 6] * u.inch + global_2 = (b / a) / global_1 + + assert a.unit.factor == 1. + assert b.unit.factor == 1. + # TODO: need to fix compound standard units + # assert global_1.unit.factor == 1. \ No newline at end of file From e7a43d47a5181efd6c415a26adccfd2a6c70aff4 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 21 Jan 2025 00:03:59 +0800 Subject: [PATCH 4/6] Update _base_test.py --- brainunit/_base_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index 27ede75..cd7b6b1 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -776,7 +776,12 @@ def test_numpy_functions_indices(self): units = [volt, second, siemens, mV, kHz] # numpy functions - keep_dim_funcs = [np.argmin, np.argmax, np.argsort, np.nonzero] + keep_dim_funcs = [ + np.argmin, + np.argmax, + # np.argsort, # TODO: after upgrading jax 0.5.0, argsort will raise an error + np.nonzero + ] for value, unit in itertools.product(values, units): q_ar = value * unit From cb26bb92d1381329915b8a099eb4b705ff06d03e Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Tue, 21 Jan 2025 00:10:15 +0800 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- brainunit/environ.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/brainunit/environ.py b/brainunit/environ.py index 75e3bff..b4146c8 100644 --- a/brainunit/environ.py +++ b/brainunit/environ.py @@ -110,9 +110,9 @@ def get(key: str, default: Any = _NOT_PROVIDE, desc: str = None): item: Any The default computation environment. """ - if key in DEFAULT.contexts: - if len(DEFAULT.contexts[key]) > 0: - return DEFAULT.contexts[key][-1] + if key in DEFAULT.contexts and len(DEFAULT.contexts[key]) > 0: + return DEFAULT.contexts[key][-1] + if key in DEFAULT.settings: return DEFAULT.settings[key] @@ -142,7 +142,7 @@ def all() -> dict: r: dict The current default computation environment. """ - r = dict() + r = {} for k, v in DEFAULT.contexts.items(): if v: r[k] = v[-1] @@ -177,7 +177,10 @@ def set( The default compute mode. Default is computing in 'si'. """ if compute_mode is not None: - assert compute_mode in ['si', 'non_si'], f"compute_mode must be 'si' or 'non_si'. Got: {compute_mode}" + assert compute_mode in { + 'si', + 'non_si', + }, f"compute_mode must be 'si' or 'non_si'. Got: {compute_mode}" kwargs['compute_mode'] = compute_mode # set default environment From b92bd17a6ea57229d327f791cc57126f5765d866 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 21 Jan 2025 00:10:52 +0800 Subject: [PATCH 6/6] Update environ.py --- brainunit/environ.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/brainunit/environ.py b/brainunit/environ.py index 75e3bff..c864a1c 100644 --- a/brainunit/environ.py +++ b/brainunit/environ.py @@ -74,8 +74,6 @@ def context(**kwargs): if 'compute_mode' in kwargs: if kwargs['compute_mode'] == SI_MODE: _convert_to_si_quantity(**kwargs) - else: - pass try: for k, v in kwargs.items():