Skip to content

Commit

Permalink
update for batch processing
Browse files Browse the repository at this point in the history
  • Loading branch information
dannylee1020 committed Dec 18, 2024
1 parent 30aa889 commit eba6e5a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 146 deletions.
159 changes: 13 additions & 146 deletions openpo/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import json
import os
from typing import Any, Dict, List, Optional, Union

import pandas as pd
from typing import Any, Dict, List, Optional

from .internal.error import AuthenticationError, ProviderError
from .internal.response import ChatCompletionOutput, ChatCompletionStreamOutput
from .resources.provider.anthropic import Anthropic
from .resources.provider.huggingface import HuggingFace
from .resources.provider.openai import OpenAI
from .resources.provider.openrouter import OpenRouter
from .resources.batch.batch import Batch
from .resources.eval.eval import Evaluation
from .resources.provider import Anthropic, HuggingFace, OpenAI, OpenRouter


class OpenPO:
Expand All @@ -33,6 +30,9 @@ def __init__(
self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
self.anthropic_api_key = anthropic_api_key or os.getenv("ANTHROPIC_API_KEY")

self._eval = Evaluation(self)
self._batch = Batch(self)

def _get_model_provider(self, model: str) -> str:
try:
return model.split("/")[0]
Expand Down Expand Up @@ -68,19 +68,6 @@ def _get_provider_instance(self, provider: str):

raise ProviderError(provider, "Unsupported model provider")

def _get_model_consensus(
self,
res_a: List[Dict],
res_b: List[Dict],
) -> List[int]:

matching_indices = []
for i, (a, b) in enumerate(zip(res_a, res_b)):
if a.get("q_index") == b.get("q_index") and a["rank"] == b["rank"]:
matching_indices.append(a.get("q_index", i))

return matching_indices

def completions(
self,
models: List[str],
Expand Down Expand Up @@ -124,130 +111,10 @@ def completions(

return responses

def eval_single(
self,
model: str,
questions: List[str],
responses: List[List[str]],
prompt: Optional[str] = None,
):
"""Use single LLM-as-a-judge method to evaluate responses for building preference data.
Args:
model (str): Model identifier to use as a judge. Follows provider/model-identifier format.
questions (List(str)): Questions for each response pair.
responses (List[List[str]]): Pairwise responses to evaluate.
prompt (str): Optional custom prompt for judge model to follow.
Returns (Dict): The evaluation data for responses with preferred, rejected, confidence_score and reason.
@property
def eval(self):
return self._eval

Raises:
AuthenticationError: If required API keys are missing or invalid.
ProviderError: For provider-specific errors during evaluation.
ValueError: If the model format is invalid or provider is not supported.
"""
try:
provider = self._get_model_provider(model)
model_id = self._get_model_id(model)

if provider not in ["openai", "anthropic"]:
raise ProviderError(provider, "Provider not supported for evaluation")

llm = self._get_provider_instance(provider=provider)
res = llm.generate(
model=model_id,
questions=questions,
responses=responses,
prompt=prompt if prompt else None,
)

if provider == "anthropic":
result = res.content[0].input['"evaluation']
result = json.loads(res.choices[0].message.content)["evaluation"]

return {"evaluation": result}
except (AuthenticationError, ValueError) as e:
raise e
except Exception as e:
raise ProviderError(
provider=provider, message=f"Error during evaluation: {str(e)}"
)

def eval_multi(
self,
models: List[str],
questions: List[str],
responses: List[List],
prompt: Optional[str] = None,
):
"""Use multiple LLMs as a judge for model consensus to evaluate responses for building preference data.
Args:
models (List): List of models to use as a judge. Follows provider/model-identifier format.
questions (List(str)): Questions for each response pair.
responses (List[List[str]]): Pairwise responses to evaluate.
prompt (str): Optional custom prompt for judge model to follow.
Returns (Dict): The evaluation data for responses that all models agree on.
- preference: Evaluation data on the input responses.
- q_index: Index of questions that reached consensus by the models.
Raises:
AuthenticationError: If required API keys are missing or invalid.
ProviderError: For provider-specific errors during evaluation.
ValueError: If the model format is invalid or required models are missing.
"""
try:
judge_a = self._get_provider_instance("anthropic")
judge_o = self._get_provider_instance("openai")

a_model = ""
o_model = ""

for m in models:
provider = self._get_model_provider(m)
if provider == "anthropic":
a_model = self._get_model_id(m)
elif provider == "openai":
o_model = self._get_model_id(m)
else:
raise ProviderError(
provider, "Provider not supported for evaluation"
)

if not a_model or not o_model:
raise ValueError("Both Anthropic and OpenAI models must be provided")

res_a = judge_a.generate(
model=a_model,
questions=questions,
responses=responses,
prompt=prompt if prompt else None,
)
parsed_res_a = res_a.content[0].input["evaluation"]

res_o = judge_o.generate(
model=o_model,
questions=questions,
responses=responses,
prompt=prompt if prompt else None,
)
parsed_res_o = json.loads(res_o.choices[0].message.content)["evaluation"]

idx = self._get_model_consensus(
parsed_res_a,
parsed_res_o,
)

return {
"evaluation": [parsed_res_o[i] for i in idx],
"q_index": idx,
}
except (AuthenticationError, ValueError) as e:
raise e
except Exception as e:
raise ProviderError(
provider="eval-multi",
message=f"Error during multi-model evaluation: {str(e)}",
)
@property
def batch(self):
return self._batch
8 changes: 8 additions & 0 deletions openpo/internal/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,11 @@
Consider each and every question with corresponding responses and make evaluation. The length of evaluation result must equal to the number of input questions.
"""

EVALUATION_QUERY_BATCH = """
Here is the question: {}
Here is the pair of responses to evaluate: {}.
Make evaluation on question and responses by following the system prompt.
"""

0 comments on commit eba6e5a

Please sign in to comment.