Skip to content

Commit

Permalink
Added kl_div loss to ivy experimental api (#23054)
Browse files Browse the repository at this point in the history
Co-authored-by: Eddy Oyieko <[email protected]>
  • Loading branch information
vismaysur and mobley-trent authored Sep 12, 2023
1 parent 6fad2fa commit 61ddd4c
Show file tree
Hide file tree
Showing 9 changed files with 396 additions and 0 deletions.
44 changes: 44 additions & 0 deletions ivy/data_classes/array/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,47 @@ def soft_margin_loss(
ivy.array([0.35667497, 0.22314353, 1.60943791])
"""
return ivy.soft_margin_loss(self._data, target, reduction=reduction, out=out)

def kl_div(
self: ivy.Array,
target: Union[ivy.Array, ivy.NativeArray],
/,
*,
reduction: Optional[str] = "mean",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.kl_div. This method simply wraps the
function, and so the docstring for ivy.kl_div also applies to this method with
minimal changes.
Parameters
----------
self
Array containing input probability distribution.
target
Array contaiing target probability distribution.
reduction
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'batchmean': The output will be divided by batch size.
'sum': The output will be summed.
Default: 'mean'.
out
Optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.
Returns
-------
ret
The Kullback-Leibler divergence loss between the two input arrays.
Examples
--------
>>> input = ivy.array([0.2, 0.8], [0.5, 0.5])
>>> target = ivy.array([0.6, 0.4], [0.3, 0.7])
>>> output_array = input.kl_div(target)
>>> print(output_array)
ivy.array(0.0916)
"""
return ivy.kl_div(self._data, target, reduction=reduction, out=out)
114 changes: 114 additions & 0 deletions ivy/data_classes/container/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,3 +787,117 @@ def soft_margin_loss(
map_sequences=map_sequences,
out=out,
)

@staticmethod
def _static_kl_div(
input: Union[ivy.Container, ivy.Array, ivy.NativeArray],
target: Union[ivy.Container, ivy.Array, ivy.NativeArray],
/,
*,
reduction: Optional[Union[str, ivy.Container]] = "mean",
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.kl_div. This method simply wraps the
function, and so the docstring for ivy.kl_div also applies to this method with
minimal changes.
Parameters
----------
input
input array or container containing input distribution.
target
input array or container containing target distribution.
reduction
the reduction method. Default: "mean".
key_chains
The key-chains to apply or not apply the method to. Default is None.
to_apply
If input, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is input.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is False.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is False.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Returns
-------
ret
The Kullback-Leibler divergence loss between the given distributions.
"""
return ContainerBase.cont_multi_map_in_function(
"kl_div",
input,
target,
reduction=reduction,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def kl_div(
self: ivy.Container,
target: Union[ivy.Container, ivy.Array, ivy.NativeArray],
/,
*,
reduction: Optional[Union[str, ivy.Container]] = "mean",
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.kl_div. This method simply wraps
the function, and so the docstring for ivy.kl_div also applies to this method
with minimal changes.
Parameters
----------
self
input container containing input distribution.
target
input array or container containing target distribution.
reduction
the reduction method. Default: "mean".
key_chains
The key-chains to apply or not apply the method to. Default is None.
to_apply
If input, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is input.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is False.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is False.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Returns
-------
ret
The Kullback-Leibler divergence loss between the given distributions.
"""
return self._static_kl_div(
self,
target,
reduction=reduction,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)
20 changes: 20 additions & 0 deletions ivy/functional/backends/jax/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,23 @@ def soft_margin_loss(
return jnp.sum(loss)
else:
return loss


def kl_div(
input: JaxArray,
target: JaxArray,
/,
*,
reduction: Optional[str] = "mean",
) -> JaxArray:
size = jnp.shape(input)
loss = jnp.sum(input * jnp.log(input / target), axis=-1)

if reduction == "mean":
loss = jnp.mean(loss)
elif reduction == "sum":
loss = jnp.sum(loss)
elif reduction == "batchmean":
loss = jnp.divide(jnp.sum(loss), size[0])

return loss
23 changes: 23 additions & 0 deletions ivy/functional/backends/numpy/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,26 @@ def soft_margin_loss(
return np.sum(loss)
else:
return loss


@with_unsupported_dtypes({"1.25.2 and below": ("bool", "bfloat16")}, backend_version)
@_scalar_output_to_0d_array
def kl_div(
input: np.ndarray,
target: np.ndarray,
/,
*,
reduction: Optional[str] = "mean",
) -> np.ndarray:
size = np.shape(input)

loss = np.sum(input * np.log(input / target), axis=-1)

if reduction == "mean":
loss = np.mean(loss)
elif reduction == "sum":
loss = np.sum(loss)
elif reduction == "batchmean":
loss = np.divide(np.sum(loss), size[0])

return loss
26 changes: 26 additions & 0 deletions ivy/functional/backends/paddle/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,29 @@ def soft_margin_loss(
reduction: Optional[str] = "mean",
) -> paddle.Tensor:
return paddle.nn.functional.soft_margin_loss(input, label, reduction=reduction)


@with_unsupported_device_and_dtypes(
{
"2.5.1 and below": {
"cpu": (
"bfloat16",
"float16",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
)
}
},
backend_version,
)
def kl_div(
input: paddle.Tensor, target: paddle.Tensor, /, *, reduction: Optional[str] = "mean"
) -> paddle.Tensor:
loss = F.kl_div(input, target, reduction=reduction)
return loss
22 changes: 22 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,25 @@ def soft_margin_loss(
return tf.reduce_mean(loss)
else:
return loss


@with_unsupported_dtypes({"2.13.0 and below": ("bool", "bfloat16")}, backend_version)
def kl_div(
input: tf.Tensor,
target: tf.Tensor,
/,
*,
reduction: Optional[str] = "mean",
) -> tf.Tensor:
size = tf.shape(input)

loss = tf.reduce_sum(input * tf.math.log(input / target), axis=-1)

if reduction == "mean":
loss = tf.math.reduce_mean(loss)
elif reduction == "sum":
loss = tf.math.reduce_sum(loss)
elif reduction == "batchmean":
loss = tf.math.reduce_sum(loss) / tf.cast(size[0], dtype=tf.float32)

return loss
29 changes: 29 additions & 0 deletions ivy/functional/backends/torch/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,32 @@ def soft_margin_loss(
target,
reduction=reduction,
)


@with_unsupported_dtypes(
{
"2.0.1 and below": (
"float16",
"uint8",
"int8",
"int16",
"int32",
"int64",
"bool",
)
},
backend_version,
)
def kl_div(
input: torch.Tensor,
target: torch.Tensor,
/,
*,
reduction: Optional[str] = "mean",
) -> torch.Tensor:
loss = torch.nn.functional.kl_div(
input,
target,
reduction=reduction,
)
return loss
69 changes: 69 additions & 0 deletions ivy/functional/ivy/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,72 @@ def soft_margin_loss(
return ivy.mean(loss, out=out)
else:
return ivy.inplace_update(out, loss) if out is not None else loss


@handle_exceptions
@handle_nestable
@inputs_to_ivy_arrays
@handle_array_function
def kl_div(
input: Union[ivy.Array, ivy.NativeArray],
target: Union[ivy.Array, ivy.NativeArray],
/,
*,
reduction: Optional[str] = "mean",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Compute the Kullback-Leibler divergence loss between two input tensors
(conventionally, probability distributions).
Parameters
----------
input : array_like
Input probability distribution (first tensor).
target : array_like
Target probability distribution (second tensor).
reduction : {'mean', 'sum', 'batchmean', 'none'}, optional
Type of reduction to apply to the output. Default is 'mean'.
out : array_like, optional
Optional output array, for writing the result to.
It must have a shape that the inputs broadcast to.
Returns
-------
ret : array
The Kullback-Leibler divergence loss between the two input tensors.
Examples
--------
>>> input = ivy.array([0.2, 0.8], [0.5, 0.5])
>>> target = ivy.array([0.6, 0.4], [0.3, 0.7])
>>> ivy.kl_div(input, target)
ivy.array(0.0916)
>>> input = ivy.array([0.2, 0.8], [0.5, 0.5])
>>> target = ivy.array([0.6, 0.4], [0.3, 0.7])
>>> ivy.kl_div(input, target, reduction='sum')
ivy.array(0.1832)
>>> input = ivy.array([0.2, 0.8], [0.5, 0.5])
>>> target = ivy.array([0.6, 0.4], [0.3, 0.7])
>>> ivy.kl_div(input, target, reduction='batchmean')
ivy.array(0.0916)
>>> input = ivy.array([0.2, 0.8], [0.5, 0.5])
>>> target = ivy.array([0.6, 0.4], [0.3, 0.7])
>>> ivy.kl_div(input, target, reduction='none')
ivy.array([0.0378], [0.1453])
"""
size = ivy.shape(input)

loss = ivy.sum(input * ivy.log(input / target), axis=-1)

if reduction == "sum":
loss = ivy.sum(loss, out=out)
elif reduction == "mean":
loss = ivy.mean(loss, out=out)
elif reduction == "batchmean":
loss = ivy.sum(loss, out=out) / size[0]

return ivy.inplace_update(out, loss) if out is not None else loss
Loading

0 comments on commit 61ddd4c

Please sign in to comment.