From dc9579e7108b58e18fe8e0f5e89dc482e8f580da Mon Sep 17 00:00:00 2001 From: Soumik Rakshit <19soumik.rakshit96@gmail.com> Date: Wed, 30 Oct 2024 15:39:20 +0000 Subject: [PATCH] update: remove dependency on evaluation pipeline --- hemm/eval_pipelines/eval_pipeline.py | 170 +----------------- hemm/metrics/__init__.py | 7 +- hemm/metrics/image_quality/lpips.py | 3 +- .../vqa/judges/mmllm_judges/openai_judge.py | 6 +- hemm/metrics/vqa/multi_modal_llm_eval.py | 29 +-- hemm/models/__init__.py | 4 +- hemm/models/diffusion_model.py | 39 ++-- .../test_2d_spatial_relationship_eval.py | 4 +- hemm/tests/test_prompt_alignment_eval.py | 3 +- 9 files changed, 51 insertions(+), 214 deletions(-) diff --git a/hemm/eval_pipelines/eval_pipeline.py b/hemm/eval_pipelines/eval_pipeline.py index f241a4e..d031b49 100644 --- a/hemm/eval_pipelines/eval_pipeline.py +++ b/hemm/eval_pipelines/eval_pipeline.py @@ -1,18 +1,8 @@ import asyncio -import os -import shutil from abc import ABC -from typing import Dict, List, Optional, Union +from typing import Dict, List, Union import weave -from PIL import Image - -import wandb - -from ..metrics.base import BaseMetric -from ..models import BaseDiffusionModel, FalAIModel, StabilityAPIModel - -MODEL_TYPE = Union[BaseDiffusionModel, FalAIModel, StabilityAPIModel] class EvaluationPipeline(ABC): @@ -20,173 +10,31 @@ class EvaluationPipeline(ABC): Args: model (BaseDiffusionModel): The model to evaluate. - seed (int): Seed value for the random number generator. - mock_inference_dataset_address (Optional[str]): A wandb dataset artifact address which if - provided will mock inference results. This prevents the need for redundant generations - when switching metrics/judges with the same evaluation datset(s). - save_inference_dataset_name (Optional[str]): A weave dataset name which if provided will - save inference results as a separate weave dataset. """ - def __init__( - self, - model: MODEL_TYPE, - seed: int = 42, - mock_inference_dataset_address: Optional[str] = None, - save_inference_dataset_name: Optional[str] = None, - ) -> None: + def __init__(self, model: weave.Model) -> None: super().__init__() self.model = model self.image_size = (self.model.image_height, self.model.image_width) - self.seed = seed - self.mock_inference_dataset_address = mock_inference_dataset_address - if mock_inference_dataset_address: - self.save_inference_dataset_name = None - artifact = wandb.use_artifact( - self.mock_inference_dataset_address, type="dataset" - ) - self.mock_inference_dataset_dir = artifact.download() - - else: - self.save_inference_dataset_name = save_inference_dataset_name - - if self.save_inference_dataset_name: - os.makedirs( - os.path.join("inference_dataset", self.save_inference_dataset_name), - exist_ok=True, - ) + self.scorers = [] - self.inference_counter = 0 - self.table_columns = ["model", "prompt", "generated_image"] - self.table_rows: List = [] - self.evaluation_table: wandb.Table = None - self.metric_functions: List[BaseMetric] = [] - - self.evaluation_configs = { - "pretrained_model_name_or_path": self.model.diffusion_model_name_or_path, - "torch_dtype": str(self.model._torch_dtype), - "enable_cpu_offfload": self.model.enable_cpu_offfload, - "image_size": { - "height": self.image_size[0], - "width": self.image_size[1], - }, - "seed": seed, - "diffusion_pipeline": dict(self.model._pipeline.config), - } - - def add_metric(self, metric_fn: BaseMetric): + def add_metric(self, metric: Union[callable, weave.Scorer]): """Add a metric function to the evaluation pipeline. Args: - metric_fn (BaseMetric): Metric function to evaluate the generated images. - """ - self.table_columns.append(metric_fn.__class__.__name__) - self.evaluation_configs.update(metric_fn.config) - self.metric_functions.append(metric_fn) - - @weave.op() - def infer(self, prompt: str) -> Dict[str, str]: - """Inference function to generate images for the given prompt. - - Args: - prompt (str): Prompt to generate the image. - - Returns: - Dict[str, str]: Dictionary containing base64 encoded image to be logged as - a Weave object. - """ - if self.inference_counter == 0: - self.evaluation_table = wandb.Table(columns=self.table_columns) - if self.mock_inference_dataset_address: - image = Image.open( - os.path.join( - self.mock_inference_dataset_dir, f"{self.inference_counter}.png" - ) - ) - output = {"image": image} - else: - output = self.model.predict(prompt, seed=self.seed) - self.table_rows.append( - [self.model.diffusion_model_name_or_path, prompt, output["image"]] - ) - if self.save_inference_dataset_name: - output["image"].save( - os.path.join( - "inference_dataset", - self.save_inference_dataset_name, - f"{self.inference_counter}.png", - ) - ) - self.inference_counter += 1 - return output - - @weave.op() - async def infer_async(self, prompt: str) -> Dict[str, str]: - """Async inference function to generate images for the given prompt. - - Args: - prompt (str): Prompt to generate the image. - - Returns: - Dict[str, str]: Dictionary containing base64 encoded image to be logged as - a Weave object. + metric (BaseMetric): Metric function to evaluate the generated images. """ - return self.infer(prompt) - - def log_summary(self, summary: Dict[str, float]) -> None: - """Log the evaluation summary to the Weights & Biases dashboard.""" - config = wandb.config - config.update(self.evaluation_configs) - for row_idx, row in enumerate(self.table_rows): - current_row = row - current_row[-1] = wandb.Image(current_row[-1]) - for metric_fn in self.metric_functions: - current_row.append(metric_fn.scores[row_idx]) - self.evaluation_table.add_data(*current_row) - summary_table = wandb.Table(columns=["summary"], data=[[summary]]) - wandb.log( - { - "evalution": self.evaluation_table, - "summary": summary_table, - "evaluation_summary": summary, - } - ) - - def save_inference_results(self): - artifact = wandb.Artifact(name=self.save_inference_dataset_name, type="dataset") - artifact.add_dir( - os.path.join("inference_dataset", self.save_inference_dataset_name) - ) - artifact.save() - - def cleanup(self): - """Cleanup the inference dataset directory. Should be called after the evaluation is complete - and `wandb.finish()` is called.""" - if os.path.exists("inference_dataset"): - shutil.rmtree("inference_dataset") + self.scorers.append(metric) - def __call__( - self, dataset: Union[List[Dict], str], async_infer: bool = False - ) -> Dict[str, float]: + def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]: """Evaluate the Stable Diffusion model on the given dataset. Args: dataset (Union[List[Dict], str]): Dataset to evaluate the model on. If a string is passed, it is assumed to be a Weave dataset reference. - async_infer (bool, optional): Whether to use async inference. Defaults to False. """ dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset - evaluation = weave.Evaluation( - dataset=dataset, - scorers=[ - metric_fn.evaluate_async if async_infer else metric_fn.evaluate - for metric_fn in self.metric_functions - ], - ) - self.model.configs.update(self.evaluation_configs) - summary = asyncio.run(evaluation.evaluate(self.infer_async)) - self.log_summary(summary) - if self.save_inference_dataset_name: - self.save_inference_results() + evaluation = weave.Evaluation(dataset=dataset, scorers=self.scorers) + summary = asyncio.run(evaluation.evaluate()) return summary diff --git a/hemm/metrics/__init__.py b/hemm/metrics/__init__.py index 3b3e6a4..380ab11 100644 --- a/hemm/metrics/__init__.py +++ b/hemm/metrics/__init__.py @@ -1,7 +1,4 @@ -from .prompt_alignment import ( - BLIPScoreMertric, - CLIPImageQualityScoreMetric, - CLIPScoreMetric, -) +from .prompt_alignment import (BLIPScoreMertric, CLIPImageQualityScoreMetric, + CLIPScoreMetric) __all__ = ["BLIPScoreMertric", "CLIPImageQualityScoreMetric", "CLIPScoreMetric"] diff --git a/hemm/metrics/image_quality/lpips.py b/hemm/metrics/image_quality/lpips.py index fe67f95..45e5b2b 100644 --- a/hemm/metrics/image_quality/lpips.py +++ b/hemm/metrics/image_quality/lpips.py @@ -5,7 +5,8 @@ import torch import weave from PIL import Image -from torchmetrics.functional.image import learned_perceptual_image_patch_similarity +from torchmetrics.functional.image import \ + learned_perceptual_image_patch_similarity from ...utils import base64_encode_image from .base import BaseImageQualityMetric, ComputeMetricOutput diff --git a/hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py b/hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py index b3fbd75..2e8ec35 100644 --- a/hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py +++ b/hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py @@ -9,7 +9,8 @@ from pydantic import BaseModel from .....utils import base64_encode_image -from .commons import JudgeMent, JudgeQuestion, PromptCategory, TaggedPromptParts +from .commons import (JudgeMent, JudgeQuestion, PromptCategory, + TaggedPromptParts) class OpenAIJudgeMent(BaseModel): @@ -102,6 +103,7 @@ def frame_question(self, prompt: str, image: Image.Image) -> List[JudgeQuestion] Returns: List[JudgeQuestion]: List of questions to ask for the given prompt. """ + prompt = str(prompt) if self.prompt_property in [PromptCategory.spatial, PromptCategory.spatial_3d]: self._total_score = 5 question = JudgeQuestion( @@ -309,7 +311,7 @@ def execute_chain_of_thought( Provide your analysis and explanation to justify the score. """ judgement_response = ( - self._openai_client.beta.chat.completions.parse( + weave.op()(self._openai_client.beta.chat.completions.parse)( model=self.openai_model, response_format=JudgeMent, seed=self.seed, diff --git a/hemm/metrics/vqa/multi_modal_llm_eval.py b/hemm/metrics/vqa/multi_modal_llm_eval.py index 94abe84..29af903 100644 --- a/hemm/metrics/vqa/multi_modal_llm_eval.py +++ b/hemm/metrics/vqa/multi_modal_llm_eval.py @@ -1,34 +1,22 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List import weave -from ..base import BaseMetric from .judges.mmllm_judges import OpenAIJudge from .judges.mmllm_judges.openai_judge import OpenAIJudgeMent -class MultiModalLLMEvaluationMetric(BaseMetric): +class MultiModalLLMEvaluationMetric(weave.Scorer): """Multi-modal LLM-based evaluation metric for an image-generation model. Args: - judge (Union[weave.Model, OpenAIJudge]): The judge LLM model to evaluate the generated images. - name (Optional[str]): Name of the evaluation. + judge (OpenAIJudge): The judge LLM model to evaluate the generated images. """ - def __init__( - self, - judge: Union[weave.Model, OpenAIJudge], - name: Optional[str] = "mmllm_eval_metric", - ) -> None: - super().__init__() - self.judge = judge - self.config = self.judge.model_dump() - self.prompt_property = judge.prompt_property - self.scores = [] - self.name = name + judge: OpenAIJudge @weave.op() - def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]: + def score(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]: """Evaluate the generated image using the judge LLM model. Args: @@ -44,11 +32,4 @@ def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, Any]: "score": score / len(judgements), "fractional_score": fractional_score / len(judgements), } - self.scores.append(evaluation_dict) return evaluation_dict - - @weave.op() - async def evaluate_async( - self, prompt: str, model_output: Dict[str, Any] - ) -> Dict[str, Any]: - return self.evaluate(prompt, model_output) diff --git a/hemm/models/__init__.py b/hemm/models/__init__.py index 883d860..57271a4 100644 --- a/hemm/models/__init__.py +++ b/hemm/models/__init__.py @@ -1,5 +1,5 @@ -from .diffusion_model import BaseDiffusionModel +from .diffusion_model import DiffusersModel from .falai_model import FalAIModel from .stability_model import StabilityAPIModel -__all__ = ["BaseDiffusionModel", "FalAIModel", "StabilityAPIModel"] +__all__ = ["DiffusersModel", "FalAIModel", "StabilityAPIModel"] diff --git a/hemm/models/diffusion_model.py b/hemm/models/diffusion_model.py index 2cfe0ba..3d85ec0 100644 --- a/hemm/models/diffusion_model.py +++ b/hemm/models/diffusion_model.py @@ -5,7 +5,7 @@ from diffusers import DiffusionPipeline -class BaseDiffusionModel(weave.Model): +class DiffusersModel(weave.Model): """`weave.Model` wrapping `diffusers.DiffusionPipeline`. Args: @@ -16,19 +16,18 @@ class BaseDiffusionModel(weave.Model): num_inference_steps (int): The number of inference steps. disable_safety_checker (bool): Disable safety checker for the diffusion model. configs (Dict[str, Any]): Additional configs. - pipeline_configs (Dict[str, Any]): Diffusion pipeline configs. inference_kwargs (Dict[str, Any]): Inference kwargs. """ diffusion_model_name_or_path: str enable_cpu_offfload: bool = False - image_height: int = 512 - image_width: int = 512 - num_inference_steps: int = 50 - disable_safety_checker: bool = True - configs: Dict[str, Any] = {} - pipeline_configs: Dict[str, Any] = {} - inference_kwargs: Dict[str, Any] = {} + image_height: int + image_width: int + num_inference_steps: int + seed: int + disable_safety_checker: bool + configs: Dict[str, Any] + inference_kwargs: Dict[str, Any] _torch_dtype: torch.dtype = torch.float16 _pipeline: DiffusionPipeline = None @@ -39,9 +38,9 @@ def __init__( image_height: int = 512, image_width: int = 512, num_inference_steps: int = 50, + seed: int = 42, disable_safety_checker: bool = True, configs: Dict[str, Any] = {}, - pipeline_configs: Dict[str, Any] = {}, inference_kwargs: Dict[str, Any] = {}, ) -> None: super().__init__( @@ -50,17 +49,15 @@ def __init__( image_height=image_height, image_width=image_width, num_inference_steps=num_inference_steps, + seed=seed, disable_safety_checker=disable_safety_checker, configs=configs, - pipeline_configs=pipeline_configs, inference_kwargs=inference_kwargs, ) - self.configs["torch_dtype"] = str(self._torch_dtype) pipeline_init_kwargs = { "pretrained_model_name_or_path": self.diffusion_model_name_or_path, "torch_dtype": self._torch_dtype, } - pipeline_init_kwargs.update(self.pipeline_configs) if self.disable_safety_checker: pipeline_init_kwargs["safety_checker"] = None self._pipeline = DiffusionPipeline.from_pretrained(**pipeline_init_kwargs) @@ -70,14 +67,26 @@ def __init__( self._pipeline = self._pipeline.to("cuda") self._pipeline.set_progress_bar_config(leave=False, desc="Generating Image") + self.configs = { + **self.configs, + "torch_dtype": str(self._torch_dtype), + "pretrained_model_name_or_path": self.diffusion_model_name_or_path, + "enable_cpu_offfload": self.enable_cpu_offfload, + "image_size": { + "height": self.image_height, + "width": self.image_width, + }, + "diffusion_pipeline": dict(self._pipeline.config), + } + @weave.op() - def predict(self, prompt: str, seed: int) -> Dict[str, Any]: + def predict(self, prompt: str) -> Dict[str, Any]: pipeline_output = self._pipeline( prompt, num_images_per_prompt=1, height=self.image_height, width=self.image_width, - generator=torch.Generator(device="cuda").manual_seed(seed), + generator=torch.Generator(device="cuda").manual_seed(self.seed), num_inference_steps=self.num_inference_steps, **self.inference_kwargs, ) diff --git a/hemm/tests/test_2d_spatial_relationship_eval.py b/hemm/tests/test_2d_spatial_relationship_eval.py index 4cd3f2d..86a2009 100644 --- a/hemm/tests/test_2d_spatial_relationship_eval.py +++ b/hemm/tests/test_2d_spatial_relationship_eval.py @@ -6,9 +6,7 @@ from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline from hemm.metrics.spatial_relationship import SpatialRelationshipMetric2D from hemm.metrics.spatial_relationship.judges import ( - DETRSpatialRelationShipJudge, - RTDETRSpatialRelationShipJudge, -) + DETRSpatialRelationShipJudge, RTDETRSpatialRelationShipJudge) class Test2DSpatialRelationshipEval(unittest.TestCase): diff --git a/hemm/tests/test_prompt_alignment_eval.py b/hemm/tests/test_prompt_alignment_eval.py index a876455..148884c 100644 --- a/hemm/tests/test_prompt_alignment_eval.py +++ b/hemm/tests/test_prompt_alignment_eval.py @@ -4,7 +4,8 @@ import wandb from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline -from hemm.metrics.prompt_alignment import CLIPImageQualityScoreMetric, CLIPScoreMetric +from hemm.metrics.prompt_alignment import (CLIPImageQualityScoreMetric, + CLIPScoreMetric) class TestPromptAlignmentEvaluation(unittest.TestCase):