Skip to content

Commit

Permalink
refactor: prompt alignment metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Oct 31, 2024
1 parent 82e3702 commit 99bbc60
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 53 deletions.
37 changes: 15 additions & 22 deletions hemm/metrics/prompt_alignment/blip_score.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from typing import Any, Dict, Union
from typing import Any, Dict

import weave
from PIL import Image
from torch.nn import functional as F
from transformers import BlipForConditionalGeneration, BlipProcessor

from .base import BasePromptAlignmentMetric


class BLIPScoreMertric(BasePromptAlignmentMetric):
class BLIPScoreMertric(weave.Scorer):
model_name: str = "Salesforce/blip-image-captioning-base"
device: str = "cuda"
_blip_processor: BlipProcessor
Expand All @@ -26,30 +23,26 @@ def __init__(
)

@weave.op()
def compute_metric(
self, pil_image: Image, prompt: str
) -> Union[float, Dict[str, Any]]:
def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
pixel_values = self.blip_processor(
images=pil_image, return_tensors="pt"
images=model_output["image"], return_tensors="pt"
).pixel_values
text_input_ids = self.blip_processor(
text_input_ids = self._blip_processor(
text=prompt, return_tensors="pt", padding=True, truncation=True
).input_ids
outputs = self.blip_model(
outputs = self._blip_model(
pixel_values=pixel_values.to(self.device),
input_ids=text_input_ids.to(self.device),
)
logits = outputs.logits[:, :-1, :]
shift_labels = text_input_ids[..., 1:].contiguous()
return float(
F.cross_entropy(
logits.view(-1, logits.size(-1)).to(self.device),
shift_labels.view(-1).to(self.device),
return {
"score": float(
F.cross_entropy(
logits.view(-1, logits.size(-1)).to(self.device),
shift_labels.view(-1).to(self.device),
)
.detach()
.item()
)
.detach()
.item()
)

@weave.op()
def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
return super().evaluate(prompt, model_output)
}
22 changes: 8 additions & 14 deletions hemm/metrics/prompt_alignment/clip_iqa_score.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from functools import partial
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List

import numpy as np
import torch
import weave
from PIL import Image
from torchmetrics.functional.multimodal import clip_image_quality_assessment
from tqdm.auto import tqdm

from .base import BasePromptAlignmentMetric


class CLIPImageQualityScoreMetric(BasePromptAlignmentMetric):
class CLIPImageQualityScoreMetric(weave.Scorer):
"""[CLIP Image Quality Assessment](https://arxiv.org/abs/2207.12396) metric
for to measuring the visual content of images.
Expand Down Expand Up @@ -61,23 +58,20 @@ def __init__(self, model_name: str = "clip_iqa") -> None:
)

@weave.op()
def compute_metric(
self, pil_image: Image, prompt: str
) -> Union[float, Dict[str, float]]:
images = np.expand_dims(np.array(pil_image), axis=0).astype(np.uint8) / 255.0
def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
images = (
np.expand_dims(np.array(model_output["image"]), axis=0).astype(np.uint8)
/ 255.0
)
score_dict = {}
for prompt in tqdm(
self.built_in_prompts, desc="Calculating IQA scores", leave=False
):
clip_iqa_score = float(
self.clip_iqa_fn(
self._clip_iqa_fn(
images=torch.from_numpy(images).permute(0, 3, 1, 2),
prompts=tuple([prompt] * images.shape[0]),
).detach()
)
score_dict[f"{self.name}_{prompt}"] = clip_iqa_score
return score_dict

@weave.op()
def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
return super().evaluate(prompt, model_output)
27 changes: 10 additions & 17 deletions hemm/metrics/prompt_alignment/clip_score.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from functools import partial
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Dict

import numpy as np
import torch
import weave
from PIL import Image
from torchmetrics.functional.multimodal import clip_score

from .base import BasePromptAlignmentMetric


class CLIPScoreMetric(BasePromptAlignmentMetric):
class CLIPScoreMetric(weave.Scorer):
"""[CLIP score](https://arxiv.org/abs/2104.08718) metric for text-to-image similarity.
CLIP Score is a reference free metric that can be used to evaluate the correlation between
a generated caption for an image and the actual content of the image. It has been found to
Expand All @@ -27,17 +24,13 @@ def __init__(self, model_name: str = "openai/clip-vit-base-patch16") -> None:
super().__init__(model_name=model_name)
self._clip_score_fn = partial(clip_score, model_name_or_path=model_name)

@weave.op()
def compute_metric(
self, pil_image: Image.Image, prompt: str
) -> Union[float, Dict[str, float]]:
images = np.expand_dims(np.array(pil_image), axis=0)
return float(
self.clip_score_fn(
torch.from_numpy(images).permute(0, 3, 1, 2), prompt
).detach()
)

@weave.op()
def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
return super().evaluate(prompt, model_output)
images = np.expand_dims(np.array(model_output["image"]), axis=0)
return {
"score": float(
self._clip_score_fn(
torch.from_numpy(images).permute(0, 3, 1, 2), prompt
).detach()
)
}

0 comments on commit 99bbc60

Please sign in to comment.