Skip to content

Commit

Permalink
Smooth L1 loss added to Experimental API
Browse files Browse the repository at this point in the history
  • Loading branch information
DebadityaPal committed Mar 30, 2023
1 parent 20f19da commit 9eabb03
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
93 changes: 93 additions & 0 deletions ivy/functional/ivy/experimental/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,96 @@ def binary_cross_entropy_with_logits(
)

return result


@handle_nestable
@handle_exceptions
@handle_array_like_without_promotion
def smooth_l1_loss(
x: Union[ivy.Array, ivy.NativeArray],
y: Union[ivy.Array, ivy.NativeArray],
/,
*,
beta: float = 1.0,
reduction: str = "none",
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Computes the smooth L1 loss between two arrays.
Parameters
----------
x
input array.
y
input array.
beta
a float in [0.0, 1.0] specifying the amount of smoothing when calculating the
loss. If beta is ``0``, no smoothing will be applied. Default: ``1.0``.
reduction
Specifies the reduction to apply to the output: ``'none'`` | ``'sum'`` |
``'mean'``. ``'none'``: no reduction will be applied, ``'sum'``: the sum of the
output will be returned, ``'mean'``: the mean of the output will be returned.
Default: ``'none'``.
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
Returns
-------
ret
The smooth L1 loss between the given distributions.
Examples
--------
With :class:`ivy.Array` input:
>>> x = ivy.array([1, 2, 3, 4])
>>> y = ivy.array([2, 3, 4, 5])
>>> z = ivy.smooth_l1_loss(x, y)
>>> print(z)
ivy.array([0.5, 0.5, 0.5, 0.5])
>>> x = ivy.array([[1, 2, 3, 4]])
>>> y = ivy.array([[2, 3, 4, 5]])
>>> z = ivy.smooth_l1_loss(x, y, beta=0.5)
>>> print(z)
ivy.array([[0.75, 0.75, 0.75, 0.75]])
>>> x = ivy.array([[1, 2, 3, 4]])
>>> y = ivy.array([[2, 3, 4, 5]])
>>> z = ivy.smooth_l1_loss(x, y, beta=0.0)
>>> print(z)
ivy.array([[1, 1, 1, 1]])
With a mix of :class:`ivy.Array` and :class:`ivy.NativeArray` inputs:
>>> x = ivy.array([1, 2, 3, 4])
>>> y = ivy.native_array([2, 3, 4, 5])
>>> z = ivy.smooth_l1_loss(x, y)
>>> print(z)
ivy.array([0.5, 0.5, 0.5, 0.5])
With :class:`ivy.Container` input:
>>> x = ivy.Container(a=ivy.array([1, 2, 3]), b=ivy.array([4, 5, 6]))
>>> y = ivy.Container(a=ivy.array([2, 3, 4]), b=ivy.array([5, 6, 7]))
>>> z = ivy.smooth_l1_loss(x, y)
>>> print(z)
{
a: ivy.array([0.5, 0.5, 0.5]),
b: ivy.array([0.5, 0.5, 0.5])
}
"""
ivy.utils.assertions.check_elem_in_list(reduction, ["none", "sum", "mean"])
abs_diff = ivy.abs(x - y)
if beta == 0.0:
result = -abs_diff
else:
result = -ivy.where(
abs_diff < beta,
0.5 * ivy.square(abs_diff) / beta,
abs_diff - 0.5 * beta,
)
result = _reduce_loss(reduction, result, None, out)
return result
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,58 @@ def test_binary_cross_entropy_with_logits(
pos_weight=pos_weight[0],
reduction=reduction,
)


# smooth_l1_loss
@handle_test(
fn_tree="functional.ivy.experimental.smooth_l1_loss",
dtype_and_true=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=0,
max_value=1,
allow_inf=False,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
dtype_and_pred=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=0,
max_value=1,
allow_inf=False,
exclude_min=True,
exclude_max=True,
min_num_dims=1,
max_num_dims=1,
min_dim_size=2,
),
reduction=st.sampled_from(["none", "sum", "mean"]),
beta=helpers.floats(min_value=0, max_value=1),
)
def test_smooth_l1_loss(
dtype_and_true,
dtype_and_pred,
reduction,
beta,
test_flags,
backend_fw,
fn_name,
on_device,
ground_truth_backend,
):
pred_dtype, pred = dtype_and_pred
true_dtype, true = dtype_and_true
helpers.test_function(
ground_truth_backend=ground_truth_backend,
input_dtypes=true_dtype + pred_dtype,
test_flags=test_flags,
fw=backend_fw,
fn_name=fn_name,
on_device=on_device,
rtol_=1e-1,
atol_=1e-1,
true=true[0],
pred=pred[0],
beta=beta,
reduction=reduction,
)

0 comments on commit 9eabb03

Please sign in to comment.