Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect CLIP-IQA type hints #2952

Merged
merged 10 commits into from
Feb 24, 2025
23 changes: 21 additions & 2 deletions src/torchmetrics/functional/multimodal/clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def _get_clip_iqa_model_and_processor(
return _get_clip_model_and_processor(model_name_or_path)


def _clip_iqa_format_prompts(prompts: tuple[Union[str, tuple[str, str]]] = ("quality",)) -> tuple[list[str], list[str]]:
def _clip_iqa_format_prompts(
prompts: tuple[Union[str, tuple[str, str]], ...] = ("quality",),
) -> tuple[list[str], list[str]]:
"""Converts the provided keywords into a list of prompts for the model to calculate the anchor vectors.

Args:
Expand Down Expand Up @@ -225,7 +227,7 @@ def clip_image_quality_assessment(
"openai/clip-vit-large-patch14",
] = "clip_iqa",
data_range: float = 1.0,
prompts: tuple[Union[str, tuple[str, str]]] = ("quality",),
prompts: tuple[Union[str, tuple[str, str]], ...] = ("quality",),
) -> Union[Tensor, dict[str, Tensor]]:
"""Calculates `CLIP-IQA`_, that can be used to measure the visual content of images.

Expand Down Expand Up @@ -329,3 +331,20 @@ def clip_image_quality_assessment(
anchors = _clip_iqa_get_anchor_vectors(model_name_or_path, model, processor, prompts_list, device)
img_features = _clip_iqa_update(model_name_or_path, images, model, processor, data_range, device)
return _clip_iqa_compute(img_features, anchors, prompts_names)


if TYPE_CHECKING:
from functools import partial
from typing import Any, cast

images = cast(Any, None)

f = partial(clip_image_quality_assessment, images=images)
f(prompts=("colorfullness",))
f(
prompts=("quality", "brightness", "noisiness"),
)
f(
prompts=("quality", "brightness", "noisiness", "colorfullness"),
)
f(prompts=(("Photo of a cat", "Photo of a dog"), "quality", ("Colorful photo", "Black and white photo")))
16 changes: 14 additions & 2 deletions src/torchmetrics/multimodal/clip_iqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from typing import Any, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(
"openai/clip-vit-large-patch14",
] = "clip_iqa",
data_range: float = 1.0,
prompts: tuple[Union[str, tuple[str, str]]] = ("quality",),
prompts: tuple[Union[str, tuple[str, str]], ...] = ("quality",),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -259,3 +259,15 @@ def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_

"""
return self._plot(val, ax)


if TYPE_CHECKING:
f = CLIPImageQualityAssessment
f(prompts=("colorfullness",))
f(
prompts=("quality", "brightness", "noisiness"),
)
f(
prompts=("quality", "brightness", "noisiness", "colorfullness"),
)
f(prompts=(("Photo of a cat", "Photo of a dog"), "quality", ("Colorful photo", "Black and white photo")))
Loading