Skip to content

Commit

Permalink
Merge pull request #23 from wandb/refactor/models
Browse files Browse the repository at this point in the history
feat(models): Refactor models API
  • Loading branch information
soumik12345 authored Oct 18, 2024
2 parents 782ca64 + 0d085d6 commit 40b2095
Show file tree
Hide file tree
Showing 32 changed files with 413 additions and 7,173 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,7 @@ dump/
.vscode/
**generated_images/
.ruff_cache/
**.jsonl
**.jsonl
test.py
inference_dataset/
uv.lock
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ import wandb
import weave


from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.prompt_alignment import CLIPImageQualityScoreMetric, CLIPScoreMetric
from hemm.models import BaseDiffusionModel


# Initialize Weave and WandB
Expand Down
3 changes: 3 additions & 0 deletions docs/models/diffusion_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Diffusion Models

::: hemm.models.diffusion_model
3 changes: 3 additions & 0 deletions docs/models/falai_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# FalAI Models

::: hemm.models.falai_model
3 changes: 3 additions & 0 deletions docs/models/stability_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# StabilityAI Model

::: hemm.models.stability_model
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional, Tuple

import fire
import wandb
import weave

from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
import wandb
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.spatial_relationship import SpatialRelationshipMetric2D
from hemm.metrics.spatial_relationship.judges import DETRSpatialRelationShipJudge
from hemm.models import BaseDiffusionModel


def main(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional, Tuple

import fire
import wandb
import weave

from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
import wandb
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.spatial_relationship import SpatialRelationshipMetric2D
from hemm.metrics.spatial_relationship.judges import RTDETRSpatialRelationShipJudge
from hemm.models import BaseDiffusionModel


def main(
Expand Down
5 changes: 3 additions & 2 deletions examples/disentangled_vqa/evaluate_disentangled_vqa.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional, Tuple

import fire
import wandb
import weave

from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
import wandb
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.vqa import DisentangledVQAMetric
from hemm.metrics.vqa.judges import BlipVQAJudge
from hemm.models import BaseDiffusionModel


def main(
Expand Down
5 changes: 3 additions & 2 deletions examples/evaluate_weave_image_quality.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import fire
import wandb
import weave

from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
import wandb
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric
from hemm.models import BaseDiffusionModel


def main(
Expand Down
5 changes: 3 additions & 2 deletions examples/evaluate_weave_prompt_alignment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import fire
import wandb
import weave

from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
import wandb
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.prompt_alignment import CLIPImageQualityScoreMetric, CLIPScoreMetric
from hemm.models import BaseDiffusionModel


def main(
Expand Down
59 changes: 59 additions & 0 deletions examples/multimodal_llm_eval/evaluate_mllm_metric_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Optional

import fire
import weave

import wandb
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.vqa import MultiModalLLMEvaluationMetric
from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge, PromptCategory
from hemm.models import BaseDiffusionModel


def main(
project="mllm-eval",
entity="hemm-eval",
dataset_ref: Optional[str] = "attribute_binding_dataset:v1",
dataset_limit: Optional[int] = None,
diffusion_model_address: str = "stabilityai/stable-diffusion-2-1",
diffusion_model_enable_cpu_offfload: bool = False,
openai_judge_model: str = "gpt-4o",
image_height: int = 1024,
image_width: int = 1024,
num_inference_steps: int = 50,
mock_inference_dataset_address: Optional[str] = None,
save_inference_dataset_name: Optional[str] = None,
):
wandb.init(project=project, entity=entity, job_type="evaluation")
weave.init(project_name=f"{entity}/{project}")

dataset = weave.ref(dataset_ref).get()
dataset = dataset.rows[:dataset_limit] if dataset_limit else dataset

diffusion_model = BaseDiffusionModel(
diffusion_model_name_or_path=diffusion_model_address,
enable_cpu_offfload=diffusion_model_enable_cpu_offfload,
image_height=image_height,
image_width=image_width,
num_inference_steps=num_inference_steps,
)
diffusion_model._pipeline.set_progress_bar_config(disable=True)
evaluation_pipeline = EvaluationPipeline(
model=diffusion_model,
mock_inference_dataset_address=mock_inference_dataset_address,
save_inference_dataset_name=save_inference_dataset_name,
)

judge = OpenAIJudge(
prompt_property=PromptCategory.action, openai_model=openai_judge_model
)
metric = MultiModalLLMEvaluationMetric(judge=judge)
evaluation_pipeline.add_metric(metric)

evaluation_pipeline(dataset=dataset)
wandb.finish()
evaluation_pipeline.cleanup()


if __name__ == "__main__":
fire.Fire(main)
20 changes: 16 additions & 4 deletions examples/multimodal_llm_eval/evaluate_mllm_metric_complex.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Optional

import fire
import wandb
import weave

from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
import wandb
from hemm.eval_pipelines import EvaluationPipeline
from hemm.metrics.vqa import MultiModalLLMEvaluationMetric
from hemm.metrics.vqa.judges.mmllm_judges import OpenAIJudge, PromptCategory
from hemm.models import BaseDiffusionModel


def main(
Expand All @@ -16,9 +17,12 @@ def main(
dataset_limit: Optional[int] = None,
diffusion_model_address: str = "stabilityai/stable-diffusion-2-1",
diffusion_model_enable_cpu_offfload: bool = False,
openai_judge_model: str = "gpt-4o",
image_height: int = 1024,
image_width: int = 1024,
num_inference_steps: int = 50,
mock_inference_dataset_address: Optional[str] = None,
save_inference_dataset_name: Optional[str] = None,
):
wandb.init(project=project, entity=entity, job_type="evaluation")
weave.init(project_name=f"{entity}/{project}")
Expand All @@ -34,13 +38,21 @@ def main(
num_inference_steps=num_inference_steps,
)
diffusion_model._pipeline.set_progress_bar_config(disable=True)
evaluation_pipeline = EvaluationPipeline(model=diffusion_model)
evaluation_pipeline = EvaluationPipeline(
model=diffusion_model,
mock_inference_dataset_address=mock_inference_dataset_address,
save_inference_dataset_name=save_inference_dataset_name,
)

judge = OpenAIJudge(prompt_property=PromptCategory.action)
judge = OpenAIJudge(
prompt_property=PromptCategory.action, openai_model=openai_judge_model
)
metric = MultiModalLLMEvaluationMetric(judge=judge)
evaluation_pipeline.add_metric(metric)

evaluation_pipeline(dataset=dataset)
wandb.finish()
evaluation_pipeline.cleanup()


if __name__ == "__main__":
Expand Down
3 changes: 1 addition & 2 deletions hemm/eval_pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .eval_pipeline import EvaluationPipeline
from .model import BaseDiffusionModel

__all__ = ["BaseDiffusionModel", "EvaluationPipeline"]
__all__ = ["EvaluationPipeline"]
92 changes: 81 additions & 11 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import asyncio
import os
import shutil
from abc import ABC
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union

import wandb
import weave
from PIL import Image

import wandb

from ..metrics.base import BaseMetric
from .model import BaseDiffusionModel
from ..models import BaseDiffusionModel, FalAIModel, StabilityAPIModel

MODEL_TYPE = Union[BaseDiffusionModel, FalAIModel, StabilityAPIModel]


class EvaluationPipeline(ABC):
Expand All @@ -15,16 +21,43 @@ class EvaluationPipeline(ABC):
Args:
model (BaseDiffusionModel): The model to evaluate.
seed (int): Seed value for the random number generator.
mock_inference_dataset_address (Optional[str]): A wandb dataset artifact address which if
provided will mock inference results. This prevents the need for redundant generations
when switching metrics/judges with the same evaluation datset(s).
save_inference_dataset_name (Optional[str]): A weave dataset name which if provided will
save inference results as a separate weave dataset.
"""

def __init__(self, model: BaseDiffusionModel, seed: int = 42) -> None:
def __init__(
self,
model: MODEL_TYPE,
seed: int = 42,
mock_inference_dataset_address: Optional[str] = None,
save_inference_dataset_name: Optional[str] = None,
) -> None:
super().__init__()
self.model = model

self.image_size = (self.model.image_height, self.model.image_width)
self.seed = seed

self.inference_counter = 1
self.mock_inference_dataset_address = mock_inference_dataset_address
if mock_inference_dataset_address:
self.save_inference_dataset_name = None
artifact = wandb.use_artifact(
self.mock_inference_dataset_address, type="dataset"
)
self.mock_inference_dataset_dir = artifact.download()

else:
self.save_inference_dataset_name = save_inference_dataset_name

if self.save_inference_dataset_name:
os.makedirs(
os.path.join("inference_dataset", self.save_inference_dataset_name),
exist_ok=True,
)

self.inference_counter = 0
self.table_columns = ["model", "prompt", "generated_image"]
self.table_rows: List = []
self.evaluation_table: wandb.Table = None
Expand Down Expand Up @@ -63,13 +96,29 @@ def infer(self, prompt: str) -> Dict[str, str]:
Dict[str, str]: Dictionary containing base64 encoded image to be logged as
a Weave object.
"""
if self.inference_counter == 1:
if self.inference_counter == 0:
self.evaluation_table = wandb.Table(columns=self.table_columns)
self.inference_counter += 1
output = self.model.predict(prompt, seed=self.seed)
if self.mock_inference_dataset_address:
image = Image.open(
os.path.join(
self.mock_inference_dataset_dir, f"{self.inference_counter}.png"
)
)
output = {"image": image}
else:
output = self.model.predict(prompt, seed=self.seed)
self.table_rows.append(
[self.model.diffusion_model_name_or_path, prompt, output["image"]]
)
if self.save_inference_dataset_name:
output["image"].save(
os.path.join(
"inference_dataset",
self.save_inference_dataset_name,
f"{self.inference_counter}.png",
)
)
self.inference_counter += 1
return output

@weave.op()
Expand Down Expand Up @@ -104,19 +153,40 @@ def log_summary(self, summary: Dict[str, float]) -> None:
}
)

def __call__(self, dataset: Union[List[Dict], str]) -> Dict[str, float]:
def save_inference_results(self):
artifact = wandb.Artifact(name=self.save_inference_dataset_name, type="dataset")
artifact.add_dir(
os.path.join("inference_dataset", self.save_inference_dataset_name)
)
artifact.save()

def cleanup(self):
"""Cleanup the inference dataset directory. Should be called after the evaluation is complete
and `wandb.finish()` is called."""
if os.path.exists("inference_dataset"):
shutil.rmtree("inference_dataset")

def __call__(
self, dataset: Union[List[Dict], str], async_infer: bool = False
) -> Dict[str, float]:
"""Evaluate the Stable Diffusion model on the given dataset.
Args:
dataset (Union[List[Dict], str]): Dataset to evaluate the model on. If a string is
passed, it is assumed to be a Weave dataset reference.
async_infer (bool, optional): Whether to use async inference. Defaults to False.
"""
dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset
evaluation = weave.Evaluation(
dataset=dataset,
scorers=[metric_fn.evaluate_async for metric_fn in self.metric_functions],
scorers=[
metric_fn.evaluate_async if async_infer else metric_fn.evaluate
for metric_fn in self.metric_functions
],
)
self.model.configs.update(self.evaluation_configs)
summary = asyncio.run(evaluation.evaluate(self.infer_async))
self.log_summary(summary)
if self.save_inference_dataset_name:
self.save_inference_results()
return summary
3 changes: 2 additions & 1 deletion hemm/metrics/spatial_relationship/analyze_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from typing import Dict, List, Optional

import jsonlines
import wandb
import weave
from openai import OpenAI
from weave import Dataset, Evaluation, Model

import wandb


def chunk_all_prompts(prompt_list: List[str], chunk_size: int) -> List[str]:
"""Chunk all prompts into smaller chunks of size `chunk_size`.
Expand Down
3 changes: 2 additions & 1 deletion hemm/metrics/spatial_relationship/spatial_relationship_2d.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Dict, List, Optional, Union

import wandb
import weave
from PIL import Image

import wandb

from ..base import BaseMetric
from .judges import DETRSpatialRelationShipJudge
from .judges.commons import BoundingBox
Expand Down
Loading

0 comments on commit 40b2095

Please sign in to comment.