Skip to content

Commit

Permalink
added the loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
sushmanthreddy committed Sep 3, 2023
1 parent 4e60e57 commit 814904d
Show file tree
Hide file tree
Showing 9 changed files with 461 additions and 0 deletions.
54 changes: 54 additions & 0 deletions ivy/data_classes/array/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,57 @@ def soft_margin_loss(
ivy.array([0.35667497, 0.22314353, 1.60943791])
"""
return ivy.soft_margin_loss(self._data, target, reduction=reduction, out=out)

def margin_ranking_loss(
self: ivy.Array,
pred: Union[ivy.Array, ivy.NativeArray],
target: Union[ivy.Array, ivy.NativeArray],
/,
*,
margin: Optional[float] = 0.0,
reduction: Optional[str] = "mean",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.margin_ranking_loss. This method simply
wraps the function, and so the docstring for ivy.margin_ranking_loss also
applies to this method with minimal changes.
Parameters
----------
true
input array or container containing predictions for the first input.
pred
input array or container containing predictions for the second input.
target
input array or container containing the binary labels (1 or -1).
margin
a float margin for loss. Default: ``0.0``.
reduction
the reduction type to apply to the loss. Default: ``'mean'``.
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
Returns
-------
ret
The margin ranking loss.
Examples
--------
>>> true = ivy.array([0.5, 0.8, 0.6])
>>> pred = ivy.array([0.3, 0.4, 0.2])
>>> target = ivy.array([1.0, -1.0, -1.0])
>>> loss = true.margin_ranking_loss(pred, target, margin=0.1)
>>> print(loss)
ivy.array(0.33)
"""
return ivy.margin_ranking_loss(
self._data,
pred,
target,
margin=margin,
reduction=reduction,
out=out,
)
161 changes: 161 additions & 0 deletions ivy/data_classes/container/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,164 @@ def soft_margin_loss(
map_sequences=map_sequences,
out=out,
)

@staticmethod
def _static_margin_ranking_loss(
true: Union[ivy.Container, ivy.Array, ivy.NativeArray],
pred: Union[ivy.Container, ivy.Array, ivy.NativeArray],
target: Union[ivy.Container, ivy.Array, ivy.NativeArray],
/,
*,
margin: Optional[float] = 0.0,
reduction: Optional[Union[str, ivy.Container]] = "mean",
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.margin_ranking_loss. This method
simply wraps the function, and so the docstring for ivy.margin_ranking_loss also
applies to this method with minimal changes.
Parameters
----------
true
input array or container containing predictions for the first input.
pred
input array or container containing predictions for the second input.
target
input array or container containing the binary labels (1 or -1).
margin
margin for the loss. Default is ``0``.
reduction
Specifies the reduction to apply to the output. Default is ``"mean"``.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Returns
-------
ret
The margin ranking loss.
Examples
--------
With :class:`ivy.Container` inputs:
>>> true = ivy.Container(a=ivy.array([0.5, 0.2, 0.8]),
>>> b=ivy.array([0.7, 0.1, 0.3]))
>>> pred = ivy.Container(a=ivy.array([0.3, 0.9, 0.6]),
>>> b=ivy.array([0.2, 0.6, 0.4]))
>>> target = ivy.Container(a=ivy.array(1),
>>> b=ivy.array(-1))
>>> loss = ivy.Container.static_margin_ranking_loss(true, pred, target)
>>> print(loss)
{
a: ivy.array(0.2),
b: ivy.array(0.4)
}
"""
return ContainerBase.cont_multi_map_in_function(
"margin_ranking_loss",
true,
pred,
target,
margin=margin,
reduction=reduction,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def margin_ranking_loss(
self: ivy.Container,
pred: Union[ivy.Container, ivy.Array, ivy.NativeArray],
target: Union[ivy.Container, ivy.Array, ivy.NativeArray],
/,
*,
margin: Optional[Union[float, ivy.Container]] = 0.0,
reduction: Optional[Union[str, ivy.Container]] = "mean",
key_chains: Optional[Union[List[str], Dict[str, str], ivy.Container]] = None,
to_apply: Union[bool, ivy.Container] = True,
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.margin_ranking_loss. This method
simply wraps the function, and so the docstring for ivy.margin_ranking_loss also
applies to this method with minimal changes.
Parameters
----------
self
input container containing predictions for the first input.
pred
input array or container containing the second input.
target
input array or container containing the binary labels (1 or -1).
margin
margin for the loss. Default is ``0``.
reduction
Specifies the reduction to apply to the output. Default is ``"mean"``.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Returns
-------
ret
The margin ranking loss.
Examples
--------
>>> true = ivy.Container(a=ivy.array([0.5, 0.2, 0.8]),
>>> b=ivy.array([0.7, 0.1, 0.3]))
>>> pred = ivy.Container(a=ivy.array([0.3, 0.9, 0.6]),
>>> b=ivy.array([0.2, 0.6, 0.4]))
>>> target = ivy.Container(a=ivy.array(1), b=ivy.array(-1))
>>> loss = true.margin_ranking_loss(true, pred, target)
>>> print(loss)
{
a: ivy.array(0.2),
b: ivy.array(0.4)
}
"""
return self._static_margin_ranking_loss(
self,
pred,
target,
margin=margin,
reduction=reduction,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)
20 changes: 20 additions & 0 deletions ivy/functional/backends/jax/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,23 @@ def soft_margin_loss(
return jnp.sum(loss)
else:
return loss


def margin_ranking_loss(
input1: JaxArray,
input2: JaxArray,
target: JaxArray,
/,
*,
margin: Optional[float] = 1.0,
reduction: Optional[str] = "mean"
) -> JaxArray:
pairwise_margin = margin - target * (input1 - input2)
loss = jnp.where(pairwise_margin > 0, pairwise_margin, 0)

if reduction == "mean":
return jnp.mean(loss)
elif reduction == "sum":
return jnp.sum(loss)
else:
return loss
23 changes: 23 additions & 0 deletions ivy/functional/backends/numpy/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,26 @@ def soft_margin_loss(
return np.sum(loss)
else:
return loss



@with_unsupported_dtypes({"1.25.2 and below": ("bool",)}, backend_version)
@_scalar_output_to_0d_array
def margin_ranking_loss(
input1: np.ndarray,
input2: np.ndarray,
target: np.ndarray,
/,
*,
margin: Optional[float] = 1.0,
reduction: Optional[str] = "mean",
) -> np.ndarray:
pairwise_margin = margin - target * (input1 - input2)
loss = np.where(pairwise_margin > 0, pairwise_margin, 0)

if reduction == "mean":
return np.mean(loss)
elif reduction == "sum":
return np.sum(loss)
else:
return loss
32 changes: 32 additions & 0 deletions ivy/functional/backends/paddle/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,35 @@ def soft_margin_loss(
reduction: Optional[str] = "mean",
) -> paddle.Tensor:
return paddle.nn.functional.soft_margin_loss(input, label, reduction=reduction)


@with_unsupported_device_and_dtypes(
{
"2.5.1 and below": {
"cpu": (
"float16",
"int8",
"int16",
"int32",
"int64",
"uint8",
"complex64",
"complex128",
"bool",
)
}
},
backend_version,
)
def margin_ranking_loss(
input1: paddle.Tensor,
input2: paddle.Tensor,
target: paddle.Tensor,
/,
*,
margin: Optional[float] = 0.0,
reduction: Optional[str] = "mean",
) -> paddle.Tensor:
return paddle.nn.functional.margin_ranking_loss(
input1, input2, target, margin=margin, reduction=reduction
)
21 changes: 21 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,24 @@ def soft_margin_loss(
return tf.reduce_mean(loss)
else:
return loss


@with_unsupported_dtypes({"2.13.0 and below": "bool"}, backend_version)
def margin_ranking_loss(
input1: tf.Tensor,
input2: tf.Tensor,
target: tf.Tensor,
/,
*,
margin: Optional[float] = 1.0,
reduction: Optional[str] = "mean"
) -> tf.Tensor:
pairwise_margin = margin - target * (input1 - input2)
loss = tf.where(pairwise_margin > 0, pairwise_margin, 0.0)

if reduction == "sum":
return tf.reduce_sum(loss)
elif reduction == "mean":
return tf.reduce_mean(loss)
else:
return loss
19 changes: 19 additions & 0 deletions ivy/functional/backends/torch/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,22 @@ def soft_margin_loss(
target,
reduction=reduction,
)


@with_unsupported_dtypes({"2.0.1 and below": ("float16", "bool")}, backend_version)
def margin_ranking_loss(
input1: torch.Tensor,
input2: torch.Tensor,
target: torch.Tensor,
/,
*,
margin: Optional[float] = 0.0,
reduction: Optional[str] = "mean",
) -> torch.Tensor:
return torch.nn.functional.margin_ranking_loss(
input1,
input2,
target,
margin=margin,
reduction=reduction,
)
Loading

0 comments on commit 814904d

Please sign in to comment.