From 9a416da9d143295d2dd15e98026250fbc915703a Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Tue, 11 Jun 2024 14:52:37 +0800 Subject: [PATCH] fix tree flatten and unflatten error --- brainunit/_base.py | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index 75793c0..0efd32a 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -958,9 +958,9 @@ class Quantity(object): def __init__( self, value: Any, - dtype: bst.typing.DTypeLike = None, + dtype: Optional[bst.typing.DTypeLike] = None, dim: Dimension = DIMENSIONLESS, - unit: 'Unit' = None, + unit: Optional['Unit'] = None, ): scale, dim = _get_dim(dim, unit) @@ -995,6 +995,9 @@ def __init__( elif isinstance(value, (jnp.number, numbers.Number)): value = jnp.array(value, dtype=dtype) + elif isinstance(value, (jax.core.ShapedArray, jax.ShapeDtypeStruct)): + value = value + else: raise TypeError(f"Invalid type for value: {type(value)}") @@ -1464,22 +1467,22 @@ def _binary_operation( Whether to do the operation in-place (defaults to ``False``). """ other = _to_quantity(other) - other_unit = None + other_dim = None if fail_for_mismatch: if inplace: message = "Cannot calculate ... %s {value}, units do not match" % operator_str - _, other_unit = fail_for_dimension_mismatch(self, other, message, value=other) + _, other_dim = fail_for_dimension_mismatch(self, other, message, value=other) else: message = "Cannot calculate {value1} %s {value2}, units do not match" % operator_str - _, other_unit = fail_for_dimension_mismatch(self, other, message, value1=self, value2=other) + _, other_dim = fail_for_dimension_mismatch(self, other, message, value1=self, value2=other) - if other_unit is None: - other_unit = get_unit(other) + if other_dim is None: + other_dim = get_unit(other) - new_unit = unit_operation(self.dim, other_unit) + new_dim = unit_operation(self.dim, other_dim) result = value_operation(self.value, other.value) - r = Quantity(result, dim=new_unit) + r = Quantity(result, dim=new_dim) if inplace: self.update_value(r.value) return self @@ -2315,23 +2318,23 @@ def tree_flatten(self) -> Tuple[jax.Array | numbers.Number, Any]: Tree flattens the data. Returns: - The data and the unit. + The data and the dimension. """ - return self.value, self.dim + return (self.value,), self.dim @classmethod - def tree_unflatten(cls, unit, value) -> 'Quantity': + def tree_unflatten(cls, dim, value) -> 'Quantity': """ Tree unflattens the data. Args: - unit: The unit. + dim: The dimension. value: The data. Returns: The Quantity object. """ - return cls(value, dim=unit) + return cls(*value, dim=dim) def cuda(self, deice=None) -> 'Quantity': deice = jax.devices('cuda')[0] if deice is None else deice