Skip to content

Commit

Permalink
fix llm as judge warnings (#173)
Browse files Browse the repository at this point in the history
* commit

* fixes

* fix style

* fixes

* make style

* Fix import error detection for open ai package (llm as a judge metric)

---------

Co-authored-by: Nathan Habib <[email protected]>
Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
3 people authored Jul 4, 2024
1 parent 7fcaab3 commit 3a80833
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 21 deletions.
File renamed without changes.
14 changes: 11 additions & 3 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
import time
from typing import Optional

from openai import OpenAI

from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.utils import NO_OPENAI_ERROR_MSG, is_openai_available


class JudgeOpenAI:
Expand Down Expand Up @@ -70,7 +69,8 @@ def __init__(
openai_api_key: str,
multi_turn: bool = False,
):
self.client = OpenAI(api_key=openai_api_key)
self.client = None # loaded lazily
self.openai_api_key = openai_api_key
self.model = model
self.seed = seed
self.temperature = temperature
Expand Down Expand Up @@ -112,6 +112,14 @@ def evaluate_answer(
Raises:
Exception: If an error occurs during the API call.
"""
if self.client is None:
if not is_openai_available():
raise ImportError(NO_OPENAI_ERROR_MSG)

from openai import OpenAI

self.client = OpenAI(api_key=self.openai_api_key)

prompts = [
self.__get_prompts_single_turn(
questions[0], answers[0], references[0] if references is not None and len(references) > 0 else None
Expand Down
10 changes: 6 additions & 4 deletions src/lighteval/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import os

import numpy as np
from aenum import Enum

Expand Down Expand Up @@ -225,29 +227,29 @@ class Metrics(Enum):
corpus_level_fn=np.mean,
higher_is_better=True,
)
llm_judge_multi_turn = SampleLevelMetricGrouping(
llm_judge_multi_turn_openai = SampleLevelMetricGrouping(
metric=["single_turn", "multi_turn"],
higher_is_better=True,
category=MetricCategory.LLM_AS_JUDGE_MULTI_TURN,
use_case=MetricUseCase.SUMMARIZATION,
sample_level_fn=JudgeLLM(
judge_model_name="gpt-3.5-turbo",
template_path="src/lighteval/tasks/extended/mt_bench/judge_prompts.jsonl",
template_path=os.path.join(os.path.dirname(__file__), "judge_prompts.jsonl"),
multi_turn=True,
).compute,
corpus_level_fn={
"single_turn": np.mean,
"multi_turn": np.mean,
},
)
llm_judge = SampleLevelMetricGrouping(
llm_judge_openai = SampleLevelMetricGrouping(
metric=["judge_score"],
higher_is_better=True,
category=MetricCategory.LLM_AS_JUDGE,
use_case=MetricUseCase.SUMMARIZATION,
sample_level_fn=JudgeLLM(
judge_model_name="gpt-3.5-turbo",
template_path="src/lighteval/tasks/extended/mt_bench/judge_prompts.jsonl",
template_path=os.path.join(os.path.dirname(__file__), "", "judge_prompts.jsonl"),
multi_turn=False,
).compute,
corpus_level_fn={
Expand Down
20 changes: 8 additions & 12 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,18 +631,14 @@ def __init__(self, judge_model_name: str, template_path: str, multi_turn: bool =
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
self.multi_turn = multi_turn

try:
self.judge = JudgeOpenAI(
model=judge_model_name,
seed=42,
temperature=0.0,
templates_path=template_path,
openai_api_key=OPENAI_API_KEY,
multi_turn=multi_turn,
)
except Exception as e:
print(f"Could not initialize the JudgeOpenAI model:\n{e}")
self.judge = None
self.judge = JudgeOpenAI(
model=judge_model_name,
seed=42,
temperature=0.0,
templates_path=template_path,
openai_api_key=OPENAI_API_KEY,
multi_turn=multi_turn,
)

def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/lighteval/tasks/extended/mt_bench/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
evaluation_splits=["train"],
few_shots_split="",
few_shots_select="random",
metric=["llm_judge_multi_turn"],
metric=["llm_judge_multi_turn_openai"],
generation_size=1024,
stop_sequence=[],
)
Expand Down
16 changes: 15 additions & 1 deletion src/lighteval/tasks/lighteval_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# SOFTWARE.

import collections
import os
import random
from dataclasses import dataclass
from multiprocessing import Pool
Expand Down Expand Up @@ -53,7 +54,7 @@
RequestType,
TaskExampleId,
)
from lighteval.utils import as_list
from lighteval.utils import NO_OPENAI_ERROR_MSG, as_list, is_openai_available

from . import tasks_prompt_formatting

Expand Down Expand Up @@ -200,8 +201,21 @@ def __init__( # noqa: C901
self.metrics = as_list(cfg.metric)
self.suite = as_list(cfg.suite)
ignored = [metric for metric in self.metrics if Metrics[metric].value.category == MetricCategory.IGNORED]

if len(ignored) > 0:
hlog_warn(f"[WARNING] Not implemented yet: ignoring the metric {' ,'.join(ignored)} for task {self.name}.")

if any(
Metrics[metric].value.category in [MetricCategory.LLM_AS_JUDGE, MetricCategory.LLM_AS_JUDGE_MULTI_TURN]
for metric in self.metrics
):
if not is_openai_available():
raise ImportError(NO_OPENAI_ERROR_MSG)
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError(
"Using llm as judge metric but no OPEN_API_KEY were found, please set it with: export OPEN_API_KEY={yourkey}"
)

current_categories = [Metrics[metric].value.category for metric in self.metrics]
self.has_metric_category = {category: (category in current_categories) for category in MetricCategory}
# Sub-optimal system - we might want to store metric parametrisation in a yaml conf for example
Expand Down
7 changes: 7 additions & 0 deletions src/lighteval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ def is_peft_available() -> bool:
NO_PEFT_ERROR_MSG = "You are trying to use adapter weights models, for which you need `peft`, which is not available in your environment. Please install it using pip."


def is_openai_available() -> bool:
return importlib.util.find_spec("openai") is not None


NO_OPENAI_ERROR_MSG = "You are trying to use an Open AI LLM as a judge, for which you need `openai`, which is not available in your environment. Please install it using pip."


def can_load_extended_tasks() -> bool:
imports = []
for package in ["langdetect"]:
Expand Down

0 comments on commit 3a80833

Please sign in to comment.