Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding azure openai evaluator #38

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions needlehaystack/evaluators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .evaluator import Evaluator
from .openai import OpenAIEvaluator
from .langsmith import LangSmithEvaluator
from .azure_openai import AzureOpenAIEvaluator
59 changes: 59 additions & 0 deletions needlehaystack/evaluators/azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from .openai import OpenAIEvaluator
from ..utils import get_from_env_or_error

from langchain_openai import AzureChatOpenAI


class AzureOpenAIEvaluator(OpenAIEvaluator):

def __init__(
self,
model_name: str = None,
model_kwargs: dict = OpenAIEvaluator.DEFAULT_MODEL_KWARGS,
true_answer: str = None,
question_asked: str = None,
azure_openai_endpoint: str = None,
azure_openai_api_version: str = "2024-02-01",
):
"""
:param model_name: The deployment name of the model in Azure OpenAI Studio.
:param model_kwargs: Model configuration. Default is {temperature: 0}
:param true_answer: The true answer to the question asked.
:param question_asked: The question asked to the model.
:param azure_openai_endpoint: The endpoint of the Azure OpenAI resource.
"""

if not model_name:
raise ValueError("model_name must be supplied with init.")
if (not true_answer) or (not question_asked):
raise ValueError(
"true_answer and question_asked must be supplied with init."
)

self.model_name = model_name
self.model_kwargs = model_kwargs
self.true_answer = true_answer
self.question_asked = question_asked

self.api_key = get_from_env_or_error(
env_key="NIAH_EVALUATOR_API_KEY",
error_message="{env_key} must be in env for using openai evaluator."
)

self.azure_openai_endpoint = azure_openai_endpoint or get_from_env_or_error(
env_key="AZURE_OPENAI_ENDPOINT",
error_message="azure_openai_endpoint must be supplied with init or {env_key} must be in your env."
)

self.azure_openai_api_version = azure_openai_api_version or get_from_env_or_error(
env_key="OPENAI_API_VERSION",
error_message="azure_openai_api_version must be supplied with init or {env_key} must be in your env."
)

self.evaluator = AzureChatOpenAI(
azure_deployment=self.model_name,
azure_endpoint=self.azure_openai_endpoint,
openai_api_version=self.azure_openai_api_version,
api_key=self.api_key,
**self.model_kwargs
)
12 changes: 5 additions & 7 deletions needlehaystack/evaluators/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

from .evaluator import Evaluator
from ..utils import get_from_env_or_error

from langchain.evaluation import load_evaluator
from langchain_community.chat_models import ChatOpenAI
Expand Down Expand Up @@ -35,11 +34,10 @@ def __init__(self,
self.true_answer = true_answer
self.question_asked = question_asked

api_key = os.getenv('NIAH_EVALUATOR_API_KEY')
if (not api_key):
raise ValueError("NIAH_EVALUATOR_API_KEY must be in env for using openai evaluator.")

self.api_key = api_key
self.api_key = get_from_env_or_error(
env_key="NIAH_EVALUATOR_API_KEY",
error_message="{env_key} must be in env for using openai evaluator."
)

self.evaluator = ChatOpenAI(model=self.model_name,
openai_api_key=self.api_key,
Expand Down
16 changes: 7 additions & 9 deletions needlehaystack/providers/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import os
import pkg_resources

from .model import ModelProvider
from ..utils import get_from_env_or_error

from operator import itemgetter
from typing import Optional

from anthropic import AsyncAnthropic
from anthropic import Anthropic as AnthropicModel
from langchain_anthropic import ChatAnthropic
from langchain.prompts import PromptTemplate

from .model import ModelProvider

class Anthropic(ModelProvider):
DEFAULT_MODEL_KWARGS: dict = dict(max_tokens_to_sample = 300,
temperature = 0)
Expand All @@ -26,13 +25,12 @@ def __init__(self,
if "claude" not in model_name:
raise ValueError("If the model provider is 'anthropic', the model name must include 'claude'. See https://docs.anthropic.com/claude/reference/selecting-a-model for more details on Anthropic models")

api_key = os.getenv('NIAH_MODEL_API_KEY')
if (not api_key):
raise ValueError("NIAH_MODEL_API_KEY must be in env.")

self.model_name = model_name
self.model_kwargs = model_kwargs
self.api_key = api_key
self.api_key = get_from_env_or_error(
env_key="NIAH_MODEL_API_KEY",
error_message="{env_key} must be in env for using Anthropic model."
)

self.model = AsyncAnthropic(api_key=self.api_key)
self.tokenizer = AnthropicModel().get_tokenizer()
Expand Down
19 changes: 9 additions & 10 deletions needlehaystack/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import tiktoken

from .model import ModelProvider
from ..utils import get_from_env_or_error

from operator import itemgetter
from typing import Optional

from openai import AsyncOpenAI
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
import tiktoken

from .model import ModelProvider


class OpenAI(ModelProvider):
"""
Expand Down Expand Up @@ -37,13 +36,13 @@ def __init__(self,
Raises:
ValueError: If NIAH_MODEL_API_KEY is not found in the environment.
"""
api_key = os.getenv('NIAH_MODEL_API_KEY')
if (not api_key):
raise ValueError("NIAH_MODEL_API_KEY must be in env.")

self.model_name = model_name
self.model_kwargs = model_kwargs
self.api_key = api_key
self.api_key = get_from_env_or_error(
env_key="NIAH_MODEL_API_KEY",
error_message="{env_key} must be in env for using OpenAI model."
)
self.model = AsyncOpenAI(api_key=self.api_key)
self.tokenizer = tiktoken.encoding_for_model(self.model_name)

Expand Down
11 changes: 10 additions & 1 deletion needlehaystack/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from jsonargparse import CLI

from . import LLMNeedleHaystackTester, LLMMultiNeedleHaystackTester
from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator
from .evaluators import Evaluator, LangSmithEvaluator, OpenAIEvaluator, AzureOpenAIEvaluator
from .providers import Anthropic, ModelProvider, OpenAI

load_dotenv()
Expand Down Expand Up @@ -35,6 +35,9 @@ class CommandArgs():
final_context_length_buffer: Optional[int] = 200
seconds_to_sleep_between_completions: Optional[float] = None
print_ongoing_status: Optional[bool] = True
# Azure OpenAI parameters
azure_openai_endpoint: Optional[str] = None
azure_openai_api_version: Optional[str] = "2024-02-01"
# LangSmith parameters
eval_set: Optional[str] = "multi-needle-eval-pizza-3"
# Multi-needle parameters
Expand Down Expand Up @@ -84,6 +87,12 @@ def get_evaluator(args: CommandArgs) -> Evaluator:
return OpenAIEvaluator(model_name=args.evaluator_model_name,
question_asked=args.retrieval_question,
true_answer=args.needle)
case "azure" | "azure_openai" | "aoai":
return AzureOpenAIEvaluator(model_name=args.evaluator_model_name,
question_asked=args.retrieval_question,
true_answer=args.needle,
azure_openai_endpoint=args.azure_openai_endpoint,
azure_openai_api_version=args.azure_openai_api_version)
case "langsmith":
return LangSmithEvaluator()
case _:
Expand Down
8 changes: 8 additions & 0 deletions needlehaystack/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import os

def get_from_env_or_error(env_key: str, error_message: str, error_class = ValueError):
env_value = os.getenv(env_key)
if not env_value:
raise error_class(error_message.format(env_key=env_key))

return env_value