Skip to content

Commit

Permalink
fix(jax-backend): fix failing test for mean (#27084)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaryan562 authored Nov 1, 2023
1 parent 4626cb1 commit 6e1f438
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 5 additions & 2 deletions ivy/functional/backends/jax/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def max(
return jnp.max(a=jnp.asarray(x), axis=axis, keepdims=keepdims)


@with_unsupported_dtypes({"0.4.14 and below": "bfloat16"}, backend_version)
@with_unsupported_dtypes(
{"0.4.19 and below": "bfloat16"},
backend_version,
)
def mean(
x: JaxArray,
/,
Expand All @@ -47,7 +50,7 @@ def mean(
out: Optional[JaxArray] = None,
) -> JaxArray:
axis = tuple(axis) if isinstance(axis, list) else axis
return jnp.mean(x, axis=axis, keepdims=keepdims)
return jnp.mean(x, axis=axis, keepdims=keepdims, dtype=x.dtype)


def _infer_dtype(dtype: jnp.dtype):
Expand Down
4 changes: 1 addition & 3 deletions ivy/functional/backends/numpy/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def mean(
out: Optional[np.ndarray] = None,
) -> np.ndarray:
axis = tuple(axis) if isinstance(axis, list) else axis
return ivy.astype(
np.mean(x, axis=axis, keepdims=keepdims, out=out), x.dtype, copy=False
)
return np.mean(x, axis=axis, keepdims=keepdims, dtype=x.dtype, out=out)


mean.support_native_out = True
Expand Down

0 comments on commit 6e1f438

Please sign in to comment.