Skip to content

Commit

Permalink
refactor(All APIs): Added new pre-commit hook for ordering functions
Browse files Browse the repository at this point in the history
  • Loading branch information
NripeshN committed Aug 31, 2023
1 parent 38338fe commit 6515e43
Show file tree
Hide file tree
Showing 182 changed files with 36,023 additions and 36,140 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ repos:
# Exclude everything in frontends except __init__.py, and func_wrapper.py
exclude: 'ivy/functional/(frontends|backends)/(?!.*/func_wrapper\.py$).*(?!__init__\.py$)'
- repo: https://github.com/unifyai/lint-hook
rev: 27646397c5390f644a645f439535b1061b9c0105
rev: 5abf78187bb7a5f839174ce4febd4e63f5d758f6
hooks:
- id: ivy-lint
32 changes: 16 additions & 16 deletions ivy/functional/backends/jax/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def gelu(
return jax.nn.gelu(x, approximate)


def hardswish(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jax.nn.hard_swish(x)


def leaky_relu(
x: JaxArray,
/,
Expand All @@ -34,6 +38,18 @@ def leaky_relu(
return jnp.asarray(jnp.where(x > 0, x, jnp.multiply(x, alpha)), x.dtype)


def log_softmax(
x: JaxArray, /, *, axis: Optional[int] = None, out: Optional[JaxArray] = None
):
if axis is None:
axis = -1
return jax.nn.log_softmax(x, axis)


def mish(x: JaxArray, /, *, out: Optional[JaxArray] = None):
return x * jnp.tanh(jax.nn.softplus(x))


def relu(
x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None
) -> JaxArray:
Expand Down Expand Up @@ -78,19 +94,3 @@ def softplus(
if threshold is not None:
return jnp.where(x_beta > threshold, x, res).astype(x.dtype)
return res.astype(x.dtype)


def log_softmax(
x: JaxArray, /, *, axis: Optional[int] = None, out: Optional[JaxArray] = None
):
if axis is None:
axis = -1
return jax.nn.log_softmax(x, axis)


def mish(x: JaxArray, /, *, out: Optional[JaxArray] = None):
return x * jnp.tanh(jax.nn.softplus(x))


def hardswish(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jax.nn.hard_swish(x)
138 changes: 69 additions & 69 deletions ivy/functional/backends/jax/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def asarray(
return jnp.asarray(obj, dtype=dtype)


def copy_array(
x: JaxArray, *, to_ivy_array: bool = True, out: Optional[JaxArray] = None
) -> JaxArray:
x = (
jax.core.ShapedArray(x.shape, x.dtype)
if isinstance(x, jax.core.ShapedArray)
else jnp.array(x)
)
if to_ivy_array:
return ivy.to_ivy(x)
return x


def empty(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
Expand Down Expand Up @@ -127,6 +140,15 @@ def from_dlpack(x, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jax.dlpack.from_dlpack(capsule)


def frombuffer(
buffer: bytes,
dtype: Optional[jnp.dtype] = float,
count: Optional[int] = -1,
offset: Optional[int] = 0,
) -> JaxArray:
return jnp.frombuffer(buffer, dtype=dtype, count=count, offset=offset)


def full(
shape: Union[ivy.NativeShape, Sequence[int]],
fill_value: Union[int, float, bool],
Expand Down Expand Up @@ -234,6 +256,42 @@ def meshgrid(
return jnp.meshgrid(*arrays, sparse=sparse, indexing=indexing)


def one_hot(
indices: JaxArray,
depth: int,
/,
*,
on_value: Optional[Number] = None,
off_value: Optional[Number] = None,
axis: Optional[int] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
on_none = on_value is None
off_none = off_value is None

if dtype is None:
if on_none and off_none:
dtype = jnp.float32
else:
if not on_none:
dtype = jnp.array(on_value).dtype
elif not off_none:
dtype = jnp.array(off_value).dtype

res = jnp.eye(depth, dtype=dtype)[jnp.array(indices, dtype="int64").reshape(-1)]
res = res.reshape(list(indices.shape) + [depth])

if not on_none and not off_none:
res = jnp.where(res == 1, on_value, off_value)

if axis is not None:
res = jnp.moveaxis(res, -1, axis)

return res


def ones(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
Expand Down Expand Up @@ -263,6 +321,17 @@ def triu(x: JaxArray, /, *, k: int = 0, out: Optional[JaxArray] = None) -> JaxAr
return jnp.triu(x, k)


def triu_indices(
n_rows: int,
n_cols: Optional[int] = None,
k: int = 0,
/,
*,
device: jaxlib.xla_extension.Device,
) -> Tuple[JaxArray]:
return jnp.triu_indices(n=n_rows, k=k, m=n_cols)


def zeros(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
Expand All @@ -289,72 +358,3 @@ def zeros_like(


array = asarray


def copy_array(
x: JaxArray, *, to_ivy_array: bool = True, out: Optional[JaxArray] = None
) -> JaxArray:
x = (
jax.core.ShapedArray(x.shape, x.dtype)
if isinstance(x, jax.core.ShapedArray)
else jnp.array(x)
)
if to_ivy_array:
return ivy.to_ivy(x)
return x


def one_hot(
indices: JaxArray,
depth: int,
/,
*,
on_value: Optional[Number] = None,
off_value: Optional[Number] = None,
axis: Optional[int] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
on_none = on_value is None
off_none = off_value is None

if dtype is None:
if on_none and off_none:
dtype = jnp.float32
else:
if not on_none:
dtype = jnp.array(on_value).dtype
elif not off_none:
dtype = jnp.array(off_value).dtype

res = jnp.eye(depth, dtype=dtype)[jnp.array(indices, dtype="int64").reshape(-1)]
res = res.reshape(list(indices.shape) + [depth])

if not on_none and not off_none:
res = jnp.where(res == 1, on_value, off_value)

if axis is not None:
res = jnp.moveaxis(res, -1, axis)

return res


def frombuffer(
buffer: bytes,
dtype: Optional[jnp.dtype] = float,
count: Optional[int] = -1,
offset: Optional[int] = 0,
) -> JaxArray:
return jnp.frombuffer(buffer, dtype=dtype, count=count, offset=offset)


def triu_indices(
n_rows: int,
n_cols: Optional[int] = None,
k: int = 0,
/,
*,
device: jaxlib.xla_extension.Device,
) -> Tuple[JaxArray]:
return jnp.triu_indices(n=n_rows, k=k, m=n_cols)
Loading

0 comments on commit 6515e43

Please sign in to comment.