Skip to content

Commit

Permalink
feat: add ivy.ssim_loss (#27134)
Browse files Browse the repository at this point in the history
Co-authored-by: Sam-Armstrong <[email protected]>
  • Loading branch information
hi-sushanta and Sam-Armstrong authored Jul 13, 2024
1 parent 19b9c7f commit 262472b
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
60 changes: 60 additions & 0 deletions ivy/functional/ivy/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,63 @@ def sparse_cross_entropy(
return ivy.cross_entropy(
true, pred, axis=axis, epsilon=epsilon, reduction=reduction, out=out
)


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@inputs_to_ivy_arrays
@handle_array_function
def ssim_loss(
true: Union[ivy.Array, ivy.NativeArray],
pred: Union[ivy.Array, ivy.NativeArray],
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Calculate the Structural Similarity Index (SSIM) loss between two
images.
Parameters
----------
true: A 4D image array of shape (batch_size, channels, height, width).
pred: A 4D image array of shape (batch_size, channels, height, width).
Returns
-------
ivy.Array: The SSIM loss measure similarity between the two images.
Examples
--------
With :class:`ivy.Array` input:
>>> import ivy
>>> x = ivy.ones((5, 3, 28, 28))
>>> y = ivy.zeros((5, 3, 28, 28))
>>> ivy.ssim_loss(x, y)
ivy.array(0.99989986)
"""
# Constants for stability
C1 = 0.01 ** 2
C2 = 0.03 ** 2

# Calculate the mean of the two images
mu_x = ivy.avg_pool2d(pred, (3, 3), (1, 1), "SAME")
mu_y = ivy.avg_pool2d(true, (3, 3), (1, 1), "SAME")

# Calculate variance and covariance
sigma_x2 = ivy.avg_pool2d(pred * pred, (3, 3), (1, 1), "SAME") - mu_x * mu_x
sigma_y2 = ivy.avg_pool2d(true * true, (3, 3), (1, 1), "SAME") - mu_y * mu_y
sigma_xy = ivy.avg_pool2d(pred * true, (3, 3), (1, 1), "SAME") - mu_x * mu_y

# Calculate SSIM
ssim = ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / (
(mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x2 + sigma_y2 + C2)
)

# Convert SSIM to loss
ssim_loss_value = 1 - ssim

# Return mean SSIM loss
ret = ivy.mean(ssim_loss_value)

if ivy.exists(out):
ret = ivy.inplace_update(out, ret)
return ret
38 changes: 38 additions & 0 deletions ivy_tests/test_ivy/test_functional/test_nn/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,41 @@ def test_sparse_cross_entropy(
epsilon=epsilon,
reduction=reduction,
)


@handle_test(
fn_tree="functional.ivy.ssim_loss",
dtype_and_true=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=-1,
max_value=1,
min_num_dims=4,
max_num_dims=4,
min_dim_size=2,
),
dtype_and_pred=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=-1,
max_value=1,
min_num_dims=4,
max_num_dims=4,
min_dim_size=2,
),
)
def test_ssim_loss(
dtype_and_true, dtype_and_pred, test_flags, backend_fw, fn_name, on_device
):
true_dtype, true = dtype_and_true
pred_dtype, pred = dtype_and_pred

helpers.test_function(
input_dtypes=pred_dtype + true_dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
fn_name=fn_name,
on_device=on_device,
true=true[0],
pred=pred[0],
rtol_=1e-02,
atol_=1e-02,
)

0 comments on commit 262472b

Please sign in to comment.