From 61ddd4c4d9a8cc833e85560cf70282a1fa403303 Mon Sep 17 00:00:00 2001 From: Vismay Suramwar <83938053+Vismay-dev@users.noreply.github.com> Date: Mon, 11 Sep 2023 22:48:48 -0500 Subject: [PATCH] Added kl_div loss to ivy experimental api (#23054) Co-authored-by: Eddy Oyieko <67474838+mobley-trent@users.noreply.github.com> --- ivy/data_classes/array/experimental/losses.py | 44 +++++++ .../container/experimental/losses.py | 114 ++++++++++++++++++ .../backends/jax/experimental/losses.py | 20 +++ .../backends/numpy/experimental/losses.py | 23 ++++ .../backends/paddle/experimental/losses.py | 26 ++++ .../tensorflow/experimental/losses.py | 22 ++++ .../backends/torch/experimental/losses.py | 29 +++++ ivy/functional/ivy/experimental/losses.py | 69 +++++++++++ .../test_experimental/test_nn/test_losses.py | 49 ++++++++ 9 files changed, 396 insertions(+) diff --git a/ivy/data_classes/array/experimental/losses.py b/ivy/data_classes/array/experimental/losses.py index 8913cf936cc85..68265a85a16e7 100644 --- a/ivy/data_classes/array/experimental/losses.py +++ b/ivy/data_classes/array/experimental/losses.py @@ -251,3 +251,47 @@ def soft_margin_loss( ivy.array([0.35667497, 0.22314353, 1.60943791]) """ return ivy.soft_margin_loss(self._data, target, reduction=reduction, out=out) + + def kl_div( + self: ivy.Array, + target: Union[ivy.Array, ivy.NativeArray], + /, + *, + reduction: Optional[str] = "mean", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + ivy.Array instance method variant of ivy.kl_div. This method simply wraps the + function, and so the docstring for ivy.kl_div also applies to this method with + minimal changes. + + Parameters + ---------- + self + Array containing input probability distribution. + target + Array contaiing target probability distribution. + reduction + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'batchmean': The output will be divided by batch size. + 'sum': The output will be summed. + Default: 'mean'. + out + Optional output array, for writing the result to. + It must have a shape that the inputs broadcast to. + + Returns + ------- + ret + The Kullback-Leibler divergence loss between the two input arrays. + + Examples + -------- + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> output_array = input.kl_div(target) + >>> print(output_array) + ivy.array(0.0916) + """ + return ivy.kl_div(self._data, target, reduction=reduction, out=out) diff --git a/ivy/data_classes/container/experimental/losses.py b/ivy/data_classes/container/experimental/losses.py index 482d9e82198f7..e4ee3c40b45d7 100644 --- a/ivy/data_classes/container/experimental/losses.py +++ b/ivy/data_classes/container/experimental/losses.py @@ -787,3 +787,117 @@ def soft_margin_loss( map_sequences=map_sequences, out=out, ) + + @staticmethod + def _static_kl_div( + input: Union[ivy.Container, ivy.Array, ivy.NativeArray], + target: Union[ivy.Container, ivy.Array, ivy.NativeArray], + /, + *, + reduction: Optional[Union[str, ivy.Container]] = "mean", + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.kl_div. This method simply wraps the + function, and so the docstring for ivy.kl_div also applies to this method with + minimal changes. + + Parameters + ---------- + input + input array or container containing input distribution. + target + input array or container containing target distribution. + reduction + the reduction method. Default: "mean". + key_chains + The key-chains to apply or not apply the method to. Default is None. + to_apply + If input, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is input. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is False. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is False. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + The Kullback-Leibler divergence loss between the given distributions. + """ + return ContainerBase.cont_multi_map_in_function( + "kl_div", + input, + target, + reduction=reduction, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def kl_div( + self: ivy.Container, + target: Union[ivy.Container, ivy.Array, ivy.NativeArray], + /, + *, + reduction: Optional[Union[str, ivy.Container]] = "mean", + key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None, + to_apply: Union[bool, ivy.Container] = True, + prune_unapplied: Union[bool, ivy.Container] = False, + map_sequences: Union[bool, ivy.Container] = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container instance method variant of ivy.kl_div. This method simply wraps + the function, and so the docstring for ivy.kl_div also applies to this method + with minimal changes. + + Parameters + ---------- + self + input container containing input distribution. + target + input array or container containing target distribution. + reduction + the reduction method. Default: "mean". + key_chains + The key-chains to apply or not apply the method to. Default is None. + to_apply + If input, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is input. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is False. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is False. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. + + Returns + ------- + ret + The Kullback-Leibler divergence loss between the given distributions. + """ + return self._static_kl_div( + self, + target, + reduction=reduction, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) diff --git a/ivy/functional/backends/jax/experimental/losses.py b/ivy/functional/backends/jax/experimental/losses.py index a5470e0afbc39..3fa778c0dbb9a 100644 --- a/ivy/functional/backends/jax/experimental/losses.py +++ b/ivy/functional/backends/jax/experimental/losses.py @@ -56,3 +56,23 @@ def soft_margin_loss( return jnp.sum(loss) else: return loss + + +def kl_div( + input: JaxArray, + target: JaxArray, + /, + *, + reduction: Optional[str] = "mean", +) -> JaxArray: + size = jnp.shape(input) + loss = jnp.sum(input * jnp.log(input / target), axis=-1) + + if reduction == "mean": + loss = jnp.mean(loss) + elif reduction == "sum": + loss = jnp.sum(loss) + elif reduction == "batchmean": + loss = jnp.divide(jnp.sum(loss), size[0]) + + return loss diff --git a/ivy/functional/backends/numpy/experimental/losses.py b/ivy/functional/backends/numpy/experimental/losses.py index 0a33b6bcf1dad..76a266109e785 100644 --- a/ivy/functional/backends/numpy/experimental/losses.py +++ b/ivy/functional/backends/numpy/experimental/losses.py @@ -70,3 +70,26 @@ def soft_margin_loss( return np.sum(loss) else: return loss + + +@with_unsupported_dtypes({"1.25.2 and below": ("bool", "bfloat16")}, backend_version) +@_scalar_output_to_0d_array +def kl_div( + input: np.ndarray, + target: np.ndarray, + /, + *, + reduction: Optional[str] = "mean", +) -> np.ndarray: + size = np.shape(input) + + loss = np.sum(input * np.log(input / target), axis=-1) + + if reduction == "mean": + loss = np.mean(loss) + elif reduction == "sum": + loss = np.sum(loss) + elif reduction == "batchmean": + loss = np.divide(np.sum(loss), size[0]) + + return loss diff --git a/ivy/functional/backends/paddle/experimental/losses.py b/ivy/functional/backends/paddle/experimental/losses.py index b0be8bda8d502..a582b43e15fb2 100644 --- a/ivy/functional/backends/paddle/experimental/losses.py +++ b/ivy/functional/backends/paddle/experimental/losses.py @@ -120,3 +120,29 @@ def soft_margin_loss( reduction: Optional[str] = "mean", ) -> paddle.Tensor: return paddle.nn.functional.soft_margin_loss(input, label, reduction=reduction) + + +@with_unsupported_device_and_dtypes( + { + "2.5.1 and below": { + "cpu": ( + "bfloat16", + "float16", + "int8", + "int16", + "int32", + "int64", + "uint8", + "complex64", + "complex128", + "bool", + ) + } + }, + backend_version, +) +def kl_div( + input: paddle.Tensor, target: paddle.Tensor, /, *, reduction: Optional[str] = "mean" +) -> paddle.Tensor: + loss = F.kl_div(input, target, reduction=reduction) + return loss diff --git a/ivy/functional/backends/tensorflow/experimental/losses.py b/ivy/functional/backends/tensorflow/experimental/losses.py index fdd493e40b8e8..e0c2da76b9958 100644 --- a/ivy/functional/backends/tensorflow/experimental/losses.py +++ b/ivy/functional/backends/tensorflow/experimental/losses.py @@ -62,3 +62,25 @@ def soft_margin_loss( return tf.reduce_mean(loss) else: return loss + + +@with_unsupported_dtypes({"2.13.0 and below": ("bool", "bfloat16")}, backend_version) +def kl_div( + input: tf.Tensor, + target: tf.Tensor, + /, + *, + reduction: Optional[str] = "mean", +) -> tf.Tensor: + size = tf.shape(input) + + loss = tf.reduce_sum(input * tf.math.log(input / target), axis=-1) + + if reduction == "mean": + loss = tf.math.reduce_mean(loss) + elif reduction == "sum": + loss = tf.math.reduce_sum(loss) + elif reduction == "batchmean": + loss = tf.math.reduce_sum(loss) / tf.cast(size[0], dtype=tf.float32) + + return loss diff --git a/ivy/functional/backends/torch/experimental/losses.py b/ivy/functional/backends/torch/experimental/losses.py index 2c24c6afd01e2..a37365922a1b6 100644 --- a/ivy/functional/backends/torch/experimental/losses.py +++ b/ivy/functional/backends/torch/experimental/losses.py @@ -97,3 +97,32 @@ def soft_margin_loss( target, reduction=reduction, ) + + +@with_unsupported_dtypes( + { + "2.0.1 and below": ( + "float16", + "uint8", + "int8", + "int16", + "int32", + "int64", + "bool", + ) + }, + backend_version, +) +def kl_div( + input: torch.Tensor, + target: torch.Tensor, + /, + *, + reduction: Optional[str] = "mean", +) -> torch.Tensor: + loss = torch.nn.functional.kl_div( + input, + target, + reduction=reduction, + ) + return loss diff --git a/ivy/functional/ivy/experimental/losses.py b/ivy/functional/ivy/experimental/losses.py index 8d2ef5d3e8f6a..f2950d682bb0d 100644 --- a/ivy/functional/ivy/experimental/losses.py +++ b/ivy/functional/ivy/experimental/losses.py @@ -408,3 +408,72 @@ def soft_margin_loss( return ivy.mean(loss, out=out) else: return ivy.inplace_update(out, loss) if out is not None else loss + + +@handle_exceptions +@handle_nestable +@inputs_to_ivy_arrays +@handle_array_function +def kl_div( + input: Union[ivy.Array, ivy.NativeArray], + target: Union[ivy.Array, ivy.NativeArray], + /, + *, + reduction: Optional[str] = "mean", + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """ + Compute the Kullback-Leibler divergence loss between two input tensors + (conventionally, probability distributions). + + Parameters + ---------- + input : array_like + Input probability distribution (first tensor). + target : array_like + Target probability distribution (second tensor). + reduction : {'mean', 'sum', 'batchmean', 'none'}, optional + Type of reduction to apply to the output. Default is 'mean'. + out : array_like, optional + Optional output array, for writing the result to. + It must have a shape that the inputs broadcast to. + + Returns + ------- + ret : array + The Kullback-Leibler divergence loss between the two input tensors. + + Examples + -------- + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> ivy.kl_div(input, target) + ivy.array(0.0916) + + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> ivy.kl_div(input, target, reduction='sum') + ivy.array(0.1832) + + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> ivy.kl_div(input, target, reduction='batchmean') + ivy.array(0.0916) + + >>> input = ivy.array([0.2, 0.8], [0.5, 0.5]) + >>> target = ivy.array([0.6, 0.4], [0.3, 0.7]) + >>> ivy.kl_div(input, target, reduction='none') + ivy.array([0.0378], [0.1453]) + """ + size = ivy.shape(input) + + loss = ivy.sum(input * ivy.log(input / target), axis=-1) + + if reduction == "sum": + loss = ivy.sum(loss, out=out) + elif reduction == "mean": + loss = ivy.mean(loss, out=out) + elif reduction == "batchmean": + loss = ivy.sum(loss, out=out) / size[0] + + return ivy.inplace_update(out, loss) if out is not None else loss diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py index c41b3f0d2d63c..cbddd092c511e 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_losses.py @@ -55,6 +55,55 @@ def test_huber_loss( ) +# kl_div +@handle_test( + fn_tree="functional.ivy.experimental.kl_div", + dtype_and_input=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=1e-04, + max_value=1, + allow_inf=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=3, + ), + dtype_and_target=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=1e-04, + max_value=1, + allow_inf=False, + min_num_dims=1, + max_num_dims=3, + min_dim_size=3, + ), + reduction=st.sampled_from(["none", "sum", "batchmean", "mean"]), + test_with_out=st.just(False), +) +def test_kl_div( + dtype_and_input, + dtype_and_target, + reduction, + test_flags, + backend_fw, + fn_name, + on_device, +): + input_dtype, input = dtype_and_input + target_dtype, target = dtype_and_target + + helpers.test_function( + input_dtypes=input_dtype + target_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + atol_=1e-02, + input=input[0], + target=target[0], + reduction=reduction, + ) + + @handle_test( fn_tree="functional.ivy.experimental.l1_loss", dtype_input=helpers.dtype_and_values(