Skip to content

Commit

Permalink
update: remove dependency on evaluation pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Oct 30, 2024
1 parent 40b2095 commit dc9579e
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 214 deletions.
170 changes: 9 additions & 161 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,192 +1,40 @@
import asyncio
import os
import shutil
from abc import ABC
from typing import Dict, List, Optional, Union
from typing import Dict, List, Union

import weave
from PIL import Image

import wandb

from ..metrics.base import BaseMetric
from ..models import BaseDiffusionModel, FalAIModel, StabilityAPIModel

MODEL_TYPE = Union[BaseDiffusionModel, FalAIModel, StabilityAPIModel]


class EvaluationPipeline(ABC):
"""Evaluation pipeline to evaluate the a multi-modal generative model.
Args:
model (BaseDiffusionModel): The model to evaluate.
seed (int): Seed value for the random number generator.
mock_inference_dataset_address (Optional[str]): A wandb dataset artifact address which if
provided will mock inference results. This prevents the need for redundant generations
when switching metrics/judges with the same evaluation datset(s).
save_inference_dataset_name (Optional[str]): A weave dataset name which if provided will
save inference results as a separate weave dataset.
"""

def __init__(
self,
model: MODEL_TYPE,
seed: int = 42,
mock_inference_dataset_address: Optional[str] = None,
save_inference_dataset_name: Optional[str] = None,
) -> None:
def __init__(self, model: weave.Model) -> None:
super().__init__()
self.model = model

self.image_size = (self.model.image_height, self.model.image_width)
self.seed = seed
self.mock_inference_dataset_address = mock_inference_dataset_address
if mock_inference_dataset_address:
self.save_inference_dataset_name = None
artifact = wandb.use_artifact(
self.mock_inference_dataset_address, type="dataset"
)
self.mock_inference_dataset_dir = artifact.download()

else:
self.save_inference_dataset_name = save_inference_dataset_name

if self.save_inference_dataset_name:
os.makedirs(
os.path.join("inference_dataset", self.save_inference_dataset_name),
exist_ok=True,
)
self.scorers = []

self.inference_counter = 0
self.table_columns = ["model", "prompt", "generated_image"]
self.table_rows: List = []
self.evaluation_table: wandb.Table = None
self.metric_functions: List[BaseMetric] = []

self.evaluation_configs = {
"pretrained_model_name_or_path": self.model.diffusion_model_name_or_path,
"torch_dtype": str(self.model._torch_dtype),
"enable_cpu_offfload": self.model.enable_cpu_offfload,
"image_size": {
"height": self.image_size[0],
"width": self.image_size[1],
},
"seed": seed,
"diffusion_pipeline": dict(self.model._pipeline.config),
}

def add_metric(self, metric_fn: BaseMetric):
def add_metric(self, metric: Union[callable, weave.Scorer]):
"""Add a metric function to the evaluation pipeline.
Args:
metric_fn (BaseMetric): Metric function to evaluate the generated images.
"""
self.table_columns.append(metric_fn.__class__.__name__)
self.evaluation_configs.update(metric_fn.config)
self.metric_functions.append(metric_fn)

@weave.op()
def infer(self, prompt: str) -> Dict[str, str]:
"""Inference function to generate images for the given prompt.
Args:
prompt (str): Prompt to generate the image.
Returns:
Dict[str, str]: Dictionary containing base64 encoded image to be logged as
a Weave object.
"""
if self.inference_counter == 0:
self.evaluation_table = wandb.Table(columns=self.table_columns)
if self.mock_inference_dataset_address:
image = Image.open(
os.path.join(
self.mock_inference_dataset_dir, f"{self.inference_counter}.png"
)
)
output = {"image": image}
else:
output = self.model.predict(prompt, seed=self.seed)
self.table_rows.append(
[self.model.diffusion_model_name_or_path, prompt, output["image"]]
)
if self.save_inference_dataset_name:
output["image"].save(
os.path.join(
"inference_dataset",
self.save_inference_dataset_name,
f"{self.inference_counter}.png",
)
)
self.inference_counter += 1
return output

@weave.op()
async def infer_async(self, prompt: str) -> Dict[str, str]:
"""Async inference function to generate images for the given prompt.
Args:
prompt (str): Prompt to generate the image.
Returns:
Dict[str, str]: Dictionary containing base64 encoded image to be logged as
a Weave object.
metric (BaseMetric): Metric function to evaluate the generated images.
"""
return self.infer(prompt)

def log_summary(self, summary: Dict[str, float]) -> None:
"""Log the evaluation summary to the Weights & Biases dashboard."""
config = wandb.config
config.update(self.evaluation_configs)
for row_idx, row in enumerate(self.table_rows):
current_row = row
current_row[-1] = wandb.Image(current_row[-1])
for metric_fn in self.metric_functions:
current_row.append(metric_fn.scores[row_idx])
self.evaluation_table.add_data(*current_row)
summary_table = wandb.Table(columns=["summary"], data=[[summary]])
wandb.log(
{
"evalution": self.evaluation_table,
"summary": summary_table,
"evaluation_summary": summary,
}
)

def save_inference_results(self):
artifact = wandb.Artifact(name=self.save_inference_dataset_name, type="dataset")
artifact.add_dir(
os.path.join("inference_dataset", self.save_inference_dataset_name)
)
artifact.save()

def cleanup(self):
"""Cleanup the inference dataset directory. Should be called after the evaluation is complete
and `wandb.finish()` is called."""
if os.path.exists("inference_dataset"):
shutil.rmtree("inference_dataset")
self.scorers.append(metric)

def __call__(
self, dataset: Union[List[Dict], str], async_infer: bool = False
) -> Dict[str, float]:
def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]:
"""Evaluate the Stable Diffusion model on the given dataset.
Args:
dataset (Union[List[Dict], str]): Dataset to evaluate the model on. If a string is
passed, it is assumed to be a Weave dataset reference.
async_infer (bool, optional): Whether to use async inference. Defaults to False.
"""
dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset
evaluation = weave.Evaluation(
dataset=dataset,
scorers=[
metric_fn.evaluate_async if async_infer else metric_fn.evaluate
for metric_fn in self.metric_functions
],
)
self.model.configs.update(self.evaluation_configs)
summary = asyncio.run(evaluation.evaluate(self.infer_async))
self.log_summary(summary)
if self.save_inference_dataset_name:
self.save_inference_results()
evaluation = weave.Evaluation(dataset=dataset, scorers=self.scorers)
summary = asyncio.run(evaluation.evaluate())
return summary
7 changes: 2 additions & 5 deletions hemm/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from .prompt_alignment import (
BLIPScoreMertric,
CLIPImageQualityScoreMetric,
CLIPScoreMetric,
)
from .prompt_alignment import (BLIPScoreMertric, CLIPImageQualityScoreMetric,
CLIPScoreMetric)

__all__ = ["BLIPScoreMertric", "CLIPImageQualityScoreMetric", "CLIPScoreMetric"]
3 changes: 2 additions & 1 deletion hemm/metrics/image_quality/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
import weave
from PIL import Image
from torchmetrics.functional.image import learned_perceptual_image_patch_similarity
from torchmetrics.functional.image import \
learned_perceptual_image_patch_similarity

from ...utils import base64_encode_image
from .base import BaseImageQualityMetric, ComputeMetricOutput
Expand Down
6 changes: 4 additions & 2 deletions hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from pydantic import BaseModel

from .....utils import base64_encode_image
from .commons import JudgeMent, JudgeQuestion, PromptCategory, TaggedPromptParts
from .commons import (JudgeMent, JudgeQuestion, PromptCategory,
TaggedPromptParts)


class OpenAIJudgeMent(BaseModel):
Expand Down Expand Up @@ -102,6 +103,7 @@ def frame_question(self, prompt: str, image: Image.Image) -> List[JudgeQuestion]
Returns:
List[JudgeQuestion]: List of questions to ask for the given prompt.
"""
prompt = str(prompt)
if self.prompt_property in [PromptCategory.spatial, PromptCategory.spatial_3d]:
self._total_score = 5
question = JudgeQuestion(
Expand Down Expand Up @@ -309,7 +311,7 @@ def execute_chain_of_thought(
Provide your analysis and explanation to justify the score.
"""
judgement_response = (
self._openai_client.beta.chat.completions.parse(
weave.op()(self._openai_client.beta.chat.completions.parse)(
model=self.openai_model,
response_format=JudgeMent,
seed=self.seed,
Expand Down
29 changes: 5 additions & 24 deletions hemm/metrics/vqa/multi_modal_llm_eval.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,22 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List

import weave

from ..base import BaseMetric
from .judges.mmllm_judges import OpenAIJudge
from .judges.mmllm_judges.openai_judge import OpenAIJudgeMent


class MultiModalLLMEvaluationMetric(BaseMetric):
class MultiModalLLMEvaluationMetric(weave.Scorer):
"""Multi-modal LLM-based evaluation metric for an image-generation model.
Args:
judge (Union[weave.Model, OpenAIJudge]): The judge LLM model to evaluate the generated images.
name (Optional[str]): Name of the evaluation.
judge (OpenAIJudge): The judge LLM model to evaluate the generated images.
"""

def __init__(
self,
judge: Union[weave.Model, OpenAIJudge],
name: Optional[str] = "mmllm_eval_metric",
) -> None:
super().__init__()
self.judge = judge
self.config = self.judge.model_dump()
self.prompt_property = judge.prompt_property
self.scores = []
self.name = name
judge: OpenAIJudge

@weave.op()
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]:
def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]:
"""Evaluate the generated image using the judge LLM model.
Args:
Expand All @@ -44,11 +32,4 @@ def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]:
"score": score / len(judgements),
"fractional_score": fractional_score / len(judgements),
}
self.scores.append(evaluation_dict)
return evaluation_dict

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
) -> Dict[str, Any]:
return self.evaluate(prompt, model_output)
4 changes: 2 additions & 2 deletions hemm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .diffusion_model import BaseDiffusionModel
from .diffusion_model import DiffusersModel
from .falai_model import FalAIModel
from .stability_model import StabilityAPIModel

__all__ = ["BaseDiffusionModel", "FalAIModel", "StabilityAPIModel"]
__all__ = ["DiffusersModel", "FalAIModel", "StabilityAPIModel"]
Loading

0 comments on commit dc9579e

Please sign in to comment.