Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(All APIs): Added new pre-commit hook for ordering functions #22830

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ repos:
# Exclude everything in frontends except __init__.py, and func_wrapper.py
exclude: 'ivy/functional/(frontends|backends)/(?!.*/func_wrapper\.py$).*(?!__init__\.py$)'
- repo: https://github.com/unifyai/lint-hook
rev: 27646397c5390f644a645f439535b1061b9c0105
rev: 2e2dc6c06475b5ec47e4b97c30b23dbb4bd01891
hooks:
- id: ivy-lint
32 changes: 16 additions & 16 deletions ivy/functional/backends/jax/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def gelu(
return jax.nn.gelu(x, approximate)


def hardswish(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jax.nn.hard_swish(x)


def leaky_relu(
x: JaxArray,
/,
Expand All @@ -34,6 +38,18 @@ def leaky_relu(
return jnp.asarray(jnp.where(x > 0, x, jnp.multiply(x, alpha)), x.dtype)


def log_softmax(
x: JaxArray, /, *, axis: Optional[int] = None, out: Optional[JaxArray] = None
):
if axis is None:
axis = -1
return jax.nn.log_softmax(x, axis)


def mish(x: JaxArray, /, *, out: Optional[JaxArray] = None):
return x * jnp.tanh(jax.nn.softplus(x))


def relu(
x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None
) -> JaxArray:
Expand Down Expand Up @@ -78,19 +94,3 @@ def softplus(
if threshold is not None:
return jnp.where(x_beta > threshold, x, res).astype(x.dtype)
return res.astype(x.dtype)


def log_softmax(
x: JaxArray, /, *, axis: Optional[int] = None, out: Optional[JaxArray] = None
):
if axis is None:
axis = -1
return jax.nn.log_softmax(x, axis)


def mish(x: JaxArray, /, *, out: Optional[JaxArray] = None):
return x * jnp.tanh(jax.nn.softplus(x))


def hardswish(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jax.nn.hard_swish(x)
138 changes: 69 additions & 69 deletions ivy/functional/backends/jax/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ def asarray(
return jnp.asarray(obj, dtype=dtype)


def copy_array(
x: JaxArray, *, to_ivy_array: bool = True, out: Optional[JaxArray] = None
) -> JaxArray:
x = (
jax.core.ShapedArray(x.shape, x.dtype)
if isinstance(x, jax.core.ShapedArray)
else jnp.array(x)
)
if to_ivy_array:
return ivy.to_ivy(x)
return x


def empty(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
Expand Down Expand Up @@ -127,6 +140,15 @@ def from_dlpack(x, /, *, out: Optional[JaxArray] = None) -> JaxArray:
return jax.dlpack.from_dlpack(capsule)


def frombuffer(
buffer: bytes,
dtype: Optional[jnp.dtype] = float,
count: Optional[int] = -1,
offset: Optional[int] = 0,
) -> JaxArray:
return jnp.frombuffer(buffer, dtype=dtype, count=count, offset=offset)


def full(
shape: Union[ivy.NativeShape, Sequence[int]],
fill_value: Union[int, float, bool],
Expand Down Expand Up @@ -234,6 +256,42 @@ def meshgrid(
return jnp.meshgrid(*arrays, sparse=sparse, indexing=indexing)


def one_hot(
indices: JaxArray,
depth: int,
/,
*,
on_value: Optional[Number] = None,
off_value: Optional[Number] = None,
axis: Optional[int] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
on_none = on_value is None
off_none = off_value is None

if dtype is None:
if on_none and off_none:
dtype = jnp.float32
else:
if not on_none:
dtype = jnp.array(on_value).dtype
elif not off_none:
dtype = jnp.array(off_value).dtype

res = jnp.eye(depth, dtype=dtype)[jnp.array(indices, dtype="int64").reshape(-1)]
res = res.reshape(list(indices.shape) + [depth])

if not on_none and not off_none:
res = jnp.where(res == 1, on_value, off_value)

if axis is not None:
res = jnp.moveaxis(res, -1, axis)

return res


def ones(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
Expand Down Expand Up @@ -263,6 +321,17 @@ def triu(x: JaxArray, /, *, k: int = 0, out: Optional[JaxArray] = None) -> JaxAr
return jnp.triu(x, k)


def triu_indices(
n_rows: int,
n_cols: Optional[int] = None,
k: int = 0,
/,
*,
device: jaxlib.xla_extension.Device,
) -> Tuple[JaxArray]:
return jnp.triu_indices(n=n_rows, k=k, m=n_cols)


def zeros(
shape: Union[ivy.NativeShape, Sequence[int]],
*,
Expand All @@ -289,72 +358,3 @@ def zeros_like(


array = asarray


def copy_array(
x: JaxArray, *, to_ivy_array: bool = True, out: Optional[JaxArray] = None
) -> JaxArray:
x = (
jax.core.ShapedArray(x.shape, x.dtype)
if isinstance(x, jax.core.ShapedArray)
else jnp.array(x)
)
if to_ivy_array:
return ivy.to_ivy(x)
return x


def one_hot(
indices: JaxArray,
depth: int,
/,
*,
on_value: Optional[Number] = None,
off_value: Optional[Number] = None,
axis: Optional[int] = None,
dtype: Optional[jnp.dtype] = None,
device: jaxlib.xla_extension.Device,
out: Optional[JaxArray] = None,
) -> JaxArray:
on_none = on_value is None
off_none = off_value is None

if dtype is None:
if on_none and off_none:
dtype = jnp.float32
else:
if not on_none:
dtype = jnp.array(on_value).dtype
elif not off_none:
dtype = jnp.array(off_value).dtype

res = jnp.eye(depth, dtype=dtype)[jnp.array(indices, dtype="int64").reshape(-1)]
res = res.reshape(list(indices.shape) + [depth])

if not on_none and not off_none:
res = jnp.where(res == 1, on_value, off_value)

if axis is not None:
res = jnp.moveaxis(res, -1, axis)

return res


def frombuffer(
buffer: bytes,
dtype: Optional[jnp.dtype] = float,
count: Optional[int] = -1,
offset: Optional[int] = 0,
) -> JaxArray:
return jnp.frombuffer(buffer, dtype=dtype, count=count, offset=offset)


def triu_indices(
n_rows: int,
n_cols: Optional[int] = None,
k: int = 0,
/,
*,
device: jaxlib.xla_extension.Device,
) -> Tuple[JaxArray]:
return jnp.triu_indices(n=n_rows, k=k, m=n_cols)
Loading
Loading