Skip to content

Commit

Permalink
fix: OpenAIJudge
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Oct 17, 2024
1 parent 673cfbd commit e71a14c
Show file tree
Hide file tree
Showing 6 changed files with 2,898 additions and 2,624 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,5 @@ dump/
**generated_images/
.ruff_cache/
**.jsonl
test.py
test.py
inference_dataset/
5 changes: 4 additions & 1 deletion examples/multimodal_llm_eval/evaluate_mllm_metric_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def main(
image_height: int = 1024,
image_width: int = 1024,
num_inference_steps: int = 50,
save_inference_dataset_name: Optional[str] = None,
):
wandb.init(project=project, entity=entity, job_type="evaluation")
weave.init(project_name=f"{entity}/{project}")
Expand All @@ -36,7 +37,9 @@ def main(
num_inference_steps=num_inference_steps,
)
diffusion_model._pipeline.set_progress_bar_config(disable=True)
evaluation_pipeline = EvaluationPipeline(model=diffusion_model)
evaluation_pipeline = EvaluationPipeline(
model=diffusion_model, save_inference_dataset_name=save_inference_dataset_name
)

judge = OpenAIJudge(
prompt_property=PromptCategory.action, openai_model=openai_judge_model
Expand Down
50 changes: 47 additions & 3 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import asyncio
import os
import shutil
from abc import ABC
from typing import Dict, List, Union
from typing import Any, Dict, List, Optional, Union

import wandb
import weave
from PIL import Image

from ..metrics.base import BaseMetric
from ..models import BaseDiffusionModel, FalAIModel, StabilityAPIModel


MODEL_TYPE = Union[BaseDiffusionModel, FalAIModel, StabilityAPIModel]


Expand All @@ -18,14 +20,28 @@ class EvaluationPipeline(ABC):
Args:
model (BaseDiffusionModel): The model to evaluate.
seed (int): Seed value for the random number generator.
save_inference_dataset_name (Optional[str]): A weave dataset name which if provided will
save inference results as a separate weave dataset.
"""

def __init__(self, model: MODEL_TYPE, seed: int = 42) -> None:
def __init__(
self,
model: MODEL_TYPE,
seed: int = 42,
save_inference_dataset_name: Optional[str] = None,
) -> None:
super().__init__()
self.model = model

self.image_size = (self.model.image_height, self.model.image_width)
self.seed = seed
self.save_inference_dataset_name = save_inference_dataset_name

if self.save_inference_dataset_name:
os.makedirs(
os.path.join("inference_dataset", self.save_inference_dataset_name),
exist_ok=True,
)

self.inference_counter = 1
self.table_columns = ["model", "prompt", "generated_image"]
Expand Down Expand Up @@ -73,6 +89,14 @@ def infer(self, prompt: str) -> Dict[str, str]:
self.table_rows.append(
[self.model.diffusion_model_name_or_path, prompt, output["image"]]
)
if self.save_inference_dataset_name:
output["image"].save(
os.path.join(
"inference_dataset",
self.save_inference_dataset_name,
f"{self.inference_counter - 1}.png",
)
)
return output

@weave.op()
Expand Down Expand Up @@ -107,6 +131,24 @@ def log_summary(self, summary: Dict[str, float]) -> None:
}
)

def save_inference_results(self, dataset: Any):
inference_dataset_rows = []
for idx, row in enumerate(dataset):
generated_image = Image.open(
os.path.join(
"inference_dataset", self.save_inference_dataset_name, f"{idx}.png"
)
)
inference_dataset_rows.append(
{"generated_image": generated_image, "seed": self.seed, **dict(row)}
)
weave.publish(
weave.Dataset(
name=self.save_inference_dataset_name, rows=inference_dataset_rows
)
)
shutil.rmtree("inference_dataset")

def __call__(
self, dataset: Union[List[Dict], str], async_infer: bool = False
) -> Dict[str, float]:
Expand All @@ -128,4 +170,6 @@ def __call__(
self.model.configs.update(self.evaluation_configs)
summary = asyncio.run(evaluation.evaluate(self.infer_async))
self.log_summary(summary)
if self.save_inference_dataset_name:
self.save_inference_results(dataset=dataset)
return summary
59 changes: 26 additions & 33 deletions hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@
import subprocess
from typing import List

import instructor
import spacy
import weave
from openai import OpenAI
from PIL import Image
from pydantic import BaseModel

from .commons import (
PromptCategory,
TaggedPromptParts,
JudgeMent,
JudgeQuestion,
)
from .....utils import base64_encode_image
from .commons import JudgeMent, JudgeQuestion, PromptCategory, TaggedPromptParts


class OpenAIJudgeMent(BaseModel):
Expand Down Expand Up @@ -48,14 +42,13 @@ class OpenAIJudge(weave.Model):
seed: int
_nlp_pipeline: spacy.Language = None
_openai_client: OpenAI = None
_instructor_openai_client: instructor.Instructor = None
_total_score: int = 4

def __init__(
self,
prompt_pipeline: str = "en_core_web_sm",
prompt_property: PromptCategory = PromptCategory.color,
openai_model: str = "gpt-4-turbo",
openai_model: str = "gpt-4o-2024-08-06",
max_retries: int = 5,
seed: int = 42,
):
Expand All @@ -69,9 +62,6 @@ def __init__(
subprocess.run(["spacy", "download", "en_core_web_sm"])
self._nlp_pipeline = spacy.load(self.prompt_pipeline)
self._openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
self._instructor_openai_client = instructor.from_openai(
OpenAI(api_key=os.environ["OPENAI_API_KEY"])
)

@weave.op()
def extract_prompt_parts(self, prompt: str) -> List[TaggedPromptParts]:
Expand Down Expand Up @@ -318,27 +308,30 @@ def execute_chain_of_thought(
Provide your analysis and explanation to justify the score.
"""
judgement_response = self._instructor_openai_client.chat.completions.create(
model=self.openai_model,
response_model=JudgeMent,
max_retries=self.max_retries,
seed=self.seed,
messages=[
{
"role": "system",
"content": question.judgement_question_system_prompt,
},
{
"role": "user",
"content": [
{"type": "text", "text": question.judgement_question},
{
"type": "image_url",
"image_url": {"url": base64_encode_image(image)},
},
],
},
],
judgement_response = (
self._openai_client.beta.chat.completions.parse(
model=self.openai_model,
response_format=JudgeMent,
seed=self.seed,
messages=[
{
"role": "system",
"content": question.judgement_question_system_prompt,
},
{
"role": "user",
"content": [
{"type": "text", "text": question.judgement_question},
{
"type": "image_url",
"image_url": {"url": base64_encode_image(image)},
},
],
},
],
)
.choices[0]
.message.parsed
)
return judgement_response

Expand Down
Loading

0 comments on commit e71a14c

Please sign in to comment.