diff --git a/hemm/metrics/prompt_alignment/blip_score.py b/hemm/metrics/prompt_alignment/blip_score.py index 47fb774..4a32c03 100644 --- a/hemm/metrics/prompt_alignment/blip_score.py +++ b/hemm/metrics/prompt_alignment/blip_score.py @@ -1,14 +1,11 @@ -from typing import Any, Dict, Union +from typing import Any, Dict import weave -from PIL import Image from torch.nn import functional as F from transformers import BlipForConditionalGeneration, BlipProcessor -from .base import BasePromptAlignmentMetric - -class BLIPScoreMertric(BasePromptAlignmentMetric): +class BLIPScoreMertric(weave.Scorer): model_name: str = "Salesforce/blip-image-captioning-base" device: str = "cuda" _blip_processor: BlipProcessor @@ -26,30 +23,26 @@ def __init__( ) @weave.op() - def compute_metric( - self, pil_image: Image, prompt: str - ) -> Union[float, Dict[str, Any]]: + def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: pixel_values = self.blip_processor( - images=pil_image, return_tensors="pt" + images=model_output["image"], return_tensors="pt" ).pixel_values - text_input_ids = self.blip_processor( + text_input_ids = self._blip_processor( text=prompt, return_tensors="pt", padding=True, truncation=True ).input_ids - outputs = self.blip_model( + outputs = self._blip_model( pixel_values=pixel_values.to(self.device), input_ids=text_input_ids.to(self.device), ) logits = outputs.logits[:, :-1, :] shift_labels = text_input_ids[..., 1:].contiguous() - return float( - F.cross_entropy( - logits.view(-1, logits.size(-1)).to(self.device), - shift_labels.view(-1).to(self.device), + return { + "score": float( + F.cross_entropy( + logits.view(-1, logits.size(-1)).to(self.device), + shift_labels.view(-1).to(self.device), + ) + .detach() + .item() ) - .detach() - .item() - ) - - @weave.op() - def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: - return super().evaluate(prompt, model_output) + } diff --git a/hemm/metrics/prompt_alignment/clip_iqa_score.py b/hemm/metrics/prompt_alignment/clip_iqa_score.py index 6c70863..8a7a1b0 100644 --- a/hemm/metrics/prompt_alignment/clip_iqa_score.py +++ b/hemm/metrics/prompt_alignment/clip_iqa_score.py @@ -1,17 +1,14 @@ from functools import partial -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List import numpy as np import torch import weave -from PIL import Image from torchmetrics.functional.multimodal import clip_image_quality_assessment from tqdm.auto import tqdm -from .base import BasePromptAlignmentMetric - -class CLIPImageQualityScoreMetric(BasePromptAlignmentMetric): +class CLIPImageQualityScoreMetric(weave.Scorer): """[CLIP Image Quality Assessment](https://arxiv.org/abs/2207.12396) metric for to measuring the visual content of images. @@ -61,23 +58,20 @@ def __init__(self, model_name: str = "clip_iqa") -> None: ) @weave.op() - def compute_metric( - self, pil_image: Image, prompt: str - ) -> Union[float, Dict[str, float]]: - images = np.expand_dims(np.array(pil_image), axis=0).astype(np.uint8) / 255.0 + def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: + images = ( + np.expand_dims(np.array(model_output["image"]), axis=0).astype(np.uint8) + / 255.0 + ) score_dict = {} for prompt in tqdm( self.built_in_prompts, desc="Calculating IQA scores", leave=False ): clip_iqa_score = float( - self.clip_iqa_fn( + self._clip_iqa_fn( images=torch.from_numpy(images).permute(0, 3, 1, 2), prompts=tuple([prompt] * images.shape[0]), ).detach() ) score_dict[f"{self.name}_{prompt}"] = clip_iqa_score return score_dict - - @weave.op() - def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: - return super().evaluate(prompt, model_output) diff --git a/hemm/metrics/prompt_alignment/clip_score.py b/hemm/metrics/prompt_alignment/clip_score.py index d0ae39d..897f955 100644 --- a/hemm/metrics/prompt_alignment/clip_score.py +++ b/hemm/metrics/prompt_alignment/clip_score.py @@ -1,16 +1,13 @@ from functools import partial -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict import numpy as np import torch import weave -from PIL import Image from torchmetrics.functional.multimodal import clip_score -from .base import BasePromptAlignmentMetric - -class CLIPScoreMetric(BasePromptAlignmentMetric): +class CLIPScoreMetric(weave.Scorer): """[CLIP score](https://arxiv.org/abs/2104.08718) metric for text-to-image similarity. CLIP Score is a reference free metric that can be used to evaluate the correlation between a generated caption for an image and the actual content of the image. It has been found to @@ -27,17 +24,13 @@ def __init__(self, model_name: str = "openai/clip-vit-base-patch16") -> None: super().__init__(model_name=model_name) self._clip_score_fn = partial(clip_score, model_name_or_path=model_name) - @weave.op() - def compute_metric( - self, pil_image: Image.Image, prompt: str - ) -> Union[float, Dict[str, float]]: - images = np.expand_dims(np.array(pil_image), axis=0) - return float( - self.clip_score_fn( - torch.from_numpy(images).permute(0, 3, 1, 2), prompt - ).detach() - ) - @weave.op() def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]: - return super().evaluate(prompt, model_output) + images = np.expand_dims(np.array(model_output["image"]), axis=0) + return { + "score": float( + self._clip_score_fn( + torch.from_numpy(images).permute(0, 3, 1, 2), prompt + ).detach() + ) + }