Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added support for `dinov2` feature extractor in `FrechetInceptionDistance` ([#3186](https://github.com/Lightning-AI/torchmetrics/pull/3186))


-


Expand Down
170 changes: 146 additions & 24 deletions src/torchmetrics/image/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from torch.nn.functional import adaptive_avg_pool2d

from torchmetrics.metric import Metric
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCH_FIDELITY_AVAILABLE, _TORCHVISION_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.utilities.prints import rank_zero_warn

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["FrechetInceptionDistance.plot"]

if _TORCH_FIDELITY_AVAILABLE:
from torch_fidelity.feature_extractor_dinov2 import FeatureExtractorDinoV2 as _FeatureExtractorDinoV2
from torch_fidelity.feature_extractor_inceptionv3 import FeatureExtractorInceptionV3 as _FeatureExtractorInceptionV3
from torch_fidelity.helpers import vassert
from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
Expand All @@ -36,6 +38,9 @@
class _FeatureExtractorInceptionV3(Module): # type: ignore[no-redef]
pass

class _FeatureExtractorDinoV2(Module): # type: ignore[no-redef]
pass

vassert = None
interpolate_bilinear_2d_like_tensorflow1x = None

Expand Down Expand Up @@ -171,6 +176,92 @@ def forward(self, x: Tensor) -> Tensor:
return out[0].reshape(x.shape[0], -1)


class NoTrainDinoV2(_FeatureExtractorDinoV2):
"""Module that never leaves evaluation mode."""

def __init__(
self,
name: str,
features_list: list[str],
feature_extractor_weights_path: Optional[str] = None,
antialias: bool = True,
) -> None:
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"NoTrainDinoV2 module requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(
"NoTrainDinoV2 module requires that `torchvision` is installed."
" Please install with `pip install torchmetrics[image]`."
)

super().__init__(name, features_list, feature_extractor_weights_path)
self.use_antialias = antialias
# put into evaluation mode
self.eval()

def train(self, mode: bool) -> "NoTrainDinoV2":
"""Force network to always be in evaluation mode."""
return super().train(False)

def _torch_fidelity_forward(self, x: Tensor) -> tuple[Tensor, ...]:
"""Forward method of dinov2 net.

Copy of the forward method from this file:
https://github.com/toshas/torch-fidelity/blob/master/torch_fidelity/feature_extractor_dinov2.py
with a single line change regarding the casting of `x` in the beginning.

Corresponding license file (Apache License, Version 2.0):
https://github.com/toshas/torch-fidelity/blob/master/LICENSE.md

"""
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, "Expecting image as torch.Tensor with dtype=torch.uint8")
vassert(x.dim() == 4 and x.shape[1] == 3, f"Input is not Bx3xHxW: {x.shape}")

x = x.to(self.feature_extractor_internal_dtype)
# N x 3 x ? x ?

if self.use_antialias:
x = torch.nn.functional.interpolate(
x,
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
mode="bilinear",
align_corners=False,
antialias=True,
)
else:
x = interpolate_bilinear_2d_like_tensorflow1x(
x,
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
align_corners=False,
)
# N x 3 x 224 x 224
from torchvision.transforms.functional import normalize

x = normalize(
x,
(255 * 0.485, 255 * 0.456, 255 * 0.406),
(255 * 0.229, 255 * 0.224, 255 * 0.225),
inplace=False,
)
# N x 3 x 224 x 224

x = self.model(x)

out = {
"dinov2": x.to(torch.float32),
}

return tuple(out[a] for a in self.features_list)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass of neural network with reshaping of output."""
out = self._torch_fidelity_forward(x)
return out[0].reshape(x.shape[0], -1)


def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor) -> Tensor:
r"""Compute adjusted version of `Fid Score`_.

Expand All @@ -194,6 +285,19 @@ def _compute_fid(mu1: Tensor, sigma1: Tensor, mu2: Tensor, sigma2: Tensor) -> Te
return a + b - 2 * c


# Map feature strings to valid configurations
FEATURE_MAP = {
"inception-64": (NoTrainInceptionV3, "64", "64", (3, 299, 299)),
"inception-192": (NoTrainInceptionV3, "192", "192", (3, 299, 299)),
"inception-768": (NoTrainInceptionV3, "768", "768", (3, 299, 299)),
"inception-2048": (NoTrainInceptionV3, "2048", "2048", (3, 299, 299)),
"dino-384": (NoTrainDinoV2, "dinov2-vit-s-14", "dinov2", (3, 224, 224)),
"dino-768": (NoTrainDinoV2, "dinov2-vit-b-14", "dinov2", (3, 224, 224)),
"dino-1024": (NoTrainDinoV2, "dinov2-vit-l-14", "dinov2", (3, 224, 224)),
"dino-1536": (NoTrainDinoV2, "dinov2-vit-g-14", "dinov2", (3, 224, 224)),
}


class FrechetInceptionDistance(Metric):
r"""Calculate Fréchet inception distance (FID_) which is used to assess the quality of generated images.

Expand Down Expand Up @@ -308,12 +412,12 @@ class FrechetInceptionDistance(Metric):
fake_features_cov_sum: Tensor
fake_features_num_samples: Tensor

inception: Module
feature_network: str = "inception"
feature_extractor: Module
feature_network: str = "feature_extractor"

def __init__(
self,
feature: Union[int, Module] = 2048,
feature: Union[int, str, Module] = "inception-2048",
reset_real_features: bool = True,
normalize: bool = False,
input_img_size: tuple[int, int, int] = (3, 299, 299),
Expand All @@ -329,42 +433,60 @@ def __init__(
self.used_custom_model = False
antialias = antialias

if isinstance(feature, int):
num_features = feature
if isinstance(feature, str):
if feature not in FEATURE_MAP:
raise ValueError(
f"String input to argument `feature` must be one of {list(FEATURE_MAP.keys())}, but got {feature}."
)
feature_extractor_cls, name, feature_layer, default_img_size = FEATURE_MAP[feature]
input_img_size = default_img_size

if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"FrechetInceptionDistance metric requires that `Torch-fidelity` is installed."
" Either install as `pip install torchmetrics[image]` or `pip install torch-fidelity`."
)
valid_int_input = (64, 192, 768, 2048)
if feature not in valid_int_input:

self.feature_extractor = feature_extractor_cls(
name=name,
features_list=[feature_layer],
feature_extractor_weights_path=feature_extractor_weights_path,
antialias=antialias,
)
num_features = int(feature.split("-")[-1])
elif isinstance(feature, int):
rank_zero_warn(
"Using an integer input to `feature` is deprecated and will be removed in v1.9."
"Instead, use a string input like 'inception-2048' or 'dino-768'.",
DeprecationWarning,
)
if feature not in [64, 192, 768, 2048]:
raise ValueError(
f"Integer input to argument `feature` must be one of {valid_int_input}, but got {feature}."
f"Integer input to argument `feature` must be one of [64, 192, 768, 2048], but got {feature}."
)

self.inception = NoTrainInceptionV3(
name="inception-v3-compat",
self.feature_extractor = NoTrainInceptionV3(
name=f"inception-{feature}",
features_list=[str(feature)],
feature_extractor_weights_path=feature_extractor_weights_path,
antialias=antialias,
)

num_features = feature
elif isinstance(feature, Module):
self.inception = feature
self.feature_extractor = feature
self.used_custom_model = True
if hasattr(self.inception, "num_features"):
if isinstance(self.inception.num_features, int):
num_features = self.inception.num_features
elif isinstance(self.inception.num_features, Tensor):
num_features = int(self.inception.num_features.item())
if hasattr(self.feature_extractor, "num_features"):
if isinstance(self.feature_extractor.num_features, int):
num_features = self.feature_extractor.num_features
elif isinstance(self.feature_extractor.num_features, Tensor):
num_features = int(self.feature_extractor.num_features.item())
else:
raise TypeError("Expected `self.inception.num_features` to be of type int or Tensor.")
raise TypeError("Expected `self.feature_extractor.num_features` to be of type int or Tensor.")
else:
if self.normalize:
dummy_image = torch.rand(1, *input_img_size, dtype=torch.float32)
else:
dummy_image = torch.randint(0, 255, (1, *input_img_size), dtype=torch.uint8)
num_features = self.inception(dummy_image).shape[-1]
num_features = self.feature_extractor(dummy_image).shape[-1]
else:
raise TypeError("Got unknown input to argument `feature`")

Expand All @@ -391,7 +513,7 @@ def update(self, imgs: Tensor, real: bool) -> None:

"""
imgs = (imgs * 255).byte() if self.normalize and (not self.used_custom_model) else imgs
features = self.inception(imgs)
features = self.feature_extractor(imgs)
self.orig_dtype = features.dtype
features = features.double()

Expand Down Expand Up @@ -440,8 +562,8 @@ def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "Metric":

"""
out = super().set_dtype(dst_type)
if isinstance(out.inception, NoTrainInceptionV3):
out.inception._dtype = dst_type
if isinstance(out.feature_extractor, (NoTrainInceptionV3, NoTrainDinoV2)):
out.feature_extractor._dtype = dst_type
return out

def plot(
Expand Down
Loading