diff --git a/ivy/data_classes/array/experimental/activations.py b/ivy/data_classes/array/experimental/activations.py index 699022d70bdfa..92f997bf66ec9 100644 --- a/ivy/data_classes/array/experimental/activations.py +++ b/ivy/data_classes/array/experimental/activations.py @@ -1,6 +1,6 @@ # global import abc -from typing import Optional, Union +from typing import Optional, Union, Literal # local import ivy @@ -8,7 +8,12 @@ class _ArrayWithActivationsExperimental(abc.ABC): def logit( - self, /, *, eps: Optional[float] = None, out: Optional[ivy.Array] = None + self, + /, + *, + eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ ivy.Array instance method variant of ivy.logit. This method simply wraps the @@ -23,6 +28,9 @@ def logit( When eps is None the function outpus NaN where x < 0 or x > 1. and inf or -inf where x = 1 or x = 0, respectively. Otherwise if eps is defined, x is clamped to [eps, 1 - eps] + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out Optional output array. @@ -43,7 +51,7 @@ def logit( >>> print(z) ivy.array([ 1.38629448, 1.38629448, -1.38629436]) """ - return ivy.logit(self, eps=eps, out=out) + return ivy.logit(self, eps=eps, complex_mode=complex_mode, out=out) def thresholded_relu( self: ivy.Array, diff --git a/ivy/data_classes/container/experimental/activations.py b/ivy/data_classes/container/experimental/activations.py index aa60fc5bd9dfd..29c29a519761e 100644 --- a/ivy/data_classes/container/experimental/activations.py +++ b/ivy/data_classes/container/experimental/activations.py @@ -1,5 +1,5 @@ # global -from typing import Union, Optional, List, Dict +from typing import Union, Optional, List, Dict, Literal # local import ivy @@ -13,6 +13,7 @@ def static_logit( /, *, eps: Optional[Union[float, ivy.Container]] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -28,6 +29,9 @@ def static_logit( When eps is None the function outpus NaN where x < 0 or x > 1. and inf or -inf where x = 1 or x = 0, respectively. Otherwise if eps is defined, x is clamped to [eps, 1 - eps] + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out Optional output Contaner. @@ -62,6 +66,7 @@ def static_logit( "logit", x, eps=eps, + complex_mode=complex_mode, out=out, ) @@ -70,6 +75,7 @@ def logit( /, *, eps: Optional[Union[float, ivy.Container]] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Container] = None, ) -> ivy.Container: """ @@ -85,6 +91,9 @@ def logit( When eps is None the function outpus NaN where x < 0 or x > 1. and inf or -inf where x = 1 or x = 0, respectively. Otherwise if eps is defined, x is clamped to [eps, 1 - eps] + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out Optional output Contaner. @@ -115,7 +124,7 @@ def logit( b: ivy.array([-1.38629436, 1.38629448, -1.38629436]) } """ - return self.static_logit(self, eps=eps, out=out) + return self.static_logit(self, eps=eps, complex_mode=complex_mode, out=out) @staticmethod def static_thresholded_relu( diff --git a/ivy/functional/backends/jax/experimental/activations.py b/ivy/functional/backends/jax/experimental/activations.py index 788126f4b44cf..a18f2b240b61c 100644 --- a/ivy/functional/backends/jax/experimental/activations.py +++ b/ivy/functional/backends/jax/experimental/activations.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Literal # global import jax @@ -13,6 +13,7 @@ def logit( /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[JaxArray] = None, ): if eps is None: diff --git a/ivy/functional/backends/numpy/experimental/activations.py b/ivy/functional/backends/numpy/experimental/activations.py index d3f8282a22de2..1ac8c43285a96 100644 --- a/ivy/functional/backends/numpy/experimental/activations.py +++ b/ivy/functional/backends/numpy/experimental/activations.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Literal # global import numpy as np @@ -15,6 +15,7 @@ def logit( /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[np.ndarray] = None, ): x_dtype = x.dtype diff --git a/ivy/functional/backends/paddle/experimental/activations.py b/ivy/functional/backends/paddle/experimental/activations.py index 81bc6bdf25cb7..d4fd80403a331 100644 --- a/ivy/functional/backends/paddle/experimental/activations.py +++ b/ivy/functional/backends/paddle/experimental/activations.py @@ -1,5 +1,5 @@ # global -from typing import Optional, Union +from typing import Optional, Union, Literal import paddle import paddle.nn.functional as F @@ -10,9 +10,16 @@ @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("float16",)}}, backend_version + {"2.5.1 and below": {"cpu": ("float16", "bfloat16")}}, backend_version ) -def logit(x: paddle.Tensor, /, *, eps: Optional[float] = None, out=None): +def logit( + x: paddle.Tensor, + /, + *, + eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out=None, +): if x.dtype in [paddle.float32, paddle.float64]: return paddle.logit(x, eps) if eps is None: diff --git a/ivy/functional/backends/tensorflow/experimental/activations.py b/ivy/functional/backends/tensorflow/experimental/activations.py index f798d7b907d98..5ba64d523be10 100644 --- a/ivy/functional/backends/tensorflow/experimental/activations.py +++ b/ivy/functional/backends/tensorflow/experimental/activations.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Literal # global import tensorflow as tf @@ -15,6 +15,7 @@ def logit( /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[Tensor] = None, ) -> Tensor: x_dtype = x.dtype diff --git a/ivy/functional/backends/torch/experimental/activations.py b/ivy/functional/backends/torch/experimental/activations.py index 98d72ac526c45..4331ac7b57fed 100644 --- a/ivy/functional/backends/torch/experimental/activations.py +++ b/ivy/functional/backends/torch/experimental/activations.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, Literal # global import torch @@ -16,6 +16,7 @@ def logit( /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.logit(x, eps=eps, out=out) diff --git a/ivy/functional/ivy/experimental/activations.py b/ivy/functional/ivy/experimental/activations.py index 34203cffbaf6d..07c5279739070 100644 --- a/ivy/functional/ivy/experimental/activations.py +++ b/ivy/functional/ivy/experimental/activations.py @@ -1,5 +1,5 @@ # global -from typing import Union, Optional +from typing import Union, Optional, Callable, Literal # local import ivy @@ -14,9 +14,29 @@ inputs_to_ivy_arrays, handle_device_shifting, handle_backend_invalid, + handle_complex_input, ) +def _logit_jax_like( + x: Union[float, int, ivy.Array], + /, + *, + fn_original: Optional[Callable] = None, + eps: Optional[float] = None, + out: Optional[ivy.Array] = None, +): + real = ivy.real(x) + imag = ivy.imag(x) + if eps is None: + real = ivy.where(ivy.logical_or(real > 1, real < 0), ivy.nan, real) + else: + real = ivy.clip(real, eps, 1 - eps) + z = ivy.add(real, ivy.multiply(ivy.array(1j, dtype=x.dtype), imag)) + z = ivy.log(z / (1 - z)) + return z + + @handle_exceptions @handle_backend_invalid @handle_nestable @@ -24,11 +44,13 @@ @handle_out_argument @to_native_arrays_and_back @handle_device_shifting +@handle_complex_input def logit( x: Union[float, int, ivy.Array], /, *, eps: Optional[float] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", out: Optional[ivy.Array] = None, ) -> ivy.Array: """ @@ -44,6 +66,9 @@ def logit( When eps is None the function outpus NaN where x < 0 or x > 1. and inf or -inf where x = 1 or x = 0, respectively. Otherwise if eps is defined, x is clamped to [eps, 1 - eps] + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. out Optional output array. @@ -67,6 +92,9 @@ def logit( return current_backend(x).logit(x, eps=eps, out=out) +logit.jax_like = _logit_jax_like + + @handle_exceptions @handle_nestable @handle_array_like_without_promotion diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index b75dd43fcce77..b83bac4fe826b 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -374,10 +374,25 @@ def _forward(self, x): class Logit(Module): - def __init__(self, eps=None): - """Apply the LOGIT activation function.""" + def __init__( + self, + eps=None, + complex_mode="jax", + ): + """ + Apply the LOGIT activation function. + + Parameters + ---------- + eps + The epsilon value for the logit formation. Default: ``None``. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. + """ Module.__init__(self) self._eps = eps + self._complex_mode = complex_mode def _forward(self, x): """ @@ -386,15 +401,17 @@ def _forward(self, x): ---------- x Inputs to process *[batch_shape, d]*. - eps - The epsilon value for the logit formation. Default: ``None``. Returns ------- ret The outputs following the LOGIT activation *[batch_shape, d]* """ - return ivy.logit(x, eps=self._eps) + return ivy.logit( + x, + eps=self._eps, + complex_mode=self._complex_mode, + ) class PReLU(Module): diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py index 03a8c0a880b28..7c213994ed770 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_activations.py @@ -44,7 +44,7 @@ def test_elu( @handle_test( fn_tree="functional.ivy.experimental.logit", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log", diff --git a/ivy_tests/test_ivy/test_stateful/test_activations.py b/ivy_tests/test_ivy/test_stateful/test_activations.py index e1fd840ccde11..3630421a1b7cc 100644 --- a/ivy_tests/test_ivy/test_stateful/test_activations.py +++ b/ivy_tests/test_ivy/test_stateful/test_activations.py @@ -288,7 +288,7 @@ def test_log_softmax( @handle_method( method_tree="stateful.activations.Logit.__call__", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=8, small_abs_safety_factor=8, safety_factor_scale="log",