Skip to content

Commit

Permalink
update: align image quality metrics with weave.Scorer + corresponding…
Browse files Browse the repository at this point in the history
… example
  • Loading branch information
soumik12345 committed Oct 30, 2024
1 parent dd55a90 commit 1b2f817
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 124 deletions.
34 changes: 14 additions & 20 deletions examples/evaluate_weave_image_quality.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,32 @@
import asyncio

import fire
import weave

import wandb
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric
from hemm.models import BaseDiffusionModel
from hemm.models import DiffusersModel


def main(
project_name: str = "image-quality",
diffusion_model_name_or_path="stabilityai/stable-diffusion-2-1",
dataset_ref: str = "COCO:v0",
image_height: int = 1024,
image_width: int = 1024,
):
wandb.init(project=project_name, job_type="evaluation")
weave.init(project_name=project_name)

model = BaseDiffusionModel(
diffusion_model_name_or_path=diffusion_model_name_or_path
)
evaluation_pipeline = EvaluationPipeline(model=model)

# Add PSNR Metric
psnr_metric = PSNRMetric(image_size=evaluation_pipeline.image_size)
evaluation_pipeline.add_metric(psnr_metric)
model = DiffusersModel(diffusion_model_name_or_path=diffusion_model_name_or_path)

# Add SSIM Metric
ssim_metric = SSIMMetric(image_size=evaluation_pipeline.image_size)
evaluation_pipeline.add_metric(ssim_metric)
psnr_metric = PSNRMetric(image_size=(image_height, image_width))
ssim_metric = SSIMMetric(image_size=(image_height, image_width))
lpips_metric = LPIPSMetric(image_size=(image_height, image_width))

# Add LPIPS Metric
lpips_metric = LPIPSMetric(image_size=evaluation_pipeline.image_size)
evaluation_pipeline.add_metric(lpips_metric)

evaluation_pipeline(dataset=dataset_ref)
dataset = weave.ref(dataset_ref).get()
evaluation = weave.Evaluation(
dataset=dataset, scorers=[psnr_metric, ssim_metric, lpips_metric]
)
asyncio.run(evaluation.evaluate(model))


if __name__ == "__main__":
Expand Down
25 changes: 4 additions & 21 deletions hemm/metrics/image_quality/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from abc import abstractmethod
from typing import Any, Dict, Union

import weave
from PIL import Image
from pydantic import BaseModel

from ..base import BaseMetric


class ComputeMetricOutput(BaseModel):
"""Output of the metric computation function."""
Expand All @@ -14,33 +13,18 @@ class ComputeMetricOutput(BaseModel):
ground_truth_image: str


class BaseImageQualityMetric(BaseMetric):

def __init__(self, name: str) -> None:
"""Base class for Image Quality Metrics.
Args:
name (str): Name of the metric.
"""
super().__init__()
self.scores = []
self.name = name
self.config = {}
class BaseImageQualityMetric(weave.Scorer):

@abstractmethod
def compute_metric(
self,
ground_truth_pil_image: Image.Image,
generated_pil_image: Image.Image,
prompt: str,
self, ground_truth_pil_image: Image.Image, generated_pil_image: Image.Image
) -> ComputeMetricOutput:
"""Compute the metric for the given images. This is an abstract
method and must be overriden by the child class implementation.
Args:
ground_truth_pil_image (Image.Image): Ground truth image in PIL format.
generated_pil_image (Image.Image): Generated image in PIL format.
prompt (str): Prompt for the image generation.
Returns:
ComputeMetricOutput: Output containing the metric score and ground truth image.
Expand All @@ -64,5 +48,4 @@ def evaluate(
metric_output = self.compute_metric(
ground_truth_image, model_output["image"], prompt
)
self.scores.append(metric_output.score)
return {self.name: metric_output.score}
return {"score": metric_output.score}
42 changes: 19 additions & 23 deletions hemm/metrics/image_quality/lpips.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from functools import partial
from typing import Any, Dict, Literal, Optional, Tuple, Union
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union

import numpy as np
import torch
import weave
from PIL import Image
from torchmetrics.functional.image import learned_perceptual_image_patch_similarity

from ...utils import base64_encode_image
from .base import BaseImageQualityMetric, ComputeMetricOutput


Expand All @@ -22,25 +21,30 @@ class LPIPSMetric(BaseImageQualityMetric):
or "squeeze".
image_size (Tuple[int, int]): The size to which images will be resized before computing
LPIPS.
name (str): The name of the metric.
"""

lpips_net_type: Literal["alex", "vgg", "squeeze"]
image_height: int
image_width: int
_lpips_metric: Callable

def __init__(
self,
lpips_net_type: Literal["alex", "vgg", "squeeze"] = "alex",
image_size: Optional[Tuple[int, int]] = (512, 512),
name: str = "alexnet_learned_perceptual_image_patch_similarity",
) -> None:
super().__init__(name)
self.image_size = image_size
self.lpips_metric = partial(
learned_perceptual_image_patch_similarity, net_type=lpips_net_type
super().__init__(
lpips_net_type=lpips_net_type,
image_height=image_size[0],
image_width=image_size[1],
)
self._lpips_metric = partial(
learned_perceptual_image_patch_similarity, net_type=self.lpips_net_type
)
self.config = {"lpips_net_type": lpips_net_type}

@weave.op()
def compute_metric(
self, ground_truth_pil_image: Image, generated_pil_image: Image, prompt: str
self, ground_truth_pil_image: Image, generated_pil_image: Image
) -> ComputeMetricOutput:
ground_truth_image = (
torch.from_numpy(
Expand All @@ -62,23 +66,15 @@ def compute_metric(
)
ground_truth_image = (ground_truth_image / 127.5) - 1.0
generated_image = (generated_image / 127.5) - 1.0
return ComputeMetricOutput(
score=float(
self.lpips_metric(generated_image, ground_truth_image).detach()
return {
"score": float(
self._lpips_metric(generated_image, ground_truth_image).detach()
),
ground_truth_image=base64_encode_image(ground_truth_pil_image),
)
"ground_truth_image": ground_truth_pil_image,
}

@weave.op()
def evaluate(
self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "LPIPSMetric"
return super().evaluate(prompt, ground_truth_image, model_output)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "LPIPSMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
52 changes: 23 additions & 29 deletions hemm/metrics/image_quality/psnr.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,51 @@
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
import weave
from PIL import Image
from torchmetrics.functional.image import peak_signal_noise_ratio

from ...utils import base64_encode_image
from .base import BaseImageQualityMetric, ComputeMetricOutput


class PSNRMetric(BaseImageQualityMetric):
"""PSNR Metric to compute the Peak Signal-to-Noise Ratio (PSNR) between two images.
Args:
psnr_base (float): The base of the logarithm in the PSNR formula.
psnr_data_range (Optional[Union[float, Tuple[float, float]]]): The data range of the input
image (min, max). If None, the data range is determined from the image data type.
psnr_base (float): The base of the logarithm in the PSNR formula.
image_size (Tuple[int, int]): The size to which images will be resized before computing
PSNR.
name (str): The name of the metric.
"""

psnr_base: float
psnr_data_range: Optional[Union[float, Tuple[float, float]]]
image_height: int
image_width: int
_psnr_metric: Callable

def __init__(
self,
psnr_data_range: Optional[Union[float, Tuple[float, float]]] = None,
psnr_base: float = 10.0,
image_size: Optional[Tuple[int, int]] = (512, 512),
name: str = "peak_signal_noise_ratio",
) -> None:
super().__init__(name)
self.image_size = image_size
self.psnr_metric = partial(
peak_signal_noise_ratio, data_range=psnr_data_range, base=psnr_base
super().__init__(
psnr_data_range=psnr_data_range,
psnr_base=psnr_base,
image_height=image_size[0],
image_width=image_size[1],
)
self._psnr_metric = partial(
peak_signal_noise_ratio, data_range=self.p, base=psnr_base
)
self.config = {
"psnr_base": psnr_base,
"psnr_data_range": psnr_data_range,
"image_size": image_size,
}

@weave.op()
def compute_metric(
self,
ground_truth_pil_image: Image.Image,
generated_pil_image: Image.Image,
prompt: str,
self, ground_truth_pil_image: Image.Image, generated_pil_image: Image.Image
) -> ComputeMetricOutput:
ground_truth_image = torch.from_numpy(
np.expand_dims(
Expand All @@ -58,21 +57,16 @@ def compute_metric(
np.array(generated_pil_image.resize(self.image_size)), axis=0
).astype(np.uint8)
).float()
return ComputeMetricOutput(
score=float(self.psnr_metric(generated_image, ground_truth_image).detach()),
ground_truth_image=base64_encode_image(ground_truth_pil_image),
)
return {
"score": float(
self._psnr_metric(generated_image, ground_truth_image).detach()
),
"ground_truth_image": ground_truth_pil_image,
}

@weave.op()
def evaluate(
self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "PSNRMetric"
return super().evaluate(prompt, ground_truth_image, model_output)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "PSNRMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
55 changes: 25 additions & 30 deletions hemm/metrics/image_quality/ssim.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union

import numpy as np
import torch
import weave
from PIL import Image
from torchmetrics.functional.image import structural_similarity_index_measure

from ...utils import base64_encode_image
from .base import BaseImageQualityMetric, ComputeMetricOutput


Expand All @@ -29,6 +28,14 @@ class SSIMMetric(BaseImageQualityMetric):
name (str): The name of the metric.
"""

ssim_gaussian_kernel: bool
ssim_sigma: float
ssim_kernel_size: int
ssim_data_range: Union[float, Tuple[float, float], None]
ssim_k1: float
ssim_k2: float
_ssim_metric: Callable

def __init__(
self,
ssim_gaussian_kernel: bool = True,
Expand All @@ -38,11 +45,18 @@ def __init__(
ssim_k1: float = 0.01,
ssim_k2: float = 0.03,
image_size: Optional[Tuple[int, int]] = (512, 512),
name: str = "structural_similarity_index_measure",
) -> None:
super().__init__(name)
self.image_size = image_size
self.ssim_metric = partial(
super().__init__(
ssim_gaussian_kernel=ssim_gaussian_kernel,
ssim_sigma=ssim_sigma,
ssim_kernel_size=ssim_kernel_size,
ssim_data_range=ssim_data_range,
ssim_k1=ssim_k1,
ssim_k2=ssim_k2,
image_height=image_size[0],
image_width=image_size[1],
)
self._ssim_metric = partial(
structural_similarity_index_measure,
gaussian_kernel=ssim_gaussian_kernel,
sigma=ssim_sigma,
Expand All @@ -51,21 +65,10 @@ def __init__(
k1=ssim_k1,
k2=ssim_k2,
)
self.config = {
"ssim_gaussian_kernel": ssim_gaussian_kernel,
"ssim_sigma": ssim_sigma,
"ssim_kernel_size": ssim_kernel_size,
"ssim_data_range": ssim_data_range,
"ssim_k1": ssim_k1,
"ssim_k2": ssim_k2,
}

@weave.op()
def compute_metric(
self,
ground_truth_pil_image: Image.Image,
generated_pil_image: Image.Image,
prompt: str,
self, ground_truth_pil_image: Image.Image, generated_pil_image: Image.Image
) -> ComputeMetricOutput:
ground_truth_image = (
torch.from_numpy(
Expand All @@ -85,21 +88,13 @@ def compute_metric(
.permute(0, 3, 1, 2)
.float()
)
return ComputeMetricOutput(
score=float(self.ssim_metric(generated_image, ground_truth_image)),
ground_truth_image=base64_encode_image(ground_truth_pil_image),
)
return {
"score": float(self._ssim_metric(generated_image, ground_truth_image)),
"ground_truth_image": ground_truth_pil_image,
}

@weave.op()
def evaluate(
self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "SSIMMetric"
return super().evaluate(prompt, ground_truth_image, model_output)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: Image.Image, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "SSIMMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
2 changes: 1 addition & 1 deletion hemm/metrics/prompt_alignment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float
Dict[str, float]: Metric score.
"""
score = self.compute_metric(model_output["image"], prompt)
return {self.name: score}
return {"score": score}

0 comments on commit 1b2f817

Please sign in to comment.