From 8919787a724d5f3a3c49934054996310f5526890 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 2 Oct 2022 13:39:57 +0800 Subject: [PATCH 1/2] operations with JaxArray and numpy ndarray do not cuase errors --- brainpy/math/jaxarray.py | 245 ++++++++++++++-------------- brainpy/math/tests/test_jaxarray.py | 7 + 2 files changed, 129 insertions(+), 123 deletions(-) diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index f9003df16..09ad63bbf 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -49,6 +49,15 @@ def turn_off_global_jit(): _global_jit_mode = False +def _check_input_array(array): + if isinstance(array, JaxArray): + return array.value + elif isinstance(array, np.ndarray): + return jnp.asarray(array) + else: + return array + + class JaxArray(object): """Multiple-dimensional array in JAX backend. """ @@ -174,7 +183,7 @@ def __getitem__(self, index): if isinstance(index, slice) and (index == _all_slice): return self.value elif isinstance(index, tuple): - index = tuple(x.value if isinstance(x, JaxArray) else x for x in index) + index = tuple(_check_input_array(x) for x in index) elif isinstance(index, JaxArray): index = index.value return self.value[index] @@ -189,7 +198,7 @@ def __setitem__(self, index, value): # tuple index if isinstance(index, tuple): - index = tuple(x.value if isinstance(x, JaxArray) else x for x in index) + index = tuple(_check_input_array(x) for x in index) # JaxArray index elif isinstance(index, JaxArray): @@ -221,199 +230,199 @@ def __invert__(self): return JaxArray(self._value.__invert__()) def __eq__(self, oc): - return JaxArray(self._value == (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value == _check_input_array(oc)) def __ne__(self, oc): - return JaxArray(self._value != (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value != _check_input_array(oc)) def __lt__(self, oc): - return JaxArray(self._value < (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value < _check_input_array(oc)) def __le__(self, oc): - return JaxArray(self._value <= (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value <= _check_input_array(oc)) def __gt__(self, oc): - return JaxArray(self._value > (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value > _check_input_array(oc)) def __ge__(self, oc): - return JaxArray(self._value >= (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value >= _check_input_array(oc)) def __add__(self, oc): - return JaxArray(self._value + (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value + _check_input_array(oc)) def __radd__(self, oc): - return JaxArray(self._value + (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value + _check_input_array(oc)) def __iadd__(self, oc): # a += b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value += (oc.value if isinstance(oc, JaxArray) else oc) + self._value += _check_input_array(oc) return self def __sub__(self, oc): - return JaxArray(self._value - (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value - _check_input_array(oc)) def __rsub__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) - self._value) + return JaxArray(_check_input_array(oc) - self._value) def __isub__(self, oc): # a -= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value - (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value - _check_input_array(oc) return self def __mul__(self, oc): - return JaxArray(self._value * (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value * _check_input_array(oc)) def __rmul__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) * self._value) + return JaxArray(_check_input_array(oc) * self._value) def __imul__(self, oc): # a *= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value * (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value * _check_input_array(oc) return self def __rdiv__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) / self._value) + return JaxArray(_check_input_array(oc) / self._value) def __truediv__(self, oc): - return JaxArray(self._value / (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value / _check_input_array(oc)) def __rtruediv__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) / self._value) + return JaxArray(_check_input_array(oc) / self._value) def __itruediv__(self, oc): # a /= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value / (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value / _check_input_array(oc) return self def __floordiv__(self, oc): - return JaxArray(self._value // (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value // _check_input_array(oc)) def __rfloordiv__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) // self._value) + return JaxArray(_check_input_array(oc) // self._value) def __ifloordiv__(self, oc): # a //= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value // (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value // _check_input_array(oc) return self def __divmod__(self, oc): - return JaxArray(self._value.__divmod__(oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value.__divmod__(_check_input_array(oc))) def __rdivmod__(self, oc): - return JaxArray(self._value.__rdivmod__(oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value.__rdivmod__(_check_input_array(oc))) def __mod__(self, oc): - return JaxArray(self._value % (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value % _check_input_array(oc)) def __rmod__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) % self._value) + return JaxArray(_check_input_array(oc) % self._value) def __imod__(self, oc): # a %= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value % (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value % _check_input_array(oc) return self def __pow__(self, oc): - return JaxArray(self._value ** (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value ** _check_input_array(oc)) def __rpow__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) ** self._value) + return JaxArray(_check_input_array(oc) ** self._value) def __ipow__(self, oc): # a **= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value ** (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value ** _check_input_array(oc) return self def __matmul__(self, oc): - return JaxArray(self._value @ (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value @ _check_input_array(oc)) def __rmatmul__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) @ self._value) + return JaxArray(_check_input_array(oc) @ self._value) def __imatmul__(self, oc): # a @= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value @ (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value @ _check_input_array(oc) return self def __and__(self, oc): - return JaxArray(self._value & (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value & _check_input_array(oc)) def __rand__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) & self._value) + return JaxArray(_check_input_array(oc) & self._value) def __iand__(self, oc): # a &= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value & (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value & _check_input_array(oc) return self def __or__(self, oc): - return JaxArray(self._value | (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value | _check_input_array(oc)) def __ror__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) | self._value) + return JaxArray(_check_input_array(oc) | self._value) def __ior__(self, oc): # a |= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value | (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value | _check_input_array(oc) return self def __xor__(self, oc): - return JaxArray(self._value ^ (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value ^ _check_input_array(oc)) def __rxor__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) ^ self._value) + return JaxArray(_check_input_array(oc) ^ self._value) def __ixor__(self, oc): # a ^= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value ^ (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value ^ _check_input_array(oc) return self def __lshift__(self, oc): - return JaxArray(self._value << (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value << _check_input_array(oc)) def __rlshift__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) << self._value) + return JaxArray(_check_input_array(oc) << self._value) def __ilshift__(self, oc): # a <<= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value << (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value << _check_input_array(oc) return self def __rshift__(self, oc): - return JaxArray(self._value >> (oc.value if isinstance(oc, JaxArray) else oc)) + return JaxArray(self._value >> _check_input_array(oc)) def __rrshift__(self, oc): - return JaxArray((oc.value if isinstance(oc, JaxArray) else oc) >> self._value) + return JaxArray(_check_input_array(oc) >> self._value) def __irshift__(self, oc): # a >>= b if self._outside_global_jit and _global_jit_mode: raise MathError(msg) - self._value = self._value >> (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self._value >> _check_input_array(oc) return self def __round__(self, ndigits=None): @@ -993,7 +1002,7 @@ def __setitem__(self, index, value): # tuple index if isinstance(index, tuple): - index = tuple(x.value if isinstance(x, JaxArray) else x for x in index) + index = tuple(_check_input_array(x) for x in index) # JaxArray index elif isinstance(index, JaxArray): @@ -1004,77 +1013,67 @@ def __setitem__(self, index, value): def __iadd__(self, oc): # a += b - # self._value += (oc.value if isinstance(oc, JaxArray) else oc) - self._value = self.value + (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value + _check_input_array(oc) return self def __isub__(self, oc): # a -= b - self._value = self.value - (oc.value if isinstance(oc, JaxArray) else oc) - # self._value -= (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value - _check_input_array(oc) return self def __imul__(self, oc): # a *= b - self._value = self.value * (oc.value if isinstance(oc, JaxArray) else oc) - # self._value *= (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value * _check_input_array(oc) return self def __itruediv__(self, oc): # a /= b - self._value = self.value / (oc.value if isinstance(oc, JaxArray) else oc) - # self._value /= (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value / _check_input_array(oc) return self def __ifloordiv__(self, oc): # a //= b - self._value = self.value // (oc.value if isinstance(oc, JaxArray) else oc) - # self._value //= (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value // _check_input_array(oc) return self def __imod__(self, oc): # a %= b - self._value = self.value % (oc.value if isinstance(oc, JaxArray) else oc) - # self._value %= (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value % _check_input_array(oc) return self def __ipow__(self, oc): # a **= b - self._value = self.value ** (oc.value if isinstance(oc, JaxArray) else oc) - # self._value **= (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value ** _check_input_array(oc) return self def __imatmul__(self, oc): # a @= b - self._value = self.value @ (oc.value if isinstance(oc, JaxArray) else oc) - # self._value @= (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value @ _check_input_array(oc) return self def __iand__(self, oc): # a &= b - self._value = self.value.__and__(oc.value if isinstance(oc, JaxArray) else oc) - # self._value &= (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value.__and__(_check_input_array(oc)) return self def __ior__(self, oc): # a |= b - self._value = self.value | (oc.value if isinstance(oc, JaxArray) else oc) - # self._value |= (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value | _check_input_array(oc) return self def __ixor__(self, oc): # a ^= b - self._value = self.value ^ (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value ^ _check_input_array(oc) return self def __ilshift__(self, oc): # a <<= b - self._value = self.value << (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value << _check_input_array(oc) return self def __irshift__(self, oc): # a >>= b - self._value = self.value >> (oc.value if isinstance(oc, JaxArray) else oc) + self._value = self.value >> _check_input_array(oc) return self def fill(self, value): @@ -1108,109 +1107,109 @@ def __invert__(self): return self.value.__invert__() def __eq__(self, oc): - return self.value == (oc.value if isinstance(oc, JaxArray) else oc) + return self.value == _check_input_array(oc) def __ne__(self, oc): - return self.value != (oc.value if isinstance(oc, JaxArray) else oc) + return self.value != _check_input_array(oc) def __lt__(self, oc): - return self.value < (oc.value if isinstance(oc, JaxArray) else oc) + return self.value < _check_input_array(oc) def __le__(self, oc): - return self.value <= (oc.value if isinstance(oc, JaxArray) else oc) + return self.value <= _check_input_array(oc) def __gt__(self, oc): - return self.value > (oc.value if isinstance(oc, JaxArray) else oc) + return self.value > _check_input_array(oc) def __ge__(self, oc): - return self.value >= (oc.value if isinstance(oc, JaxArray) else oc) + return self.value >= _check_input_array(oc) def __add__(self, oc): - return self.value + (oc.value if isinstance(oc, JaxArray) else oc) + return self.value + _check_input_array(oc) def __radd__(self, oc): - return self.value + (oc.value if isinstance(oc, JaxArray) else oc) + return self.value + _check_input_array(oc) def __sub__(self, oc): - return self.value - (oc.value if isinstance(oc, JaxArray) else oc) + return self.value - _check_input_array(oc) def __rsub__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) - self.value + return _check_input_array(oc) - self.value def __mul__(self, oc): - return self.value * (oc.value if isinstance(oc, JaxArray) else oc) + return self.value * _check_input_array(oc) def __rmul__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) * self.value + return _check_input_array(oc) * self.value def __rdiv__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) / self.value + return _check_input_array(oc) / self.value def __truediv__(self, oc): - return self.value / (oc.value if isinstance(oc, JaxArray) else oc) + return self.value / _check_input_array(oc) def __rtruediv__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) / self.value + return _check_input_array(oc) / self.value def __floordiv__(self, oc): - return self.value // (oc.value if isinstance(oc, JaxArray) else oc) + return self.value // _check_input_array(oc) def __rfloordiv__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) // self.value + return _check_input_array(oc) // self.value def __divmod__(self, oc): - return self.value.__divmod__(oc.value if isinstance(oc, JaxArray) else oc) + return self.value.__divmod__(_check_input_array(oc)) def __rdivmod__(self, oc): - return self.value.__rdivmod__(oc.value if isinstance(oc, JaxArray) else oc) + return self.value.__rdivmod__(_check_input_array(oc)) def __mod__(self, oc): - return self.value % (oc.value if isinstance(oc, JaxArray) else oc) + return self.value % _check_input_array(oc) def __rmod__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) % self.value + return _check_input_array(oc) % self.value def __pow__(self, oc): - return self.value ** (oc.value if isinstance(oc, JaxArray) else oc) + return self.value ** _check_input_array(oc) def __rpow__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) ** self.value + return _check_input_array(oc) ** self.value def __matmul__(self, oc): - return self.value @ (oc.value if isinstance(oc, JaxArray) else oc) + return self.value @ _check_input_array(oc) def __rmatmul__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) @ self.value + return _check_input_array(oc) @ self.value def __and__(self, oc): - return self.value & (oc.value if isinstance(oc, JaxArray) else oc) + return self.value & _check_input_array(oc) def __rand__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) & self.value + return _check_input_array(oc) & self.value def __or__(self, oc): - return self.value | (oc.value if isinstance(oc, JaxArray) else oc) + return self.value | _check_input_array(oc) def __ror__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) | self.value + return _check_input_array(oc) | self.value def __xor__(self, oc): - return self.value ^ (oc.value if isinstance(oc, JaxArray) else oc) + return self.value ^ _check_input_array(oc) def __rxor__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) ^ self.value + return _check_input_array(oc) ^ self.value def __lshift__(self, oc): - return self.value << (oc.value if isinstance(oc, JaxArray) else oc) + return self.value << _check_input_array(oc) def __rlshift__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) << self.value + return _check_input_array(oc) << self.value def __rshift__(self, oc): - return self.value >> (oc.value if isinstance(oc, JaxArray) else oc) + return self.value >> _check_input_array(oc) def __rrshift__(self, oc): - return (oc.value if isinstance(oc, JaxArray) else oc) >> self.value + return _check_input_array(oc) >> self.value def __round__(self, ndigits=None): return self.value.__round__(ndigits) @@ -1573,7 +1572,7 @@ def __setitem__(self, index, value): # tuple index if isinstance(index, tuple): - index = tuple(x.value if isinstance(x, JaxArray) else x for x in index) + index = tuple(_check_input_array(x) for x in index) # JaxArray index elif isinstance(index, JaxArray): @@ -1584,67 +1583,67 @@ def __setitem__(self, index, value): def __iadd__(self, oc): # a += b - self._value[self.index] = self.value + (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value + _check_input_array(oc) return self def __isub__(self, oc): # a -= b - self._value[self.index] = self.value - (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value - _check_input_array(oc) return self def __imul__(self, oc): # a *= b - self._value[self.index] = self.value * (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value * _check_input_array(oc) return self def __itruediv__(self, oc): # a /= b - self._value[self.index] = self.value / (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value / _check_input_array(oc) return self def __ifloordiv__(self, oc): # a //= b - self._value[self.index] = self.value // (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value // _check_input_array(oc) return self def __imod__(self, oc): # a %= b - self._value[self.index] = self.value % (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value % _check_input_array(oc) return self def __ipow__(self, oc): # a **= b - self._value[self.index] = self.value ** (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value ** _check_input_array(oc) return self def __imatmul__(self, oc): # a @= b - self._value[self.index] = self.value @ (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value @ _check_input_array(oc) return self def __iand__(self, oc): # a &= b - self._value[self.index] = self.value.__and__(oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value.__and__(_check_input_array(oc)) return self def __ior__(self, oc): # a |= b - self._value[self.index] = self.value | (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value | _check_input_array(oc) return self def __ixor__(self, oc): # a ^= b - self._value[self.index] = self.value ^ (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value ^ _check_input_array(oc) return self def __ilshift__(self, oc): # a <<= b - self._value[self.index] = self.value << (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value << _check_input_array(oc) return self def __irshift__(self, oc): # a >>= b - self._value[self.index] = self.value >> (oc.value if isinstance(oc, JaxArray) else oc) + self._value[self.index] = self.value >> _check_input_array(oc) return self def fill(self, value): diff --git a/brainpy/math/tests/test_jaxarray.py b/brainpy/math/tests/test_jaxarray.py index 2f6a9c10e..11017b22a 100644 --- a/brainpy/math/tests/test_jaxarray.py +++ b/brainpy/math/tests/test_jaxarray.py @@ -4,6 +4,7 @@ import unittest import jax.numpy as jnp +import numpy as np from jax.tree_util import tree_flatten, tree_unflatten import brainpy.math as bm @@ -39,6 +40,12 @@ def test_none(self): with self.assertRaises(TypeError): ee = a + e + def test_operation_with_numpy_array(self): + rng = bm.random.RandomState(123) + add = lambda: rng.rand(10) + np.zeros(1) + self.assertTrue(isinstance(add(), bm.JaxArray)) + self.assertTrue(isinstance(bm.jit(add, dyn_vars=rng)(), bm.JaxArray)) + class TestVariable(unittest.TestCase): def test_variable_init(self): From 6685e1aaf035fa688bd9e448442f5914a9598788 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 2 Oct 2022 14:29:25 +0800 Subject: [PATCH 2/2] update NumPy ndarray and JaxArray operations --- brainpy/math/jaxarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 09ad63bbf..196bf8ace 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -183,7 +183,7 @@ def __getitem__(self, index): if isinstance(index, slice) and (index == _all_slice): return self.value elif isinstance(index, tuple): - index = tuple(_check_input_array(x) for x in index) + index = tuple((x.value if isinstance(x, JaxArray) else x) for x in index) elif isinstance(index, JaxArray): index = index.value return self.value[index]