diff --git a/needlehaystack/evaluators/openai.py b/needlehaystack/evaluators/openai.py index 91baa955..a08cbe68 100644 --- a/needlehaystack/evaluators/openai.py +++ b/needlehaystack/evaluators/openai.py @@ -39,10 +39,15 @@ def __init__(self, if (not api_key): raise ValueError("NIAH_EVALUATOR_API_KEY must be in env for using openai evaluator.") - self.api_key = api_key + api_base = os.getenv("NIAH_MODEL_API_BASE") + if api_base is None: + api_base = f"https://api.openai.com/v1" + self.api_key = api_key + self.api_base = api_base self.evaluator = ChatOpenAI(model=self.model_name, openai_api_key=self.api_key, + openai_api_base=self.api_base, **self.model_kwargs) def evaluate_response(self, response: str) -> int: diff --git a/needlehaystack/providers/openai.py b/needlehaystack/providers/openai.py index 1bc841e4..cd1052a4 100644 --- a/needlehaystack/providers/openai.py +++ b/needlehaystack/providers/openai.py @@ -41,10 +41,15 @@ def __init__(self, if (not api_key): raise ValueError("NIAH_MODEL_API_KEY must be in env.") + api_base = os.getenv("NIAH_MODEL_API_BASE") + if api_base is None: + api_base = f"https://api.openai.com/v1" + self.model_name = model_name self.model_kwargs = model_kwargs self.api_key = api_key - self.model = AsyncOpenAI(api_key=self.api_key) + self.api_base = api_base + self.model = AsyncOpenAI(api_key=self.api_key, base_url=self.api_base) self.tokenizer = tiktoken.encoding_for_model(self.model_name) async def evaluate_model(self, prompt: str) -> str: