diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ccf8cae705..a4f853bae72 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `Upper Face Dynamics Deviation (FDD)` metric to multimodal domain. ([#3097](https://github.com/Lightning-AI/torchmetrics/issues/3097)) ### Changed diff --git a/docs/source/links.rst b/docs/source/links.rst index 539d2728e74..abb08b6b9ee 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -150,6 +150,7 @@ .. _CLIP-IQA: https://arxiv.org/abs/2207.12396 .. _CLIP: https://arxiv.org/abs/2103.00020 .. _LVE: https://openaccess.thecvf.com/content/ICCV2021/papers/Richard_MeshTalk_3D_Face_Animation_From_Speech_Using_Cross-Modality_Disentanglement_ICCV_2021_paper.pdf +.. _FDD: https://openaccess.thecvf.com/content/CVPR2023/papers/Xing_CodeTalker_Speech-Driven_3D_Facial_Animation_With_Discrete_Motion_Prior_CVPR_2023_paper.pdf .. _PPL : https://arxiv.org/abs/1812.04948 .. _CIOU: https://arxiv.org/abs/2005.03572 .. _DIOU: https://arxiv.org/abs/1911.08287v1 diff --git a/docs/source/multimodal/fdd.rst b/docs/source/multimodal/fdd.rst new file mode 100644 index 00000000000..8a29367bac6 --- /dev/null +++ b/docs/source/multimodal/fdd.rst @@ -0,0 +1,20 @@ +.. customcarditem:: + :header: Upper Face Dynamics Deviation (FDD) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg + :tags: Multimodal + +.. include:: ../links.rst + +################################### +Upper Face Dynamics Deviation (FDD) +################################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.multimodal.fdd.UpperFaceDynamicsDeviation + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.multimodal.fdd.upper_face_dynamics_deviation diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index d3847b37ce1..2cec2309623 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -74,7 +74,7 @@ ) from torchmetrics.functional.image._deprecated import _total_variation as total_variation from torchmetrics.functional.image._deprecated import _universal_image_quality_index as universal_image_quality_index -from torchmetrics.functional.multimodal import lip_vertex_error +from torchmetrics.functional.multimodal import lip_vertex_error, upper_face_dynamics_deviation from torchmetrics.functional.nominal import ( cramers_v, cramers_v_matrix, @@ -246,6 +246,7 @@ "tschuprows_t_matrix", "tweedie_deviance_score", "universal_image_quality_index", + "upper_face_dynamics_deviation", "weighted_mean_absolute_percentage_error", "word_error_rate", "word_information_lost", diff --git a/src/torchmetrics/functional/multimodal/__init__.py b/src/torchmetrics/functional/multimodal/__init__.py index ac9f5e199a4..30fda5feb01 100644 --- a/src/torchmetrics/functional/multimodal/__init__.py +++ b/src/torchmetrics/functional/multimodal/__init__.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.functional.multimodal.fdd import upper_face_dynamics_deviation from torchmetrics.functional.multimodal.lve import lip_vertex_error from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 -__all__ = ["lip_vertex_error"] +__all__ = ["lip_vertex_error", "upper_face_dynamics_deviation"] if _TRANSFORMERS_GREATER_EQUAL_4_10: from torchmetrics.functional.multimodal.clip_iqa import clip_image_quality_assessment diff --git a/src/torchmetrics/functional/multimodal/fdd.py b/src/torchmetrics/functional/multimodal/fdd.py new file mode 100644 index 00000000000..c542c8d7f3f --- /dev/null +++ b/src/torchmetrics/functional/multimodal/fdd.py @@ -0,0 +1,116 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List + +import torch +from torch import Tensor + + +def upper_face_dynamics_deviation( + vertices_pred: Tensor, + vertices_gt: Tensor, + template: Tensor, + upper_face_map: List[int], +) -> Tensor: + r"""Compute Upper Face Dynamics Deviation (FDD) for 3D talking head evaluation. + + The Upper Face Dynamics Deviation (FDD) metric evaluates the quality of facial expressions in the upper + face region for 3D talking head models. It quantifies the deviation in vertex motion dynamics between the + predicted and ground truth sequences by comparing the temporal variation (standard deviation) of per-vertex + squared displacements relative to a neutral template. Lower values of FDD indicate closer alignment of the + predicted upper-face motion dynamics with the ground truth. + + The metric is defined as: + + .. math:: + \text{FDD} = \frac{1}{|S_U|} \sum_{v \in S_U} \Big( \text{std}(\| x_{1:T,v} - + \text{template}_v \|_2^2) - \text{std}(\| \hat{x}_{1:T,v} - \text{template}_v \|_2^2) \Big) + + where :math:`T` is the number of frames, :math:`S_U` is the set of upper-face vertices with :math:`M = |S_U|`, + :math:`x_{t,v}` are the 3D coordinates of vertex :math:`v` at frame :math:`t` in the ground truth sequence, + and :math:`\hat{x}_{t,v} \in \mathbb{R}^3` are the corresponding predicted vertices. The neutral template coordinate + of vertex :math:`v` is denoted as :math:`\text{template}_v \in \mathbb{R}^3`. The operator :math:`\text{std}(\cdot)` + computes the standard deviation of the temporal sequence. + + Args: + vertices_pred: Predicted vertices tensor of shape (T, V, 3) where T is number of frames, + V is number of vertices, and 3 represents XYZ coordinates. + vertices_gt: Ground truth vertices tensor of shape (T, V, 3) where T is number of frames, + V is number of vertices, and 3 represents XYZ coordinates. + template: Template mesh tensor of shape (V, 3) representing the neutral face. + upper_face_map: List of vertex indices corresponding to the upper face region. + + Returns: + torch.Tensor: Scalar tensor containing the mean FDD value across upper-face vertices. + + Raises: + ValueError: + If the number of dimensions of `vertices_pred` or `vertices_gt` is not 3. + If `template` does not have shape (No_of_vertices, 3). + If `vertices_pred` and `vertices_gt` do not have the same vertex and coordinate dimensions. + If `template` shape does not match the vertex-coordinate dimensions of `vertices_pred` (and `vertices_gt`). + If ``upper_face_map`` is empty or contains invalid vertex indices. + + Example: + >>> import torch + >>> from torchmetrics.functional.multimodal import upper_face_dynamics_deviation + >>> vertices_pred = torch.randn(10, 100, 3, generator=torch.manual_seed(41)) + >>> vertices_gt = torch.randn(10, 100, 3, generator=torch.manual_seed(42)) + >>> upper_face_map = [10, 11, 12, 13, 14] + >>> template = torch.randn(100, 3, generator=torch.manual_seed(43)) + >>> upper_face_dynamics_deviation(vertices_pred, vertices_gt, template, upper_face_map) + tensor(1.0385) + + """ + if vertices_pred.ndim != 3 or vertices_gt.ndim != 3: + raise ValueError( + f"Expected both vertices_pred and vertices_gt to have 3 dimensions but got " + f"{vertices_pred.ndim} and {vertices_gt.ndim} dimensions respectively." + ) + if template.ndim != 2 or template.shape[1] != 3: + raise ValueError(f"Expected template to have shape (V, 3) but got {template.shape}.") + if vertices_pred.shape[1:] != vertices_gt.shape[1:]: + raise ValueError( + f"Expected vertices_pred and vertices_gt to have same vertex and coordinate dimensions but got " + f"shapes {vertices_pred.shape} and {vertices_gt.shape}." + ) + if vertices_pred.shape[1:] != template.shape: + raise ValueError( + f"Shape mismatch: expected template shape {template.shape} to match " + f"vertex-coordinate dimensions of predictions {vertices_pred.shape[1:]}, " + f"but got template shape {template.shape} instead." + ) + if not upper_face_map: + raise ValueError("upper_face_map cannot be empty.") + if min(upper_face_map) < 0 or max(upper_face_map) >= template.shape[0]: + raise ValueError( + f"upper_face_map contains out-of-range vertex indices. " + f"Valid index range is [0, {template.shape[0] - 1}], " + f"but received indices in range [{min(upper_face_map)}, {max(upper_face_map)}]." + ) + min_frames = min(vertices_pred.shape[0], vertices_gt.shape[0]) + pred = vertices_pred[:min_frames, upper_face_map, :] # (T, M, 3) + gt = vertices_gt[:min_frames, upper_face_map, :] + template = template.to(pred.device)[upper_face_map, :] # (M, 3) + + pred_disp = pred - template # (T, M, 3) + gt_disp = gt - template + + pred_norm_sq = torch.sum(pred_disp**2, dim=-1) # (T, M) + gt_norm_sq = torch.sum(gt_disp**2, dim=-1) # (T, M) + + pred_dyn = torch.std(pred_norm_sq, dim=0, unbiased=False) # (M,) + gt_dyn = torch.std(gt_norm_sq, dim=0, unbiased=False) + + return torch.mean(gt_dyn - pred_dyn) # scalar diff --git a/src/torchmetrics/multimodal/__init__.py b/src/torchmetrics/multimodal/__init__.py index a00661de9d7..ed8ce419461 100644 --- a/src/torchmetrics/multimodal/__init__.py +++ b/src/torchmetrics/multimodal/__init__.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.multimodal.fdd import UpperFaceDynamicsDeviation from torchmetrics.multimodal.lve import LipVertexError from torchmetrics.utilities.imports import _TRANSFORMERS_GREATER_EQUAL_4_10 -__all__ = ["LipVertexError"] +__all__ = ["LipVertexError", "UpperFaceDynamicsDeviation"] if _TRANSFORMERS_GREATER_EQUAL_4_10: from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment diff --git a/src/torchmetrics/multimodal/fdd.py b/src/torchmetrics/multimodal/fdd.py new file mode 100644 index 00000000000..1c6715d8462 --- /dev/null +++ b/src/torchmetrics/multimodal/fdd.py @@ -0,0 +1,202 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Union + +from torch import Tensor + +from torchmetrics.functional.multimodal.fdd import upper_face_dynamics_deviation +from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["UpperFaceDynamicsDeviation.plot"] + + +class UpperFaceDynamicsDeviation(Metric): + r"""Implements the Upper Facial Dynamics Deviation (FDD) metric for 3D talking head evaluation. + + The Upper Face Dynamics Deviation (FDD) metric evaluates the quality of facial expressions in the upper + face region for 3D talking head models. It quantifies the deviation in vertex motion dynamics between the + predicted and ground truth sequences by comparing the temporal variation (standard deviation) of per-vertex + squared displacements relative to a neutral template. Lower values of FDD indicate closer alignment of the + predicted upper-face motion dynamics with the ground truth. + + The metric is defined as: + + .. math:: + \text{FDD} = \frac{1}{|S_U|} \sum_{v \in S_U} \Big( \text{std}(\| x_{1:T,v} - + \text{template}_v \|_2^2) - \text{std}(\| \hat{x}_{1:T,v} - \text{template}_v \|_2^2) \Big) + + where :math:`T` is the number of frames, :math:`S_U` is the set of upper-face vertices with :math:`M = |S_U|`, + :math:`x_{t,v}` are the 3D coordinates of vertex :math:`v` at frame :math:`t` in the ground truth sequence, + and :math:`\hat{x}_{t,v} \in \mathbb{R}^3` are the corresponding predicted vertices. The neutral template coordinate + of vertex :math:`v` is denoted as :math:`\text{template}_v \in \mathbb{R}^3`. The operator :math:`\text{std}(\cdot)` + computes the standard deviation of the temporal sequence. + + As input to ``forward`` and ``update``, the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predicted vertices tensor of shape (T, V, 3) where T is the number of frames, + V is the number of vertices, and 3 represents XYZ coordinates. + - ``target`` (:class:`~torch.Tensor`): Ground truth vertices tensor of shape (T, V, 3) where T is the number of + frames, V is the number of vertices, and 3 represents XYZ coordinates. + + As output of ``forward`` and ``compute``, the metric returns the following output: + + - ``fdd_score`` (:class:`~torch.Tensor`): A scalar tensor containing the mean Face Dynamics Deviation + across all upper-face vertices. + + Args: + template: Template mesh tensor of shape (V, 3) representing the neutral face. + upper_face_map: List of vertex indices for the upper-face region. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Raises: + ValueError: + If the number of dimensions of `vertices_pred` or `vertices_gt` is not 3. + If `template` does not have shape (No_of_vertices, 3). + If `vertices_pred` and `vertices_gt` do not have the same vertex and coordinate dimensions. + If `template` shape does not match the vertex-coordinate dimensions of `vertices_pred` (and `vertices_gt`). + If ``upper_face_map`` is empty or contains invalid vertex indices. + + Example: + >>> import torch + >>> from torchmetrics.multimodal.fdd import UpperFaceDynamicsDeviation + >>> template = torch.randn(100, 3, generator=torch.manual_seed(41)) + >>> metric = UpperFaceDynamicsDeviation(template=template, upper_face_map=[0, 1, 2, 3, 4]) + >>> vertices_pred = torch.randn(10, 100, 3, generator=torch.manual_seed(42)) + >>> vertices_gt = torch.randn(10, 100, 3, generator=torch.manual_seed(43)) + >>> metric(vertices_pred, vertices_gt) + tensor(0.2131) + + """ + + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + + vertices_pred_list: List[Tensor] + vertices_gt_list: List[Tensor] + + def __init__( + self, + template: Tensor, + upper_face_map: List[int], + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.upper_face_map = upper_face_map + self.template = template + + if self.template.ndim != 2 or self.template.shape[1] != 3: + raise ValueError(f"Expected template to have shape (V, 3) but got {template.shape}.") + if not self.upper_face_map: + raise ValueError("upper_face_map cannot be empty.") + if min(self.upper_face_map) < 0 or max(self.upper_face_map) >= self.template.shape[0]: + raise ValueError( + f"upper_face_map contains out-of-range vertex indices. " + f"Valid index range is [0, {self.template.shape[0] - 1}], " + f"but received indices in range [{min(self.upper_face_map)}, {max(self.upper_face_map)}]." + ) + self.add_state("vertices_pred_list", default=[], dist_reduce_fx=None) + self.add_state("vertices_gt_list", default=[], dist_reduce_fx=None) + + def update(self, vertices_pred: Tensor, vertices_gt: Tensor) -> None: + """Update metric states with predictions and targets. + + Args: + vertices_pred: Predicted vertices tensor of shape (T, V, 3) where T is number of frames, + V is number of vertices, and 3 represents XYZ coordinates + vertices_gt: Ground truth vertices tensor of shape (T', V, 3) where T is number of frames, + V is number of vertices, and 3 represents XYZ coordinates + + """ + if vertices_pred.ndim != 3 or vertices_gt.ndim != 3: + raise ValueError( + f"Expected both vertices_pred and vertices_gt to have 3 dimensions but got " + f"{vertices_pred.ndim} and {vertices_gt.ndim} dimensions respectively." + ) + if vertices_pred.shape[1:] != vertices_gt.shape[1:]: + raise ValueError( + f"Expected vertices_pred and vertices_gt to have same vertex and coordinate dimensions but got " + f"shapes {vertices_pred.shape} and {vertices_gt.shape}." + ) + if vertices_pred.shape[1:] != self.template.shape: + raise ValueError( + f"Shape mismatch: expected template shape {self.template.shape} to match " + f"vertex-coordinate dimensions of predictions {vertices_pred.shape[1:]}, " + f"but got template shape {self.template.shape} instead." + ) + + min_frames = min(vertices_pred.shape[0], vertices_gt.shape[0]) + self.vertices_pred_list.append(vertices_pred[:min_frames]) + self.vertices_gt_list.append(vertices_gt[:min_frames]) + + def compute(self) -> Tensor: + """Compute the Upper Face Dynamics Deviation over all accumulated states. + + Returns: + torch.Tensor: A scalar tensor with the mean FDD value + + """ + vertices_pred = dim_zero_cat(self.vertices_pred_list) + vertices_gt = dim_zero_cat(self.vertices_gt_list) + return upper_face_dynamics_deviation(vertices_pred, vertices_gt, self.template, self.upper_face_map) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.multimodal.fdd import UpperFaceDynamicsDeviation + >>> metric = UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[0, 1, 2, 3, 4]) + >>> vertices_pred = torch.randn(10, 100, 3, generator=torch.manual_seed(42)) + >>> vertices_gt = torch.randn(10, 100, 3, generator=torch.manual_seed(43)) + >>> metric.update(vertices_pred, vertices_gt) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.multimodal.fdd import UpperFaceDynamicsDeviation + >>> metric = UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[0, 1, 2, 3, 4]) + >>> values = [] + >>> for _ in range(10): + ... vertices_pred = torch.randn(10, 100, 3, generator=torch.manual_seed(42+_)) + ... vertices_gt = torch.randn(10, 100, 3, generator=torch.manual_seed(43+_)) + ... values.append(metric(vertices_pred, vertices_gt)) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/tests/unittests/multimodal/test_fdd.py b/tests/unittests/multimodal/test_fdd.py new file mode 100644 index 00000000000..28c65293a61 --- /dev/null +++ b/tests/unittests/multimodal/test_fdd.py @@ -0,0 +1,158 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import NamedTuple + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pytest +import torch +from torch import Tensor + +from torchmetrics.functional.multimodal.fdd import upper_face_dynamics_deviation +from torchmetrics.multimodal.fdd import UpperFaceDynamicsDeviation +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +seed_all(42) + + +class _InputVertices(NamedTuple): + vertices_pred: Tensor + vertices_gt: Tensor + template: Tensor + + +def _generate_vertices(batch_size: int = 1) -> _InputVertices: + """Generate random vertices for testing.""" + return _InputVertices( + vertices_pred=torch.randn(batch_size, 10, 100, 3), + vertices_gt=torch.randn(batch_size, 10, 100, 3), + template=torch.randn(100, 3), + ) + + +def _reference_fdd(vertices_pred, vertices_gt, template, upper_face_map): + """Reference implementation for FDD metric using numpy.""" + min_frames = min(vertices_pred.shape[0], vertices_gt.shape[0]) + pred = vertices_pred[:min_frames, upper_face_map, :].detach().cpu().numpy() # (T, M, 3) + gt = vertices_gt[:min_frames, upper_face_map, :].detach().cpu().numpy() # (T, M, 3) + template = template[upper_face_map, :].detach().cpu().numpy() # (M, 3) + + displacements_gt = gt - template # (T, V, 3) + displacements_pred = pred - template + + l2_gt = np.sum(displacements_gt**2, axis=-1) # (T, M), squared L2 norm + l2_pred = np.sum(displacements_pred**2, axis=-1) + + std_diff = np.std(l2_gt, axis=0) - np.std(l2_pred, axis=0) # (M,) + + fdd = np.mean(std_diff) + + return torch.tensor(fdd) + + +class TestUpperFaceDynamicsDeviation(MetricTester): + """Test class for `UpperFaceDynamicsDeviation` metric (FDD).""" + + atol: float = 1e-2 + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_fdd_metric_class(self, ddp): + """Test class implementation of metric.""" + upper_face_map = [0, 1, 2, 3, 4] + vertices_pred, vertices_gt, template = _generate_vertices(batch_size=4) + + self.run_class_metric_test( + ddp=ddp, + preds=vertices_pred, + target=vertices_gt, + metric_class=UpperFaceDynamicsDeviation, + reference_metric=partial(_reference_fdd, template=template, upper_face_map=upper_face_map), + metric_args={"template": template, "upper_face_map": upper_face_map}, + ) + + def test_fdd_functional(self): + """Test functional implementation of metric.""" + upper_face_map = [0, 1, 2, 3, 4] + vertices_pred, vertices_gt, template = _generate_vertices(batch_size=4) + + self.run_functional_metric_test( + preds=vertices_pred, + target=vertices_gt, + metric_functional=upper_face_dynamics_deviation, + reference_metric=partial(_reference_fdd, template=template, upper_face_map=upper_face_map), + metric_args={"template": template, "upper_face_map": upper_face_map}, + ) + + def test_fdd_differentiability(self): + """Test differentiability of FDD metric.""" + upper_face_map = [0, 1, 2, 3, 4] + vertices_pred, vertices_gt, template = _generate_vertices(batch_size=4) + + self.run_differentiability_test( + preds=vertices_pred, + target=vertices_gt, + metric_module=UpperFaceDynamicsDeviation, + metric_functional=upper_face_dynamics_deviation, + metric_args={"template": template, "upper_face_map": upper_face_map}, + ) + + def test_error_on_wrong_dimensions(self): + """Test that an error is raised for wrong input dimensions.""" + metric = UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[0, 1, 2, 3, 4]) + with pytest.raises( + ValueError, match="Expected both vertices_pred and vertices_gt to have 3 dimensions but got.*" + ): + metric(torch.randn(10, 100), torch.randn(10, 100, 3)) + + def test_error_on_mismatched_dimensions(self): + """Test that an error is raised for mismatched vertex dimensions.""" + metric = UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[0, 1, 2, 3, 4]) + with pytest.raises( + ValueError, + match="Expected vertices_pred and vertices_gt to have same vertex and coordinate dimensions but got.*", + ): + metric(torch.randn(10, 80, 3), torch.randn(10, 100, 3)) + + def test_error_on_template_shape_mismatch(self): + """Test that an error is raised when template shape does not match vertex-coordinate dimensions.""" + metric = UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[0, 1, 2, 3, 4]) + with pytest.raises(ValueError, match="Shape mismatch: expected template shape.*to match.*"): + metric(torch.randn(10, 120, 3), torch.randn(10, 120, 3)) + + def test_error_on_empty_upper_face_map(self): + """Test that an error is raised if upper_face_map is empty.""" + with pytest.raises(ValueError, match="upper_face_map cannot be empty."): + UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[]) + + def test_error_on_invalid_upper_face_indices(self): + """Test that an error is raised if upper_face_map has invalid indices.""" + with pytest.raises(ValueError, match="upper_face_map contains out-of-range vertex indices.*"): + UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[98, 99, 100]) + + def test_different_sequence_lengths(self): + """Test that the metric handles different sequence lengths correctly.""" + metric = UpperFaceDynamicsDeviation(template=torch.randn(100, 3), upper_face_map=[0, 1, 2, 3, 4]) + metric(torch.randn(10, 100, 3), torch.randn(8, 100, 3)) + + def test_plot_method(self): + """Test the plot method of FDD.""" + vertices_pred, vertices_gt, template = _generate_vertices() + metric = UpperFaceDynamicsDeviation(template=template, upper_face_map=[0, 1, 2, 3, 4]) + metric.update(vertices_pred[0], vertices_gt[0]) + fig, ax = metric.plot() + assert isinstance(fig, plt.Figure) + assert isinstance(ax, matplotlib.axes.Axes) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 4d973a790c9..625ed609505 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -128,7 +128,7 @@ TotalVariation, UniversalImageQualityIndex, ) -from torchmetrics.multimodal import LipVertexError +from torchmetrics.multimodal import LipVertexError, UpperFaceDynamicsDeviation from torchmetrics.nominal import CramersV, FleissKappa, PearsonsContingencyCoefficient, TheilsU, TschuprowsT from torchmetrics.regression import ( ConcordanceCorrCoef, @@ -692,6 +692,12 @@ lambda: torch.randn(10, 100, 3), id="lip vertex error", ), + pytest.param( + partial(UpperFaceDynamicsDeviation, template=torch.randn(100, 3), upper_face_map=[0, 1, 2, 3, 4]), + lambda: torch.randn(10, 100, 3), + lambda: torch.randn(10, 100, 3), + id="upper face dynamic deviation", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 3])