Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 12, 2024
1 parent 7ee0d89 commit 475b16e
Show file tree
Hide file tree
Showing 14 changed files with 892 additions and 2,323 deletions.
267 changes: 162 additions & 105 deletions brainunit/math/_compat_numpy_array_manipulation.py

Large diffs are not rendered by default.

212 changes: 99 additions & 113 deletions brainunit/math/_compat_numpy_funcs_accept_unitless.py

Large diffs are not rendered by default.

86 changes: 35 additions & 51 deletions brainunit/math/_compat_numpy_funcs_bit_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from brainstate._utils import set_module_as
from jax import Array

from .._base import (Quantity,
Expand All @@ -36,73 +37,56 @@
# Elementwise bit operations (unary)
# ----------------------------------

def wrap_elementwise_bit_operation_unary(func):
@wraps(func)
def f(x, *args, **kwargs):
if isinstance(x, Quantity):
raise ValueError(f'Expected integers, got {x}')
elif isinstance(x, (jax.Array, np.ndarray)):
return func(x, *args, **kwargs)
else:
raise ValueError(f'Unsupported types {type(x)} for {func.__name__}')
def elementwise_bit_operation_unary(func, x, *args, **kwargs):
if isinstance(x, Quantity):
raise ValueError(f'Expected integers, got {x}')
elif isinstance(x, (jax.Array, np.ndarray)):
return func(x, *args, **kwargs)
else:
raise ValueError(f'Unsupported types {type(x)} for {func.__name__}')

f.__module__ = 'brainunit.math'
return f


@wrap_elementwise_bit_operation_unary
@set_module_as('brainunit.math')
def bitwise_not(x: Union[Quantity, jax.typing.ArrayLike]) -> Array:
return jnp.bitwise_not(x)


@wrap_elementwise_bit_operation_unary
def invert(x: Union[Quantity, jax.typing.ArrayLike]) -> Array:
return jnp.invert(x)


# docs for functions above
bitwise_not.__doc__ = '''
'''
Compute the bit-wise NOT of an array, element-wise.
Args:
x: array_like
Returns:
jax.Array: an array
'''
'''
return elementwise_bit_operation_unary(jnp.bitwise_not, x)


invert.__doc__ = '''
@set_module_as('brainunit.math')
def invert(x: Union[Quantity, jax.typing.ArrayLike]) -> Array:
'''
Compute bit-wise inversion, or bit-wise NOT, element-wise.
Args:
x: array_like
Returns:
jax.Array: an array
'''
'''
return elementwise_bit_operation_unary(jnp.invert, x)


# Elementwise bit operations (binary)
# -----------------------------------

def wrap_elementwise_bit_operation_binary(func):
@wraps(func)
def decorator(*args, **kwargs):
def f(x, y, *args, **kwargs):
if isinstance(x, Quantity) or isinstance(y, Quantity):
raise ValueError(f'Expected integers, got {x} and {y}')
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)):
return func(x, y, *args, **kwargs)
else:
raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}')

f.__module__ = 'brainunit.math'
return f

return decorator

def elementwise_bit_operation_binary(func, x, y, *args, **kwargs):
if isinstance(x, Quantity) or isinstance(y, Quantity):
raise ValueError(f'Expected integers, got {x} and {y}')
elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)):
return func(x, y, *args, **kwargs)
else:
raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}')

@wrap_elementwise_bit_operation_binary(jnp.bitwise_and)
@set_module_as('brainunit.math')
def bitwise_and(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike]
Expand All @@ -117,10 +101,10 @@ def bitwise_and(
Returns:
jax.Array: an array
'''
...
return elementwise_bit_operation_binary(jnp.bitwise_and, x, y)


@wrap_elementwise_bit_operation_binary(jnp.bitwise_or)
@set_module_as('brainunit.math')
def bitwise_or(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike]
Expand All @@ -135,10 +119,10 @@ def bitwise_or(
Returns:
jax.Array: an array
'''
...
return elementwise_bit_operation_binary(jnp.bitwise_or, x, y)


@wrap_elementwise_bit_operation_binary(jnp.bitwise_xor)
@set_module_as('brainunit.math')
def bitwise_xor(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike]
Expand All @@ -153,10 +137,10 @@ def bitwise_xor(
Returns:
jax.Array: an array
'''
...
return elementwise_bit_operation_binary(jnp.bitwise_xor, x, y)


@wrap_elementwise_bit_operation_binary(jnp.left_shift)
@set_module_as('brainunit.math')
def left_shift(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike]
Expand All @@ -171,10 +155,10 @@ def left_shift(
Returns:
jax.Array: an array
'''
...
return elementwise_bit_operation_binary(jnp.left_shift, x, y)


@wrap_elementwise_bit_operation_binary(jnp.right_shift)
@set_module_as('brainunit.math')
def right_shift(
x: Union[Quantity, jax.typing.ArrayLike],
y: Union[Quantity, jax.typing.ArrayLike]
Expand All @@ -189,4 +173,4 @@ def right_shift(
Returns:
jax.Array: an array
'''
...
return elementwise_bit_operation_binary(jnp.right_shift, x, y)
Loading

0 comments on commit 475b16e

Please sign in to comment.