diff --git a/e2e/preview/components/test_chatgpt_generator.py b/e2e/preview/components/test_chatgpt_generator.py new file mode 100644 index 0000000000..ae243d1a4b --- /dev/null +++ b/e2e/preview/components/test_chatgpt_generator.py @@ -0,0 +1,60 @@ +import os +import pytest +from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_chatgpt_generator_run(): + component = ChatGPTGenerator(api_key=os.environ.get("OPENAI_API_KEY")) + results = component.run(prompts=["What's the capital of France?", "What's the capital of Germany?"], n=1) + + assert len(results["replies"]) == 2 + assert len(results["replies"][0]) == 1 + assert "Paris" in results["replies"][0][0] + assert len(results["replies"][1]) == 1 + assert "Berlin" in results["replies"][1][0] + + assert len(results["metadata"]) == 2 + assert len(results["metadata"][0]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"] + assert "stop" == results["metadata"][0][0]["finish_reason"] + assert len(results["metadata"][1]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"] + assert "stop" == results["metadata"][1][0]["finish_reason"] + + +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_chatgpt_generator_run_streaming(): + class Callback: + def __init__(self): + self.responses = "" + + def __call__(self, token, event_data): + self.responses += token + return token + + callback = Callback() + component = ChatGPTGenerator(os.environ.get("OPENAI_API_KEY"), stream=True, streaming_callback=callback) + results = component.run(prompts=["What's the capital of France?", "What's the capital of Germany?"], n=1) + + assert len(results["replies"]) == 2 + assert len(results["replies"][0]) == 1 + assert "Paris" in results["replies"][0][0] + assert len(results["replies"][1]) == 1 + assert "Berlin" in results["replies"][1][0] + + assert callback.responses == results["replies"][0][0] + results["replies"][1][0] + + assert len(results["metadata"]) == 2 + assert len(results["metadata"][0]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][0][0]["model"] + assert "stop" == results["metadata"][0][0]["finish_reason"] + assert len(results["metadata"][1]) == 1 + assert "gpt-3.5-turbo" in results["metadata"][1][0]["model"] + assert "stop" == results["metadata"][1][0]["finish_reason"] diff --git a/haystack/preview/__init__.py b/haystack/preview/__init__.py index 36f7de744f..a5f9ed225c 100644 --- a/haystack/preview/__init__.py +++ b/haystack/preview/__init__.py @@ -1,4 +1,4 @@ from canals import component, Pipeline from canals.serialization import default_from_dict, default_to_dict -from canals.errors import DeserializationError +from canals.errors import DeserializationError, ComponentError from haystack.preview.dataclasses import * diff --git a/haystack/preview/components/generators/openai/_helpers.py b/haystack/preview/components/generators/openai/_helpers.py deleted file mode 100644 index 946901b644..0000000000 --- a/haystack/preview/components/generators/openai/_helpers.py +++ /dev/null @@ -1,33 +0,0 @@ -import logging - -from haystack.preview.lazy_imports import LazyImport - -with LazyImport("Run 'pip install tiktoken'") as tiktoken_import: - import tiktoken - - -logger = logging.getLogger(__name__) - - -def enforce_token_limit(prompt: str, tokenizer: "tiktoken.Encoding", max_tokens_limit: int) -> str: - """ - Ensure that the length of the prompt is within the max tokens limit of the model. - If needed, truncate the prompt text so that it fits within the limit. - - :param prompt: Prompt text to be sent to the generative model. - :param tokenizer: The tokenizer used to encode the prompt. - :param max_tokens_limit: The max tokens limit of the model. - :return: The prompt text that fits within the max tokens limit of the model. - """ - tiktoken_import.check() - tokens = tokenizer.encode(prompt) - tokens_count = len(tokens) - if tokens_count > max_tokens_limit: - logger.warning( - "The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. " - "Reduce the length of the prompt to prevent it from being cut off.", - tokens_count, - max_tokens_limit, - ) - prompt = tokenizer.decode(tokens[:max_tokens_limit]) - return prompt diff --git a/haystack/preview/components/generators/openai/chatgpt.py b/haystack/preview/components/generators/openai/chatgpt.py new file mode 100644 index 0000000000..afd61e2884 --- /dev/null +++ b/haystack/preview/components/generators/openai/chatgpt.py @@ -0,0 +1,194 @@ +from typing import Optional, List, Callable, Dict, Any + +import logging + +from haystack.preview import component, default_from_dict, default_to_dict +from haystack.preview.llm_backends.openai.chatgpt import ChatGPTBackend +from haystack.preview.llm_backends.chat_message import ChatMessage +from haystack.preview.llm_backends.openai._helpers import default_streaming_callback + + +logger = logging.getLogger(__name__) + + +TOKENS_PER_MESSAGE_OVERHEAD = 4 + + +@component +class ChatGPTGenerator: + """ + ChatGPT LLM Generator. + + Queries ChatGPT using OpenAI's GPT-3 ChatGPT API. Invocations are made using REST API. + See [OpenAI ChatGPT API](https://platform.openai.com/docs/guides/chat) for more details. + """ + + # TODO support function calling! + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gpt-3.5-turbo", + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = 500, + temperature: Optional[float] = 0.7, + top_p: Optional[float] = 1, + n: Optional[int] = 1, + stop: Optional[List[str]] = None, + presence_penalty: Optional[float] = 0, + frequency_penalty: Optional[float] = 0, + logit_bias: Optional[Dict[str, float]] = None, + stream: bool = False, + streaming_callback: Optional[Callable] = default_streaming_callback, + api_base_url: str = "https://api.openai.com/v1", + openai_organization: Optional[str] = None, + ): + """ + Creates an instance of ChatGPTGenerator for OpenAI's GPT-3.5 model. + + :param api_key: The OpenAI API key. + :param model_name: The name or path of the underlying model. + :param system_prompt: The prompt to be prepended to the user prompt. + :param max_tokens: The maximum number of tokens the output text can have. + :param temperature: What sampling temperature to use. Higher values means the model will take more risks. + Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. + :param top_p: An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens + comprising the top 10% probability mass are considered. + :param n: How many completions to generate for each prompt. + :param stop: One or more sequences where the API will stop generating further tokens. + :param presence_penalty: What penalty to apply if a token is already present at all. Bigger values mean + the model will be less likely to repeat the same token in the text. + :param frequency_penalty: What penalty to apply if a token has already been generated in the text. + Bigger values mean the model will be less likely to repeat the same token in the text. + :param logit_bias: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the + values are the bias to add to that token. + :param stream: If set to True, the API will stream the response. The streaming_callback parameter + is used to process the stream. If set to False, the response will be returned as a string. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function should accept two parameters: the token received from the stream and **kwargs. + The callback function should return the token to be sent to the stream. If the callback function is not + provided, the token is printed to stdout. + :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param openai_organization: The OpenAI organization ID. + + See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details. + """ + self.llm = ChatGPTBackend( + api_key=api_key, + model_name=model_name, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + stream=stream, + streaming_callback=streaming_callback, + api_base_url=api_base_url, + openai_organization=openai_organization, + ) + self.system_prompt = system_prompt + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict(self, system_prompt=self.system_prompt, **self.llm.to_dict()) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ChatGPTGenerator": + """ + Deserialize this component from a dictionary. + """ + # FIXME how to deserialize the streaming callback? + return default_from_dict(cls, data) + + @component.output_types(replies=List[List[str]], metadata=List[Dict[str, Any]]) + def run( + self, + prompts: List[str], + api_key: Optional[str] = None, + model_name: Optional[str] = None, + system_prompt: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stop: Optional[List[str]] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + stream: Optional[bool] = None, + streaming_callback: Optional[Callable] = None, + api_base_url: Optional[str] = None, + openai_organization: Optional[str] = None, + ): + """ + Queries the LLM with the prompts to produce replies. + + :param prompts: The prompts to be sent to the generative model. + :param api_key: The OpenAI API key. + :param model_name: The name or path of the underlying model. + :param system_prompt: The prompt to be prepended to the user prompt. + :param max_tokens: The maximum number of tokens the output text can have. + :param temperature: What sampling temperature to use. Higher values means the model will take more risks. + Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. + :param top_p: An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens + comprising the top 10% probability mass are considered. + :param n: How many completions to generate for each prompt. + :param stop: One or more sequences where the API will stop generating further tokens. + :param presence_penalty: What penalty to apply if a token is already present at all. Bigger values mean + the model will be less likely to repeat the same token in the text. + :param frequency_penalty: What penalty to apply if a token has already been generated in the text. + Bigger values mean the model will be less likely to repeat the same token in the text. + :param logit_bias: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the + values are the bias to add to that token. + :param stream: If set to True, the API will stream the response. The streaming_callback parameter + is used to process the stream. If set to False, the response will be returned as a string. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function should accept two parameters: the token received from the stream and **kwargs. + The callback function should return the token to be sent to the stream. If the callback function is not + provided, the token is printed to stdout. + :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param openai_organization: The OpenAI organization ID. + + See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details. + """ + system_prompt = system_prompt if system_prompt is not None else self.system_prompt + if system_prompt: + system_message = ChatMessage(content=system_prompt, role="system") + chats = [] + for prompt in prompts: + message = ChatMessage(content=prompt, role="user") + if system_prompt: + chats.append([system_message, message]) + else: + chats.append([message]) + + replies, metadata = [], [] + for chat in chats: + reply, meta = self.llm.complete( + chat=chat, + api_key=api_key, + model_name=model_name, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + api_base_url=api_base_url, + openai_organization=openai_organization, + stream=stream, + streaming_callback=streaming_callback, + ) + replies.append(reply) + metadata.append(meta) + + return {"replies": replies, "metadata": metadata} diff --git a/haystack/preview/examples/__init__.py b/haystack/preview/examples/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/haystack/preview/examples/chat_gpt_example.py b/haystack/preview/examples/chat_gpt_example.py new file mode 100644 index 0000000000..4e13f86ee3 --- /dev/null +++ b/haystack/preview/examples/chat_gpt_example.py @@ -0,0 +1,13 @@ +import os + +from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator + +stream_response = False + +llm = ChatGPTGenerator( + api_key=os.environ.get("OPENAI_API_KEY"), model_name="gpt-3.5-turbo", max_tokens=256, stream=stream_response +) + +responses = llm.run(prompts=["What is the meaning of life?"]) +if not stream_response: + print(responses) diff --git a/haystack/preview/llm_backends/__init__.py b/haystack/preview/llm_backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/haystack/preview/llm_backends/chat_message.py b/haystack/preview/llm_backends/chat_message.py new file mode 100644 index 0000000000..ca20f905f3 --- /dev/null +++ b/haystack/preview/llm_backends/chat_message.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass +class ChatMessage: + content: str + role: str diff --git a/haystack/preview/llm_backends/openai/__init__.py b/haystack/preview/llm_backends/openai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/haystack/preview/llm_backends/openai/_helpers.py b/haystack/preview/llm_backends/openai/_helpers.py new file mode 100644 index 0000000000..1b446e319b --- /dev/null +++ b/haystack/preview/llm_backends/openai/_helpers.py @@ -0,0 +1,232 @@ +from typing import List, Callable, Dict, Any, Tuple +import os +import logging +import json + +import tenacity +import requests +import sseclient + +from haystack.preview.lazy_imports import LazyImport +from haystack.preview.llm_backends.chat_message import ChatMessage +from haystack.preview.llm_backends.openai.errors import OpenAIError, OpenAIRateLimitError, OpenAIUnauthorizedError + +with LazyImport("Run 'pip install tiktoken'") as tiktoken_import: + import tiktoken + + +logger = logging.getLogger(__name__) + + +OPENAI_TIMEOUT = float(os.environ.get("HAYSTACK_REMOTE_API_TIMEOUT_SEC", 30)) +OPENAI_BACKOFF = int(os.environ.get("HAYSTACK_REMOTE_API_BACKOFF_SEC", 10)) +OPENAI_MAX_RETRIES = int(os.environ.get("HAYSTACK_REMOTE_API_MAX_RETRIES", 5)) +OPENAI_TOKENIZERS = { + **tiktoken.model.MODEL_TO_ENCODING, + "gpt-35-turbo": "cl100k_base", # https://github.com/openai/tiktoken/pull/72 +} +OPENAI_TOKENIZERS_TOKEN_LIMITS = { + "text-davinci": 4097, # Ref: https://platform.openai.com/docs/models/gpt-3 + "gpt-35-turbo": 4097, # Ref: https://platform.openai.com/docs/models/gpt-3-5 + "gpt-3.5-turbo": 4097, # Ref: https://platform.openai.com/docs/models/gpt-3-5 + "gpt-3.5-turbo-16k": 16384, # Ref: https://platform.openai.com/docs/models/gpt-3-5 + "gpt-3": 4096, # Ref: https://platform.openai.com/docs/models/gpt-3 + "gpt-4-32k": 32768, # Ref: https://platform.openai.com/docs/models/gpt-4 + "gpt-4": 8192, # Ref: https://platform.openai.com/docs/models/gpt-4 +} +OPENAI_STREAMING_DONE_MARKER = "[DONE]" # Ref: https://platform.openai.com/docs/api-reference/chat/create#stream + + +#: Retry on OpenAI errors +openai_retry = tenacity.retry( + reraise=True, + retry=tenacity.retry_if_exception_type(OpenAIError) + and tenacity.retry_if_not_exception_type(OpenAIUnauthorizedError), + wait=tenacity.wait_exponential(multiplier=OPENAI_BACKOFF), + stop=tenacity.stop_after_attempt(OPENAI_MAX_RETRIES), +) + + +def default_streaming_callback(token: str, **kwargs): + """ + Default callback function for streaming responses from OpenAI API. + Prints the tokens to stdout as soon as they are received and returns them. + """ + print(token, flush=True, end="") + return token + + +@openai_retry +def complete(url: str, headers: Dict[str, str], payload: Dict[str, Any]) -> Tuple[List[str], List[Dict[str, Any]]]: + """ + Query ChatGPT without streaming the response. + + :param url: The URL to query. + :param headers: The headers to send with the request. + :param payload: The payload to send with the request. + :return: A list of strings containing the response from the OpenAI API. + """ + response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=OPENAI_TIMEOUT) + raise_for_status(response=response) + json_response = json.loads(response.text) + check_truncated_answers(result=json_response, payload=payload) + metadata = [ + { + "model": json_response.get("model", None), + "index": choice.get("index", None), + "finish_reason": choice.get("finish_reason", None), + **json_response.get("usage", {}), + } + for choice in json_response.get("choices", []) + ] + replies = [choice["message"]["content"].strip() for choice in json_response.get("choices", [])] + return replies, metadata + + +@openai_retry +def complete_stream( + url: str, headers: Dict[str, str], payload: Dict[str, Any], callback: Callable +) -> Tuple[List[str], List[Dict[str, Any]]]: + """ + Query ChatGPT and streams the response. Once the stream finishes, returns a list of strings just like + self._query_llm() + + :param url: The URL to query. + :param headers: The headers to send with the request. + :param payload: The payload to send with the request. + :param callback: A callback function that is called when a new token is received from the stream. + The callback function should accept two parameters: the token received from the stream and **kwargs. + The callback function should return the token that will be returned at the end of the streaming. + :return: A list of strings containing the response from the OpenAI API. + """ + response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=OPENAI_TIMEOUT, stream=True) + raise_for_status(response=response) + + client = sseclient.SSEClient(response) # type: ignore + event_data = None + tokens = [] + try: + for event in client.events(): + if event.data == OPENAI_STREAMING_DONE_MARKER: + break + event_data = json.loads(event.data) + delta = event_data["choices"][0]["delta"] + token = delta["content"] if "content" in delta else None + if token: + tokens.append(callback(token, event_data=event_data["choices"])) + finally: + client.close() + metadata = ( + [ + { + "model": event_data.get("model", None), + "index": choice.get("index", None), + "finish_reason": choice.get("finish_reason", None), + } + for choice in event_data.get("choices", []) + ] + if event_data + else [] + ) + return ["".join(tokens)], metadata + + +def raise_for_status(response: requests.Response): + """ + Raises the appropriate OpenAI error in case of a bad response. + + :param response: The response returned from the OpenAI API. + :raises OpenAIError: If the response status code is not 200. + """ + if response.status_code >= 400: + if response.status_code == 429: + raise OpenAIRateLimitError(f"API rate limit exceeded: {response.text}") + if response.status_code == 401: + raise OpenAIUnauthorizedError(f"API key is invalid: {response.text}") + raise OpenAIError( + f"OpenAI returned an error.\n" f"Status code: {response.status_code}\n" f"Response body: {response.text}", + status_code=response.status_code, + ) + + +def check_truncated_answers(result: Dict[str, Any], payload: Dict[str, Any]): + """ + Check the `finish_reason` the answers returned by OpenAI completions endpoint. + If the `finish_reason` is `length`, log a warning to the user. + + :param result: The result returned from the OpenAI API. + :param payload: The payload sent to the OpenAI API. + """ + truncated_completions = sum(1 for ans in result["choices"] if ans["finish_reason"] == "length") + if truncated_completions > 0: + logger.warning( + "%s out of the %s completions have been truncated before reaching a natural stopping point. " + "Increase the max_tokens parameter to allow for longer completions.", + truncated_completions, + payload["n"], + ) + + +def enforce_token_limit(prompt: str, tokenizer: "tiktoken.Encoding", max_tokens_limit: int) -> str: + """ + Ensure that the length of the prompt is within the max tokens limit of the model. + If needed, truncate the prompt text so that it fits within the limit. + + :param prompt: Prompt text to be sent to the generative model. + :param tokenizer: The tokenizer used to encode the prompt. + :param max_tokens_limit: The max tokens limit of the model. + :return: The prompt text that fits within the max tokens limit of the model. + """ + tiktoken_import.check() + tokens = tokenizer.encode(prompt) + tokens_count = len(tokens) + if tokens_count > max_tokens_limit: + logger.warning( + "The prompt has been truncated from %s tokens to %s tokens to fit within the max token limit. " + "Reduce the length of the prompt to prevent it from being cut off.", + tokens_count, + max_tokens_limit, + ) + prompt = tokenizer.decode(tokens[:max_tokens_limit]) + return prompt + + +def enforce_token_limit_chat( + chat: List[ChatMessage], tokenizer: "tiktoken.Encoding", max_tokens_limit: int, tokens_per_message_overhead: int +) -> List[ChatMessage]: + """ + Ensure that the length of the chat is within the max tokens limit of the model. + If needed, truncate the messages so that the chat fits within the limit. + + :param chat: The chat messages to be sent to the generative model. + :param tokenizer: The tokenizer used to encode the chat. + :param max_tokens_limit: The max tokens limit of the model. + :param tokens_per_message_overhead: The number of tokens that are added to the prompt text for each message. + :return: A chat that fits within the max tokens limit of the model. + """ + print(chat) + messages_len = [len(tokenizer.encode(message.content)) + tokens_per_message_overhead for message in chat] + if (total_chat_length := sum(messages_len)) <= max_tokens_limit: + return chat + + logger.warning( + "The chat have been truncated from %s tokens to %s tokens to fit within the max token limit. " + "Reduce the length of the chat to prevent it from being cut off.", + total_chat_length, + max_tokens_limit, + ) + cut_messages = [] + cut_messages_len: List[int] = [] + for message, message_len in zip(chat, messages_len): + if sum(cut_messages_len) + message_len <= max_tokens_limit: + cut_messages.append(message) + cut_messages_len.append(message_len) + else: + remaining_tokens = max_tokens_limit - sum(cut_messages_len) + cut_messages.append( + ChatMessage( + content=enforce_token_limit(message.content, tokenizer, remaining_tokens), role=message.role + ) + ) + break + return cut_messages diff --git a/haystack/preview/llm_backends/openai/chatgpt.py b/haystack/preview/llm_backends/openai/chatgpt.py new file mode 100644 index 0000000000..d416f6e453 --- /dev/null +++ b/haystack/preview/llm_backends/openai/chatgpt.py @@ -0,0 +1,239 @@ +from typing import Optional, List, Callable, Dict, Any + +import logging +from dataclasses import asdict + +from haystack.preview.lazy_imports import LazyImport +from haystack.preview.llm_backends.chat_message import ChatMessage +from haystack.preview.llm_backends.openai._helpers import ( + default_streaming_callback, + complete, + complete_stream, + enforce_token_limit_chat, + OPENAI_TOKENIZERS, + OPENAI_TOKENIZERS_TOKEN_LIMITS, +) + + +with LazyImport() as tiktoken_import: + import tiktoken + + +logger = logging.getLogger(__name__) + + +TOKENS_PER_MESSAGE_OVERHEAD = 4 + + +class ChatGPTBackend: + """ + ChatGPT LLM interface. + + Queries ChatGPT using OpenAI's GPT-3 ChatGPT API. Invocations are made using REST API. + See [OpenAI ChatGPT API](https://platform.openai.com/docs/guides/chat) for more details. + """ + + # TODO support function calling! + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "gpt-3.5-turbo", + max_tokens: Optional[int] = 500, + temperature: Optional[float] = 0.7, + top_p: Optional[float] = 1, + n: Optional[int] = 1, + stop: Optional[List[str]] = None, + presence_penalty: Optional[float] = 0, + frequency_penalty: Optional[float] = 0, + logit_bias: Optional[Dict[str, float]] = None, + stream: bool = False, + streaming_callback: Optional[Callable] = default_streaming_callback, + api_base_url: str = "https://api.openai.com/v1", + openai_organization: Optional[str] = None, + ): + """ + Creates an instance of ChatGPTGenerator for OpenAI's GPT-3.5 model. + + :param api_key: The OpenAI API key. + :param model_name: The name or path of the underlying model. + :param max_tokens: The maximum number of tokens the output text can have. + :param temperature: What sampling temperature to use. Higher values means the model will take more risks. + Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. + :param top_p: An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens + comprising the top 10% probability mass are considered. + :param n: How many completions to generate for each prompt. + :param stop: One or more sequences where the API will stop generating further tokens. + :param presence_penalty: What penalty to apply if a token is already present at all. Bigger values mean + the model will be less likely to repeat the same token in the text. + :param frequency_penalty: What penalty to apply if a token has already been generated in the text. + Bigger values mean the model will be less likely to repeat the same token in the text. + :param logit_bias: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the + values are the bias to add to that token. + :param stream: If set to True, the API will stream the response. The streaming_callback parameter + is used to process the stream. If set to False, the response will be returned as a string. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function should accept two parameters: the token received from the stream and **kwargs. + The callback function should return the token to be sent to the stream. If the callback function is not + provided, the token is printed to stdout. + :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param openai_organization: The OpenAI organization ID. + + See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details. + """ + if not api_key: + logger.warning("OpenAI API key is missing. You will need to provide an API key to Pipeline.run().") + + self.api_key = api_key + self.model_name = model_name + + self.max_tokens = max_tokens + self.temperature = temperature + self.top_p = top_p + self.n = n + self.stop = stop or [] + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.logit_bias = logit_bias or {} + self.stream = stream + self.streaming_callback = streaming_callback or default_streaming_callback + + self.openai_organization = openai_organization + self.api_base_url = api_base_url + + tokenizer = None + for model_prefix, tokenizer_name in OPENAI_TOKENIZERS.items(): + if model_name.startswith(model_prefix): + tokenizer = tiktoken.get_encoding(tokenizer_name) + break + if not tokenizer: + raise ValueError(f"Tokenizer for model '{model_name}' not found.") + self.tokenizer = tokenizer + + max_tokens_limit = None + for model_prefix, limit in OPENAI_TOKENIZERS_TOKEN_LIMITS.items(): + if model_name.startswith(model_prefix): + max_tokens_limit = limit + break + if not max_tokens_limit: + raise ValueError(f"Max tokens limit for model '{model_name}' not found.") + self.max_tokens_limit = max_tokens_limit + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize to a dictionary. + """ + return { + "api_key": self.api_key, + "model_name": self.model_name, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "n": self.n, + "stop": self.stop, + "presence_penalty": self.presence_penalty, + "frequency_penalty": self.frequency_penalty, + "logit_bias": self.logit_bias, + "stream": self.stream, + # FIXME how to serialize the streaming callback? + "api_base_url": self.api_base_url, + "openai_organization": self.openai_organization, + } + + def complete( + self, + chat: List[ChatMessage], + api_key: Optional[str] = None, + model_name: Optional[str] = None, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stop: Optional[List[str]] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + api_base_url: Optional[str] = None, + openai_organization: Optional[str] = None, + stream: Optional[bool] = None, + streaming_callback: Optional[Callable] = None, + ): + """ + Queries the LLM with the prompts to produce replies. + + :param chat: The chat to be sent to the generative model. + :param api_key: The OpenAI API key. + :param model_name: The name or path of the underlying model. + :param max_tokens: The maximum number of tokens the output text can have. + :param temperature: What sampling temperature to use. Higher values means the model will take more risks. + Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. + :param top_p: An alternative to sampling with temperature, called nucleus sampling, where the model + considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens + comprising the top 10% probability mass are considered. + :param n: How many completions to generate for each prompt. + :param stop: One or more sequences where the API will stop generating further tokens. + :param presence_penalty: What penalty to apply if a token is already present at all. Bigger values mean + the model will be less likely to repeat the same token in the text. + :param frequency_penalty: What penalty to apply if a token has already been generated in the text. + Bigger values mean the model will be less likely to repeat the same token in the text. + :param logit_bias: Add a logit bias to specific tokens. The keys of the dictionary are tokens and the + values are the bias to add to that token. + :param stream: If set to True, the API will stream the response. The streaming_callback parameter + is used to process the stream. If set to False, the response will be returned as a string. + :param streaming_callback: A callback function that is called when a new token is received from the stream. + The callback function should accept two parameters: the token received from the stream and **kwargs. + The callback function should return the token to be sent to the stream. If the callback function is not + provided, the token is printed to stdout. + :param api_base_url: The OpenAI API Base url, defaults to `https://api.openai.com/v1`. + :param openai_organization: The OpenAI organization ID. + + See OpenAI documentation](https://platform.openai.com/docs/api-reference/chat) for more details. + """ + api_key = api_key if api_key is not None else self.api_key + + if not api_key: + raise ValueError("OpenAI API key is missing. Please provide an API key.") + + model_name = model_name if model_name is not None else self.model_name + max_tokens = max_tokens if max_tokens is not None else self.max_tokens + temperature = temperature if temperature is not None else self.temperature + top_p = top_p if top_p is not None else self.top_p + n = n if n is not None else self.n + stop = stop if stop is not None else self.stop + presence_penalty = presence_penalty if presence_penalty is not None else self.presence_penalty + frequency_penalty = frequency_penalty if frequency_penalty is not None else self.frequency_penalty + logit_bias = logit_bias if logit_bias is not None else self.logit_bias + stream = stream if stream is not None else self.stream + streaming_callback = streaming_callback if streaming_callback is not None else self.streaming_callback + api_base_url = api_base_url or self.api_base_url + openai_organization = openai_organization if openai_organization is not None else self.openai_organization + + parameters = { + "model": model_name, + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "n": n, + "stream": stream, + "stop": stop, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + } + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + if openai_organization: + headers["OpenAI-Organization"] = openai_organization + url = f"{api_base_url}/chat/completions" + + chat = enforce_token_limit_chat( + chat=chat, + tokenizer=self.tokenizer, + max_tokens_limit=self.max_tokens_limit, + tokens_per_message_overhead=TOKENS_PER_MESSAGE_OVERHEAD, + ) + payload = {**parameters, "messages": [asdict(message) for message in chat]} + if stream: + return complete_stream(url=url, headers=headers, payload=payload, callback=streaming_callback) + else: + return complete(url=url, headers=headers, payload=payload) diff --git a/haystack/preview/llm_backends/openai/errors.py b/haystack/preview/llm_backends/openai/errors.py new file mode 100644 index 0000000000..1787b4e17a --- /dev/null +++ b/haystack/preview/llm_backends/openai/errors.py @@ -0,0 +1,35 @@ +from typing import Optional +from haystack.preview import ComponentError + + +class OpenAIError(ComponentError): + """Exception for issues that occur in the OpenAI APIs""" + + def __init__(self, message: Optional[str] = None, status_code: Optional[int] = None): + super().__init__() + self.message = message + self.status_code = status_code + + def __str__(self): + return self.message + f"(status code {self.status_code})" if self.status_code else "" + + +class OpenAIRateLimitError(OpenAIError): + """ + Rate limit error for OpenAI API (status code 429) + See https://help.openai.com/en/articles/5955604-how-can-i-solve-429-too-many-requests-errors + See https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits + """ + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message, status_code=429) + + +class OpenAIUnauthorizedError(OpenAIError): + """ + Unauthorized error for OpenAI API (status code 401) + See https://platform.openai.com/docs/guides/error-codes/api-errors + """ + + def __init__(self, message: Optional[str] = None): + super().__init__(message=message, status_code=401) diff --git a/releasenotes/notes/chatgpt-generator-6f47f1f6207c05f5.yaml b/releasenotes/notes/chatgpt-generator-6f47f1f6207c05f5.yaml new file mode 100644 index 0000000000..363b1d092f --- /dev/null +++ b/releasenotes/notes/chatgpt-generator-6f47f1f6207c05f5.yaml @@ -0,0 +1,2 @@ +preview: + - Add ChatGPTGenerator component. diff --git a/test/preview/components/generators/openai/test_chatgpt_generator.py b/test/preview/components/generators/openai/test_chatgpt_generator.py new file mode 100644 index 0000000000..be1597e289 --- /dev/null +++ b/test/preview/components/generators/openai/test_chatgpt_generator.py @@ -0,0 +1,223 @@ +from unittest.mock import patch + +import pytest + +from haystack.preview.components.generators.openai.chatgpt import ChatGPTGenerator +from haystack.preview.components.generators.openai.chatgpt import default_streaming_callback + + +class TestChatGPTGenerator: + @pytest.mark.unit + def test_init_default(self, caplog): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTGenerator() + assert component.system_prompt is None + assert component.llm.api_key is None + assert component.llm.model_name == "gpt-3.5-turbo" + assert component.llm.max_tokens == 500 + assert component.llm.temperature == 0.7 + assert component.llm.top_p == 1 + assert component.llm.n == 1 + assert component.llm.stop == [] + assert component.llm.presence_penalty == 0 + assert component.llm.frequency_penalty == 0 + assert component.llm.logit_bias == {} + assert component.llm.stream is False + assert component.llm.streaming_callback == default_streaming_callback + assert component.llm.api_base_url == "https://api.openai.com/v1" + assert component.llm.openai_organization is None + assert component.llm.max_tokens_limit == 4097 + + tiktoken_patch.get_encoding.assert_called_once_with("cl100k_base") + assert caplog.records[0].message == ( + "OpenAI API key is missing. You will need to provide an API key to Pipeline.run()." + ) + + @pytest.mark.unit + def test_init_with_parameters(self, caplog): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + callback = lambda x: x + component = ChatGPTGenerator( + api_key="test-api-key", + model_name="gpt-4", + system_prompt="test-system-prompt", + max_tokens=20, + temperature=1, + top_p=5, + n=10, + stop=["test-stop-word"], + presence_penalty=0.5, + frequency_penalty=0.4, + logit_bias={"test-logit-bias": 0.3}, + stream=True, + streaming_callback=callback, + api_base_url="test-base-url", + openai_organization="test-orga-id", + ) + assert component.system_prompt == "test-system-prompt" + assert component.llm.api_key == "test-api-key" + assert component.llm.model_name == "gpt-4" + assert component.llm.max_tokens == 20 + assert component.llm.temperature == 1 + assert component.llm.top_p == 5 + assert component.llm.n == 10 + assert component.llm.stop == ["test-stop-word"] + assert component.llm.presence_penalty == 0.5 + assert component.llm.frequency_penalty == 0.4 + assert component.llm.logit_bias == {"test-logit-bias": 0.3} + assert component.llm.stream is True + assert component.llm.streaming_callback == callback + assert component.llm.api_base_url == "test-base-url" + assert component.llm.openai_organization == "test-orga-id" + assert component.llm.max_tokens_limit == 8192 + + tiktoken_patch.get_encoding.assert_called_once_with("cl100k_base") + assert not caplog.records + + @pytest.mark.unit + def test_to_dict_default(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTGenerator() + data = component.to_dict() + assert data == { + "type": "ChatGPTGenerator", + "init_parameters": { + "api_key": None, + "model_name": "gpt-3.5-turbo", + "system_prompt": None, + "max_tokens": 500, + "temperature": 0.7, + "top_p": 1, + "n": 1, + "stop": [], + "presence_penalty": 0, + "frequency_penalty": 0, + "logit_bias": {}, + "stream": False, + # FIXME serialize callback? + "api_base_url": "https://api.openai.com/v1", + "openai_organization": None, + }, + } + + @pytest.mark.unit + def test_to_dict_with_parameters(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + callback = lambda x: x + component = ChatGPTGenerator( + api_key="test-api-key", + model_name="gpt-4", + system_prompt="test-system-prompt", + max_tokens=20, + temperature=1, + top_p=5, + n=10, + stop=["test-stop-word"], + presence_penalty=0.5, + frequency_penalty=0.4, + logit_bias={"test-logit-bias": 0.3}, + stream=True, + streaming_callback=callback, + api_base_url="test-base-url", + openai_organization="test-orga-id", + ) + data = component.to_dict() + assert data == { + "type": "ChatGPTGenerator", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "gpt-4", + "system_prompt": "test-system-prompt", + "max_tokens": 20, + "temperature": 1, + "top_p": 5, + "n": 10, + "stop": ["test-stop-word"], + "presence_penalty": 0.5, + "frequency_penalty": 0.4, + "logit_bias": {"test-logit-bias": 0.3}, + "stream": True, + # FIXME serialize callback? + "api_base_url": "test-base-url", + "openai_organization": "test-orga-id", + }, + } + + @pytest.mark.unit + def test_from_dict(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + data = { + "type": "ChatGPTGenerator", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "gpt-4", + "system_prompt": "test-system-prompt", + "max_tokens": 20, + "temperature": 1, + "top_p": 5, + "n": 10, + "stop": ["test-stop-word"], + "presence_penalty": 0.5, + "frequency_penalty": 0.4, + "logit_bias": {"test-logit-bias": 0.3}, + "stream": True, + # FIXME serialize callback? + "api_base_url": "test-base-url", + "openai_organization": "test-orga-id", + }, + } + component = ChatGPTGenerator.from_dict(data) + assert component.system_prompt == "test-system-prompt" + assert component.llm.api_key == "test-api-key" + assert component.llm.model_name == "gpt-4" + assert component.llm.max_tokens == 20 + assert component.llm.temperature == 1 + assert component.llm.top_p == 5 + assert component.llm.n == 10 + assert component.llm.stop == ["test-stop-word"] + assert component.llm.presence_penalty == 0.5 + assert component.llm.frequency_penalty == 0.4 + assert component.llm.logit_bias == {"test-logit-bias": 0.3} + assert component.llm.stream is True + assert component.llm.streaming_callback == default_streaming_callback + assert component.llm.api_base_url == "test-base-url" + assert component.llm.openai_organization == "test-orga-id" + assert component.llm.max_tokens_limit == 8192 + + @pytest.mark.unit + def test_run_no_api_key(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTGenerator() + with pytest.raises(ValueError, match="OpenAI API key is missing. Please provide an API key."): + component.run(prompts=["test"]) + + @pytest.mark.unit + def test_run_no_system_prompt(self): + with patch("haystack.preview.components.generators.openai.chatgpt.ChatGPTBackend") as chatgpt_patch: + chatgpt_patch.return_value.complete.side_effect = lambda chat, **kwargs: ( + [f"{msg.role}: {msg.content}" for msg in chat], + {"some_info": None}, + ) + component = ChatGPTGenerator(api_key="test-api-key") + results = component.run(prompts=["test-prompt-1", "test-prompt-2"]) + assert results == { + "replies": [["user: test-prompt-1"], ["user: test-prompt-2"]], + "metadata": [{"some_info": None}, {"some_info": None}], + } + + @pytest.mark.unit + def test_run_with_system_prompt(self): + with patch("haystack.preview.components.generators.openai.chatgpt.ChatGPTBackend") as chatgpt_patch: + chatgpt_patch.return_value.complete.side_effect = lambda chat, **kwargs: ( + [f"{msg.role}: {msg.content}" for msg in chat], + {"some_info": None}, + ) + component = ChatGPTGenerator(api_key="test-api-key", system_prompt="test-system-prompt") + results = component.run(prompts=["test-prompt-1", "test-prompt-2"]) + assert results == { + "replies": [ + ["system: test-system-prompt", "user: test-prompt-1"], + ["system: test-system-prompt", "user: test-prompt-2"], + ], + "metadata": [{"some_info": None}, {"some_info": None}], + } diff --git a/test/preview/components/generators/openai/test_openai_helpers.py b/test/preview/components/generators/openai/test_openai_helpers.py deleted file mode 100644 index 23a66117d1..0000000000 --- a/test/preview/components/generators/openai/test_openai_helpers.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest - -from haystack.preview.components.generators.openai._helpers import enforce_token_limit - - -@pytest.mark.unit -def test_enforce_token_limit_above_limit(caplog, mock_tokenizer): - prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3) - assert prompt == "This is a" - assert caplog.records[0].message == ( - "The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token " - "limit. Reduce the length of the prompt to prevent it from being cut off." - ) - - -@pytest.mark.unit -def test_enforce_token_limit_below_limit(caplog, mock_tokenizer): - prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100) - assert prompt == "This is a test prompt." - assert not caplog.records diff --git a/test/preview/conftest.py b/test/preview/conftest.py index b8abfa41a6..377370bccf 100644 --- a/test/preview/conftest.py +++ b/test/preview/conftest.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest @@ -11,3 +11,12 @@ def mock_tokenizer(): tokenizer.encode = lambda text: text.split() tokenizer.decode = lambda tokens: " ".join(tokens) return tokenizer + + +@pytest.fixture(autouse=True) +def tenacity_wait(): + """ + Mocks tenacity's wait function to speed up tests. + """ + with patch("tenacity.nap.time"): + yield diff --git a/test/preview/llm_backends/test_chatgpt_backend.py b/test/preview/llm_backends/test_chatgpt_backend.py new file mode 100644 index 0000000000..58a78b6756 --- /dev/null +++ b/test/preview/llm_backends/test_chatgpt_backend.py @@ -0,0 +1,248 @@ +from unittest.mock import patch, Mock + +import pytest + +from haystack.preview.llm_backends.openai.chatgpt import ChatGPTBackend, default_streaming_callback, ChatMessage + + +class TestChatGPTBackend: + @pytest.mark.unit + def test_init_default(self, caplog): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTBackend() + assert component.api_key is None + assert component.model_name == "gpt-3.5-turbo" + assert component.max_tokens == 500 + assert component.temperature == 0.7 + assert component.top_p == 1 + assert component.n == 1 + assert component.stop == [] + assert component.presence_penalty == 0 + assert component.frequency_penalty == 0 + assert component.logit_bias == {} + assert component.stream is False + assert component.streaming_callback == default_streaming_callback + assert component.api_base_url == "https://api.openai.com/v1" + assert component.openai_organization is None + assert component.max_tokens_limit == 4097 + + tiktoken_patch.get_encoding.assert_called_once_with("cl100k_base") + assert caplog.records[0].message == ( + "OpenAI API key is missing. You will need to provide an API key to Pipeline.run()." + ) + + @pytest.mark.unit + def test_init_with_parameters(self, caplog): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + callback = lambda x: x + component = ChatGPTBackend( + api_key="test-api-key", + model_name="gpt-4", + max_tokens=20, + temperature=1, + top_p=5, + n=10, + stop=["test-stop-word"], + presence_penalty=0.5, + frequency_penalty=0.4, + logit_bias={"test-logit-bias": 0.3}, + stream=True, + streaming_callback=callback, + api_base_url="test-base-url", + openai_organization="test-orga-id", + ) + assert component.api_key == "test-api-key" + assert component.model_name == "gpt-4" + assert component.max_tokens == 20 + assert component.temperature == 1 + assert component.top_p == 5 + assert component.n == 10 + assert component.stop == ["test-stop-word"] + assert component.presence_penalty == 0.5 + assert component.frequency_penalty == 0.4 + assert component.logit_bias == {"test-logit-bias": 0.3} + assert component.stream is True + assert component.streaming_callback == callback + assert component.api_base_url == "test-base-url" + assert component.openai_organization == "test-orga-id" + assert component.max_tokens_limit == 8192 + + tiktoken_patch.get_encoding.assert_called_once_with("cl100k_base") + assert not caplog.records + + @pytest.mark.unit + def test_init_unknown_tokenizer(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + with pytest.raises(ValueError, match="Tokenizer for model 'test-another-model-name' not found."): + ChatGPTBackend(model_name="test-another-model-name") + + @pytest.mark.unit + def test_init_unknown_token_limit(self, monkeypatch): + monkeypatch.setattr( + "haystack.preview.llm_backends.openai.chatgpt.OPENAI_TOKENIZERS", {"test-model-name": "test-encoding"} + ) + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + with pytest.raises(ValueError, match="Max tokens limit for model 'test-model-name' not found."): + ChatGPTBackend(model_name="test-model-name") + + @pytest.mark.unit + def test_to_dict_default(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTBackend() + data = component.to_dict() + assert data == { + "api_key": None, + "model_name": "gpt-3.5-turbo", + "max_tokens": 500, + "temperature": 0.7, + "top_p": 1, + "n": 1, + "stop": [], + "presence_penalty": 0, + "frequency_penalty": 0, + "logit_bias": {}, + "stream": False, + # FIXME serialize callback? + "api_base_url": "https://api.openai.com/v1", + "openai_organization": None, + } + + @pytest.mark.unit + def test_to_dict_with_parameters(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + callback = lambda x: x + component = ChatGPTBackend( + api_key="test-api-key", + model_name="gpt-4", + max_tokens=20, + temperature=1, + top_p=5, + n=10, + stop=["test-stop-word"], + presence_penalty=0.5, + frequency_penalty=0.4, + logit_bias={"test-logit-bias": 0.3}, + stream=True, + streaming_callback=callback, + api_base_url="test-base-url", + openai_organization="test-orga-id", + ) + data = component.to_dict() + assert data == { + "api_key": "test-api-key", + "model_name": "gpt-4", + "max_tokens": 20, + "temperature": 1, + "top_p": 5, + "n": 10, + "stop": ["test-stop-word"], + "presence_penalty": 0.5, + "frequency_penalty": 0.4, + "logit_bias": {"test-logit-bias": 0.3}, + "stream": True, + # FIXME serialize callback? + "api_base_url": "test-base-url", + "openai_organization": "test-orga-id", + } + + @pytest.mark.unit + def test_run_no_api_key(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + component = ChatGPTBackend() + with pytest.raises(ValueError, match="OpenAI API key is missing. Please provide an API key."): + component.complete(chat=[]) + + @pytest.mark.unit + def test_complete(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + with patch("haystack.preview.llm_backends.openai.chatgpt.complete") as complete_patch: + complete_patch.side_effect = lambda payload, **kwargs: ( + [ + f"Response for {payload['messages'][1]['content']}", + f"Another Response for {payload['messages'][1]['content']}", + ], + [{"metadata of": payload["messages"][1]["content"]}], + ) + component = ChatGPTBackend( + api_key="test-api-key", openai_organization="test_orga_id", api_base_url="test-base-url" + ) + + results = component.complete( + chat=[ + ChatMessage(content="test-prompt-system", role="system"), + ChatMessage(content="test-prompt-user", role="user"), + ] + ) + + assert results == ( + [f"Response for test-prompt-user", f"Another Response for test-prompt-user"], + [{"metadata of": "test-prompt-user"}], + ) + + complete_patch.call_count == 2 + complete_patch.assert_called_once_with( + url="test-base-url/chat/completions", + headers={ + "Authorization": f"Bearer test-api-key", + "Content-Type": "application/json", + "OpenAI-Organization": "test_orga_id", + }, + payload={ + "model": "gpt-3.5-turbo", + "max_tokens": 500, + "temperature": 0.7, + "top_p": 1, + "n": 1, + "stream": False, + "stop": [], + "presence_penalty": 0, + "frequency_penalty": 0, + "logit_bias": {}, + "messages": [ + {"role": "system", "content": "test-prompt-system"}, + {"role": "user", "content": "test-prompt-user"}, + ], + }, + ) + + @pytest.mark.unit + def test_complete_streaming(self): + with patch("haystack.preview.llm_backends.openai.chatgpt.tiktoken") as tiktoken_patch: + with patch("haystack.preview.llm_backends.openai.chatgpt.complete_stream") as complete_stream_patch: + complete_stream_patch.side_effect = lambda payload, **kwargs: ( + [f"Response for {payload['messages'][1]['content']}"], + [{"metadata of": payload["messages"][1]["content"]}], + ) + callback = Mock() + component = ChatGPTBackend(api_key="test-api-key", stream=True, streaming_callback=callback) + + results = component.complete( + chat=[ + ChatMessage(content="test-prompt-system", role="system"), + ChatMessage(content="test-prompt-user", role="user"), + ] + ) + + assert results == (["Response for test-prompt-user"], [{"metadata of": "test-prompt-user"}]) + complete_stream_patch.call_count == 2 + complete_stream_patch.assert_any_call( + url="https://api.openai.com/v1/chat/completions", + headers={"Authorization": f"Bearer test-api-key", "Content-Type": "application/json"}, + payload={ + "model": "gpt-3.5-turbo", + "max_tokens": 500, + "temperature": 0.7, + "top_p": 1, + "n": 1, + "stream": True, + "stop": [], + "presence_penalty": 0, + "frequency_penalty": 0, + "logit_bias": {}, + "messages": [ + {"role": "system", "content": "test-prompt-system"}, + {"role": "user", "content": "test-prompt-user"}, + ], + }, + callback=callback, + ) diff --git a/test/preview/llm_backends/test_openai_helpers.py b/test/preview/llm_backends/test_openai_helpers.py new file mode 100644 index 0000000000..736d7f3dd5 --- /dev/null +++ b/test/preview/llm_backends/test_openai_helpers.py @@ -0,0 +1,252 @@ +from unittest.mock import Mock, patch +import json + +import pytest + +from haystack.preview.llm_backends.openai.errors import OpenAIUnauthorizedError, OpenAIError, OpenAIRateLimitError +from haystack.preview.llm_backends.openai._helpers import ( + ChatMessage, + raise_for_status, + check_truncated_answers, + complete, + complete_stream, + enforce_token_limit, + enforce_token_limit_chat, + OPENAI_TIMEOUT, + OPENAI_MAX_RETRIES, +) + + +@pytest.mark.unit +def test_raise_for_status_200(): + response = Mock() + response.status_code = 200 + raise_for_status(response) + + +@pytest.mark.unit +def test_raise_for_status_401(): + response = Mock() + response.status_code = 401 + with pytest.raises(OpenAIUnauthorizedError): + raise_for_status(response) + + +@pytest.mark.unit +def test_raise_for_status_429(): + response = Mock() + response.status_code = 429 + with pytest.raises(OpenAIRateLimitError): + raise_for_status(response) + + +@pytest.mark.unit +def test_raise_for_status_500(): + response = Mock() + response.status_code = 500 + response.text = "Internal Server Error" + with pytest.raises(OpenAIError): + raise_for_status(response) + + +@pytest.mark.unit +def test_check_truncated_answers(caplog): + result = { + "choices": [ + {"finish_reason": "length"}, + {"finish_reason": "content_filter"}, + {"finish_reason": "length"}, + {"finish_reason": "stop"}, + ] + } + payload = {"n": 4} + check_truncated_answers(result, payload) + assert caplog.records[0].message == ( + "2 out of the 4 completions have been truncated before reaching a natural " + "stopping point. Increase the max_tokens parameter to allow for longer completions." + ) + + +@pytest.mark.unit +def test_query_chat_model(): + with patch("haystack.preview.llm_backends.openai._helpers.requests.post") as mock_post: + response = Mock() + response.status_code = 200 + response.text = """ + { + "model": "test-model", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": {"content": " Hello, how are you? "} + } + ], + "usage": { + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9 + } + + }""" + mock_post.return_value = response + replies, metadata = complete(url="test-url", headers={"header": "test-header"}, payload={"param": "test-param"}) + mock_post.assert_called_once_with( + "test-url", + headers={"header": "test-header"}, + data=json.dumps({"param": "test-param"}), + timeout=OPENAI_TIMEOUT, + ) + assert replies == ["Hello, how are you?"] + assert metadata == [ + { + "model": "test-model", + "index": 0, + "finish_reason": "stop", + "prompt_tokens": 4, + "completion_tokens": 5, + "total_tokens": 9, + } + ] + + +@pytest.mark.unit +def test_query_chat_model_fail(): + with patch("haystack.preview.llm_backends.openai._helpers.requests.post") as mock_post: + response = Mock() + response.status_code = 500 + mock_post.return_value = response + with pytest.raises(OpenAIError): + complete(url="test-url", headers={"header": "test-header"}, payload={"param": "test-param"}) + mock_post.assert_called_with( + "test-url", + headers={"header": "test-header"}, + data=json.dumps({"param": "test-param"}), + timeout=OPENAI_TIMEOUT, + ) + mock_post.call_count == OPENAI_MAX_RETRIES + + +def mock_chat_completion_stream(model="test-model", index=0, token="test", finish_reason="stop"): + return Mock( + data=f"""{{ + "model": "{model}", + "choices": [ + {{ + "index": {index}, + "delta": {{"content": "{token}"}}, + "finish_reason": "{finish_reason}" + }} + ] + }}""" + ) + + +@pytest.mark.unit +def test_query_chat_model_stream(): + with patch("haystack.preview.llm_backends.openai._helpers.requests.post") as mock_post: + with patch("haystack.preview.llm_backends.openai._helpers.sseclient.SSEClient") as mock_sseclient: + callback = lambda token, event_data: f"|{token}|" + response = Mock() + response.status_code = 200 + + mock_sseclient.return_value.events.return_value = [ + mock_chat_completion_stream(token="Hello"), + mock_chat_completion_stream(token=","), + mock_chat_completion_stream(token=" how"), + mock_chat_completion_stream(token=" are"), + mock_chat_completion_stream(token=" you"), + mock_chat_completion_stream(token="?"), + Mock(data="[DONE]"), + mock_chat_completion_stream(token="discarded tokens"), + ] + + mock_post.return_value = response + replies, metadata = complete_stream( + url="test-url", headers={"header": "test-header"}, payload={"param": "test-param"}, callback=callback + ) + mock_post.assert_called_once_with( + "test-url", + headers={"header": "test-header"}, + data=json.dumps({"param": "test-param"}), + timeout=OPENAI_TIMEOUT, + stream=True, + ) + assert replies == ["|Hello||,|| how|| are|| you||?|"] + assert metadata == [{"model": "test-model", "index": 0, "finish_reason": "stop"}] + + +@pytest.mark.unit +def test_query_chat_model_stream_fail(): + with patch("haystack.preview.llm_backends.openai._helpers.requests.post") as mock_post: + callback = Mock() + response = Mock() + response.status_code = 500 + mock_post.return_value = response + with pytest.raises(OpenAIError): + complete_stream( + url="test-url", headers={"header": "test-header"}, payload={"param": "test-param"}, callback=callback + ) + mock_post.assert_called_with( + "test-url", + headers={"header": "test-header"}, + data=json.dumps({"param": "test-param"}), + timeout=OPENAI_TIMEOUT, + ) + mock_post.call_count == OPENAI_MAX_RETRIES + + +@pytest.mark.unit +def test_enforce_token_limit_above_limit(caplog, mock_tokenizer): + prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=3) + assert prompt == "This is a" + assert caplog.records[0].message == ( + "The prompt has been truncated from 5 tokens to 3 tokens to fit within the max token " + "limit. Reduce the length of the prompt to prevent it from being cut off." + ) + + +@pytest.mark.unit +def test_enforce_token_limit_below_limit(caplog, mock_tokenizer): + prompt = enforce_token_limit("This is a test prompt.", tokenizer=mock_tokenizer, max_tokens_limit=100) + assert prompt == "This is a test prompt." + assert not caplog.records + + +@pytest.mark.unit +def test_enforce_token_limit_chat_above_limit(caplog, mock_tokenizer): + prompts = enforce_token_limit_chat( + [ + ChatMessage(content="System Prompt", role="system"), + ChatMessage(content="This is a test prompt.", role="user"), + ], + tokenizer=mock_tokenizer, + max_tokens_limit=7, + tokens_per_message_overhead=2, + ) + assert prompts == [ + ChatMessage(content="System Prompt", role="system"), + ChatMessage(content="This is a", role="user"), + ] + assert caplog.records[0].message == ( + "The chat have been truncated from 11 tokens to 7 tokens to fit within the max token limit. " + "Reduce the length of the chat to prevent it from being cut off." + ) + + +@pytest.mark.unit +def test_enforce_token_limit_chat_below_limit(caplog, mock_tokenizer): + prompts = enforce_token_limit_chat( + [ + ChatMessage(content="System Prompt", role="system"), + ChatMessage(content="This is a test prompt.", role="user"), + ], + tokenizer=mock_tokenizer, + max_tokens_limit=100, + tokens_per_message_overhead=2, + ) + assert prompts == [ + ChatMessage(content="System Prompt", role="system"), + ChatMessage(content="This is a test prompt.", role="user"), + ] + assert not caplog.records