Skip to content

Commit

Permalink
fix tree flatten and unflatten error
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 11, 2024
1 parent 52ae103 commit 9a416da
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,9 +958,9 @@ class Quantity(object):
def __init__(
self,
value: Any,
dtype: bst.typing.DTypeLike = None,
dtype: Optional[bst.typing.DTypeLike] = None,
dim: Dimension = DIMENSIONLESS,
unit: 'Unit' = None,
unit: Optional['Unit'] = None,
):
scale, dim = _get_dim(dim, unit)

Expand Down Expand Up @@ -995,6 +995,9 @@ def __init__(
elif isinstance(value, (jnp.number, numbers.Number)):
value = jnp.array(value, dtype=dtype)

elif isinstance(value, (jax.core.ShapedArray, jax.ShapeDtypeStruct)):
value = value

else:
raise TypeError(f"Invalid type for value: {type(value)}")

Expand Down Expand Up @@ -1464,22 +1467,22 @@ def _binary_operation(
Whether to do the operation in-place (defaults to ``False``).
"""
other = _to_quantity(other)
other_unit = None
other_dim = None

if fail_for_mismatch:
if inplace:
message = "Cannot calculate ... %s {value}, units do not match" % operator_str
_, other_unit = fail_for_dimension_mismatch(self, other, message, value=other)
_, other_dim = fail_for_dimension_mismatch(self, other, message, value=other)
else:
message = "Cannot calculate {value1} %s {value2}, units do not match" % operator_str
_, other_unit = fail_for_dimension_mismatch(self, other, message, value1=self, value2=other)
_, other_dim = fail_for_dimension_mismatch(self, other, message, value1=self, value2=other)

if other_unit is None:
other_unit = get_unit(other)
if other_dim is None:
other_dim = get_unit(other)

new_unit = unit_operation(self.dim, other_unit)
new_dim = unit_operation(self.dim, other_dim)
result = value_operation(self.value, other.value)
r = Quantity(result, dim=new_unit)
r = Quantity(result, dim=new_dim)
if inplace:
self.update_value(r.value)
return self
Expand Down Expand Up @@ -2315,23 +2318,23 @@ def tree_flatten(self) -> Tuple[jax.Array | numbers.Number, Any]:
Tree flattens the data.
Returns:
The data and the unit.
The data and the dimension.
"""
return self.value, self.dim
return (self.value,), self.dim

@classmethod
def tree_unflatten(cls, unit, value) -> 'Quantity':
def tree_unflatten(cls, dim, value) -> 'Quantity':
"""
Tree unflattens the data.
Args:
unit: The unit.
dim: The dimension.
value: The data.
Returns:
The Quantity object.
"""
return cls(value, dim=unit)
return cls(*value, dim=dim)

def cuda(self, deice=None) -> 'Quantity':
deice = jax.devices('cuda')[0] if deice is None else deice
Expand Down

0 comments on commit 9a416da

Please sign in to comment.