diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index 0e62d540b..ae96f114d 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -174,7 +174,9 @@ ignore_missing_imports = true addopts = "--strict-markers" markers = [ "integration: integration tests", + "unit: unit tests", "embedders: embedders tests", "generators: generators tests", + "chat_generators: chat_generators tests", ] log_cli = true \ No newline at end of file diff --git a/integrations/cohere/src/cohere_haystack/chat/chat_generator.py b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py new file mode 100644 index 000000000..f3178d567 --- /dev/null +++ b/integrations/cohere/src/cohere_haystack/chat/chat_generator.py @@ -0,0 +1,203 @@ +import logging +import os +from typing import Any, Callable, Dict, List, Optional + +from haystack import component, default_from_dict, default_to_dict +from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.lazy_imports import LazyImport + +with LazyImport(message="Run 'pip install cohere'") as cohere_import: + import cohere +logger = logging.getLogger(__name__) + + +class CohereChatGenerator: + """Enables text generation using Cohere's chat endpoint. This component is designed to inference + Cohere's chat models. + + Users can pass any text generation parameters valid for the `cohere.Client,chat` method + directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs` + parameter in `run` method. + + Invocations are made using 'cohere' package. + See [Cohere API](https://docs.cohere.com/reference/chat) for more details. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "command", + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + api_base_url: Optional[str] = None, + generation_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """ + Initialize the CohereChatGenerator instance. + + :param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var. + :param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly, + command-nightly-light]. Defaults to "command". + :param streaming_callback: A callback function to be called with the streaming response. Defaults to None. + :param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai". + :param generation_kwargs: Additional model parameters. These will be used during generation. Refer to + https://docs.cohere.com/reference/chat for more details. + Some of the parameters are: + - 'chat_history': A list of previous messages between the user and the model, meant to give the model + conversational context for responding to the user's message. + - 'preamble_override': When specified, the default Cohere preamble will be replaced with the provided one. + - 'conversation_id': An alternative to chat_history. Previous conversations can be resumed by providing + the conversation's identifier. The contents of message and the model's response will be stored + as part of this conversation.If a conversation with this id does not already exist, + a new conversation will be created. + - 'prompt_truncation': Defaults to AUTO when connectors are specified and OFF in all other cases. + Dictates how the prompt will be constructed. + - 'connectors': Accepts {"id": "web-search"}, and/or the "id" for a custom connector, if you've created one. + When specified, the model's reply will be enriched with information found by + quering each of the connectors (RAG). + - 'documents': A list of relevant documents that the model can use to enrich its reply. + - 'search_queries_only': Defaults to false. When true, the response will only contain a + list of generated search queries, but no search will take place, and no reply from the model to the + user's message will be generated. + - 'citation_quality': Defaults to "accurate". Dictates the approach taken to generating citations + as part of the RAG flow by allowing the user to specify whether they want + "accurate" results or "fast" results. + - 'temperature': A non-negative float that tunes the degree of randomness in generation. Lower temperatures + mean less random generations. + """ + cohere_import.check() + + if not api_key: + api_key = os.environ.get("COHERE_API_KEY") + if not api_key: + error = "CohereChatGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY." # noqa: E501 + raise ValueError(error) + + if not api_base_url: + api_base_url = cohere.COHERE_API_URL + if generation_kwargs is None: + generation_kwargs = {} + self.api_key = api_key + self.model_name = model_name + self.streaming_callback = streaming_callback + self.api_base_url = api_base_url + self.generation_kwargs = generation_kwargs + self.model_parameters = kwargs + self.client = cohere.Client(api_key=self.api_key, api_url=self.api_base_url) + + def _get_telemetry_data(self) -> Dict[str, Any]: + """ + Data that is sent to Posthog for usage analytics. + """ + return {"model": self.model_name} + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + :return: The serialized component as a dictionary. + """ + callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None + return default_to_dict( + self, + model_name=self.model_name, + streaming_callback=callback_name, + api_base_url=self.api_base_url, + generation_kwargs=self.generation_kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) + return default_from_dict(cls, data) + + def _message_to_dict(self, message: ChatMessage) -> Dict[str, str]: + if message.role == ChatRole.USER: + role = "User" + elif message.role == ChatRole.ASSISTANT: + role = "Chatbot" + chat_message = {"user_name": role, "text": message.content} + return chat_message + + @component.output_types(replies=List[ChatMessage]) + def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, Any]] = None): + """ + Invoke the text generation inference based on the provided messages and generation parameters. + + :param messages: A list of ChatMessage instances representing the input messages. + :param generation_kwargs: Additional keyword arguments for text generation. These parameters will + potentially override the parameters passed in the __init__ method. + For more details on the parameters supported by the Cohere API, refer to the + Cohere [documentation](https://docs.cohere.com/reference/chat). + :return: A list containing the generated responses as ChatMessage instances. + """ + # update generation kwargs by merging with the generation kwargs passed to the run method + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + chat_history = [self._message_to_dict(m) for m in messages[:-1]] + response = self.client.chat( + message=messages[-1].content, + model=self.model_name, + stream=self.streaming_callback is not None, + chat_history=chat_history, + **generation_kwargs, + ) + if self.streaming_callback: + for chunk in response: + if chunk.event_type == "text-generation": + stream_chunk = self._build_chunk(chunk) + self.streaming_callback(stream_chunk) + chat_message = ChatMessage.from_assistant(content=response.texts) + chat_message.metadata.update( + { + "model": self.model_name, + "usage": response.token_count, + "index": 0, + "finish_reason": response.finish_reason, + "documents": response.documents, + "citations": response.citations, + } + ) + else: + chat_message = self._build_message(response) + return {"replies": [chat_message]} + + def _build_chunk(self, chunk) -> StreamingChunk: + """ + Converts the response from the Cohere API to a StreamingChunk. + :param chunk: The chunk returned by the OpenAI API. + :param choice: The choice returned by the OpenAI API. + :return: The StreamingChunk. + """ + # if chunk.event_type == "text-generation": + chat_message = StreamingChunk( + content=chunk.text, metadata={"index": chunk.index, "event_type": chunk.event_type} + ) + return chat_message + + def _build_message(self, cohere_response): + """ + Converts the non-streaming response from the Cohere API to a ChatMessage. + :param cohere_response: The completion returned by the Cohere API. + :return: The ChatMessage. + """ + content = cohere_response.text + message = ChatMessage.from_assistant(content=content) + message.metadata.update( + { + "model": self.model_name, + "usage": cohere_response.token_count, + "index": 0, + "finish_reason": None, + "documents": cohere_response.documents, + "citations": cohere_response.citations, + } + ) + return message diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py new file mode 100644 index 000000000..92954df8b --- /dev/null +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -0,0 +1,346 @@ +import os +from unittest.mock import Mock, patch + +import cohere +import pytest +from haystack.components.generators.utils import default_streaming_callback +from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk + +from cohere_haystack.chat.chat_generator import CohereChatGenerator + +pytestmark = pytest.mark.chat_generators + + +@pytest.fixture +def mock_chat_response(): + """ + Mock the CohereI API response and reuse it for tests + """ + with patch("cohere.Client.chat", autospec=True) as mock_chat_response: + # mimic the response from the Cohere API + + mock_response = Mock() + mock_response.text = "I'm fine, thanks." + mock_response.token_count = { + "prompt_tokens": 66, + "response_tokens": 78, + "total_tokens": 144, + "billed_tokens": 133, + } + mock_response.meta = { + "api_version": {"version": "1"}, + "billed_units": {"input_tokens": 55, "output_tokens": 78}, + } + mock_chat_response.return_value = mock_response + yield mock_chat_response + + +def streaming_chunk(text: str): + """ + Mock chunks of streaming responses from the Cohere API + """ + # mimic the chunk response from the OpenAI API + mock_chunks = Mock() + mock_chunks.index = 0 + mock_chunks.text = text + mock_chunks.event_type = "text-generation" + return mock_chunks + + +@pytest.fixture +def chat_messages(): + return [ChatMessage(content="What's the capital of France", role=ChatRole.ASSISTANT, name=None)] + + +class TestCohereChatGenerator: + @pytest.mark.unit + def test_init_default(self): + component = CohereChatGenerator(api_key="test-api-key") + assert component.api_key == "test-api-key" + assert component.model_name == "command" + assert component.streaming_callback is None + assert component.api_base_url == cohere.COHERE_API_URL + assert not component.generation_kwargs + + @pytest.mark.unit + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("COHERE_API_KEY", raising=False) + with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"): + CohereChatGenerator() + + @pytest.mark.unit + def test_init_with_parameters(self): + component = CohereChatGenerator( + api_key="test-api-key", + model_name="command-nightly", + streaming_callback=default_streaming_callback, + api_base_url="test-base-url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + assert component.api_key == "test-api-key" + assert component.model_name == "command-nightly" + assert component.streaming_callback is default_streaming_callback + assert component.api_base_url == "test-base-url" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + @pytest.mark.unit + def test_to_dict_default(self): + component = CohereChatGenerator(api_key="test-api-key") + data = component.to_dict() + assert data == { + "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "init_parameters": { + "model_name": "command", + "streaming_callback": None, + "api_base_url": "https://api.cohere.ai", + "generation_kwargs": {}, + }, + } + + @pytest.mark.unit + def test_to_dict_with_parameters(self): + component = CohereChatGenerator( + api_key="test-api-key", + model_name="command-nightly", + streaming_callback=default_streaming_callback, + api_base_url="test-base-url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "init_parameters": { + "model_name": "command-nightly", + "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", + "api_base_url": "test-base-url", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + + @pytest.mark.unit + def test_to_dict_with_lambda_streaming_callback(self): + component = CohereChatGenerator( + api_key="test-api-key", + model_name="command", + streaming_callback=lambda x: x, + api_base_url="test-base-url", + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + ) + data = component.to_dict() + assert data == { + "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "init_parameters": { + "model_name": "command", + "api_base_url": "test-base-url", + "streaming_callback": "tests.test_cohere_chat_generator.", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + + @pytest.mark.unit + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "fake-api-key") + data = { + "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "init_parameters": { + "model_name": "command", + "api_base_url": "test-base-url", + "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + component = CohereChatGenerator.from_dict(data) + assert component.model_name == "command" + assert component.streaming_callback is default_streaming_callback + assert component.api_base_url == "test-base-url" + assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} + + @pytest.mark.unit + def test_from_dict_fail_wo_env_var(self, monkeypatch): + monkeypatch.delenv("COHERE_API_KEY", raising=False) + data = { + "type": "cohere_haystack.chat.chat_generator.CohereChatGenerator", + "init_parameters": { + "model_name": "command", + "api_base_url": "test-base-url", + "streaming_callback": "haystack.components.generators.utils.default_streaming_callback", + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + }, + } + with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"): + CohereChatGenerator.from_dict(data) + + @pytest.mark.unit + def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002 + component = CohereChatGenerator(api_key="test-api-key") + response = component.run(chat_messages) + + # check that the component returns the correct ChatMessage response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.unit + def test_message_to_dict(self, chat_messages): + obj = CohereChatGenerator(api_key="api-key") + dictionary = [obj._message_to_dict(message) for message in chat_messages] + assert dictionary == [{"user_name": "Chatbot", "text": "What's the capital of France"}] + + @pytest.mark.unit + def test_run_with_params(self, chat_messages, mock_chat_response): + component = CohereChatGenerator( + api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5} + ) + response = component.run(chat_messages) + + # check that the component calls the Cohere API with the correct parameters + _, kwargs = mock_chat_response.call_args + assert kwargs["max_tokens"] == 10 + assert kwargs["temperature"] == 0.5 + + # check that the component returns the correct response + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.unit + def test_run_streaming(self, chat_messages, mock_chat_response): + streaming_call_count = 0 + + # Define the streaming callback function and assert that it is called with StreamingChunk objects + def streaming_callback_fn(chunk: StreamingChunk): + nonlocal streaming_call_count + streaming_call_count += 1 + assert isinstance(chunk, StreamingChunk) + + generator = CohereChatGenerator(api_key="test-api-key", streaming_callback=streaming_callback_fn) + + # Create a fake streamed response + # self needed here, don't remove + def mock_iter(self): # noqa: ARG001 + yield streaming_chunk("Hello") + yield streaming_chunk("How are you?") + + mock_response = Mock(**{"__iter__": mock_iter}) + mock_chat_response.return_value = mock_response + + response = generator.run(chat_messages) + + # Assert that the streaming callback was called twice + assert streaming_call_count == 2 + + # Assert that the response contains the generated replies + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_live_run(self): + chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", metadata={})] + component = CohereChatGenerator( + api_key=os.environ.get("COHERE_API_KEY"), generation_kwargs={"temperature": 0.8} + ) + results = component.run(chat_messages) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.content + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_wrong_model(self, chat_messages): + component = CohereChatGenerator( + model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY") + ) + with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"): + component.run(chat_messages) + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_streaming(self): + class Callback: + def __init__(self): + self.responses = "" + self.counter = 0 + + def __call__(self, chunk: StreamingChunk) -> None: + self.counter += 1 + self.responses += chunk.content if chunk.content else "" + + callback = Callback() + component = CohereChatGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback) + results = component.run( + [ChatMessage(content="What's the capital of France? answer in a word", role=ChatRole.USER, name=None)] + ) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.content[0] + + assert message.metadata["finish_reason"] == "COMPLETE" + + assert callback.counter > 1 + assert "Paris" in callback.responses + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_connector(self): + chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", metadata={})] + component = CohereChatGenerator( + api_key=os.environ.get("COHERE_API_KEY"), generation_kwargs={"temperature": 0.8} + ) + results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.content + assert message.metadata["documents"] is not None + assert message.metadata["citations"] is not None + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_streaming_with_connector(self): + class Callback: + def __init__(self): + self.responses = "" + self.counter = 0 + + def __call__(self, chunk: StreamingChunk) -> None: + self.counter += 1 + self.responses += chunk.content if chunk.content else "" + + callback = Callback() + chat_messages = [ChatMessage(content="What's the capital of France? answer in a word", role=None, name=None)] + component = CohereChatGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback) + results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) + + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + assert "Paris" in message.content[0] + + assert message.metadata["finish_reason"] == "COMPLETE" + + assert callback.counter > 1 + assert "Paris" in callback.responses + + assert message.metadata["documents"] is not None + assert message.metadata["citations"] is not None