diff --git a/docs/docs/examples/llm/maritalk.ipynb b/docs/docs/examples/llm/maritalk.ipynb index d885d8635c690..b046df1dbd15b 100644 --- a/docs/docs/examples/llm/maritalk.ipynb +++ b/docs/docs/examples/llm/maritalk.ipynb @@ -13,7 +13,10 @@ "MariTalk is an assistant developed by the Brazilian company [Maritaca AI](https://www.maritaca.ai).\n", "MariTalk is based on language models that have been specially trained to understand Portuguese well.\n", "\n", - "This notebook demonstrates how to use MariTalk with llama-index through a simple example." + "This notebook demonstrates how to use MariTalk with Llama Index through two examples:\n", + "\n", + "1. Get pet name suggestions with chat method;\n", + "2. Classify film reviews as negative or positive with few-shot examples with complete method." ] }, { @@ -31,7 +34,8 @@ "outputs": [], "source": [ "!pip install llama-index\n", - "!pip install llama-index-llms-maritalk" + "!pip install llama-index-llms-maritalk\n", + "!pip install asyncio" ] }, { @@ -46,9 +50,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Usage\n", - "\n", - "### Chat" + "### Example 1 - Pet Name Suggestions with Chat" ] }, { @@ -60,11 +62,11 @@ "from llama_index.core.llms import ChatMessage\n", "from llama_index.llms.maritalk import Maritalk\n", "\n", + "import asyncio\n", + "\n", "# To customize your API key, do this\n", "# otherwise it will lookup MARITALK_API_KEY from your env variable\n", - "# llm = Maritalk(api_key=\"\")\n", - "\n", - "llm = Maritalk()\n", + "llm = Maritalk(api_key=\"\", model=\"sabia-2-medium\")\n", "\n", "# Call chat with a list of messages\n", "messages = [\n", @@ -75,15 +77,55 @@ " ChatMessage(role=\"user\", content=\"I have a dog.\"),\n", "]\n", "\n", + "# Sync chat\n", "response = llm.chat(messages)\n", - "print(response)" + "print(response)\n", + "\n", + "\n", + "# Async chat\n", + "async def get_dog_name(llm, messages):\n", + " response = await llm.achat(messages)\n", + " print(response)\n", + "\n", + "\n", + "asyncio.run(get_dog_name(llm, messages))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Few-shot examples\n", + "#### Stream Generation\n", + "\n", + "For tasks involving the generation of long text, such as creating an extensive article or translating a large document, it can be advantageous to receive the response in parts, as the text is generated, instead of waiting for the complete text. This makes the application more responsive and efficient, especially when the generated text is extensive. We offer two approaches to meet this need: one synchronous and another asynchronous." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sync streaming chat\n", + "response = llm.stream_chat(messages)\n", + "for chunk in response:\n", + " print(chunk.delta, end=\"\", flush=True)\n", + "\n", + "\n", + "# Async streaming chat\n", + "async def get_dog_name_streaming(llm, messages):\n", + " async for chunk in await llm.astream_chat(messages):\n", + " print(chunk.delta, end=\"\", flush=True)\n", + "\n", + "\n", + "asyncio.run(get_dog_name_streaming(llm, messages))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 2 - Few-shot Examples with Complete\n", "\n", "We recommend using the `llm.complete()` method when using the model with few-shot examples" ] @@ -105,8 +147,39 @@ "Resenha: Apesar de longo, valeu o ingresso..\n", "Classe:\"\"\"\n", "\n", - "response = llm.complete(prompt, stopping_tokens=[\"\\n\"])\n", - "print(response)" + "# Sync complete\n", + "response = llm.complete(prompt)\n", + "print(response)\n", + "\n", + "\n", + "# Async complete\n", + "async def classify_review(llm, prompt):\n", + " response = await llm.acomplete(prompt)\n", + " print(response)\n", + "\n", + "\n", + "asyncio.run(classify_review(llm, prompt))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Sync streaming complete\n", + "response = llm.stream_complete(prompt)\n", + "for chunk in response:\n", + " print(chunk.delta, end=\"\", flush=True)\n", + "\n", + "\n", + "# Async streaming complete\n", + "async def classify_review_streaming(llm, prompt):\n", + " async for chunk in await llm.astream_complete(prompt):\n", + " print(chunk.delta, end=\"\", flush=True)\n", + "\n", + "\n", + "asyncio.run(classify_review_streaming(llm, prompt))" ] } ], diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/README.md b/llama-index-integrations/llms/llama-index-llms-maritalk/README.md index bf3413f76d27b..1277c35f2ce8c 100644 --- a/llama-index-integrations/llms/llama-index-llms-maritalk/README.md +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/README.md @@ -1 +1,19 @@ # LlamaIndex Llms Integration: Maritalk + +MariTalk is an assistant developed by the Brazilian company [Maritaca AI](https://www.maritaca.ai). MariTalk is based on language models that have been specially trained to understand Portuguese well. + +## Installation + +First, install the Llama Index library (and all its dependencies) using the following command: + +``` +$pip install llama-index llama-index-llms-maritalk +``` + +## API Key + +You will need and API key that can be obtained from chat.maritaca.ai (\"Chaves da API\" section). + +## Examples + +Examples of usage are presented at [llamahub.ai](https://docs.llamaindex.ai/en/stable/examples/llm/maritalk/). diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/base.py b/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/base.py index b9be1d2fb4dc5..bd851a94f50d2 100644 --- a/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/base.py +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/llama_index/llms/maritalk/base.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Sequence +from typing import Any, Sequence, Any, Dict, List, Union from llama_index.core.base.llms.types import ( ChatMessage, ChatResponse, @@ -16,10 +16,47 @@ llm_chat_callback, llm_completion_callback, ) + +from llama_index.core.base.llms.generic_utils import ( + achat_to_completion_decorator, + astream_chat_to_completion_decorator, + chat_to_completion_decorator, + stream_chat_to_completion_decorator, +) + +from requests import Response +from requests.exceptions import HTTPError +from http import HTTPStatus + import requests +import json import os +class MaritalkHTTPError(HTTPError): + def __init__(self, request_obj: Response) -> None: + self.request_obj = request_obj + try: + response_json = request_obj.json() + if "detail" in response_json: + api_message = response_json["detail"] + elif "message" in response_json: + api_message = response_json["message"] + else: + api_message = response_json + except Exception: + api_message = request_obj.text + + self.message = api_message + self.status_code = request_obj.status_code + + def __str__(self) -> str: + status_code_meaning = HTTPStatus(self.status_code).phrase + formatted_message = f"HTTP Error: {self.status_code} - {status_code_meaning}" + formatted_message += f"\nDetail: {self.message}" + return formatted_message + + class Maritalk(LLM): """Maritalk LLM. @@ -50,7 +87,19 @@ class Maritalk(LLM): ``` """ - api_key: Optional[str] = Field(default=None, description="Your MariTalk API key.") + api_key: str = Field( + default=None, + description="Your MariTalk API key.", + ) + + model: str = Field( + default="sabia-2-medium", + description="Chose one of the available models:\n" + "- `sabia-2-medium`\n" + "- `sabia-2-small`\n" + "- `maritalk-2024-01-08`", + ) + temperature: float = Field( default=0.7, gt=0.0, @@ -58,15 +107,18 @@ class Maritalk(LLM): description="Run inference with this temperature. Must be in the" "closed interval [0.0, 1.0].", ) + max_tokens: int = Field( default=512, gt=0, description="The maximum number of tokens to" "generate in the reply.", ) + do_sample: bool = Field( default=True, description="Whether or not to use sampling; use `True` to enable.", ) + top_p: float = Field( default=0.95, gt=0.0, @@ -92,6 +144,33 @@ def __init__(self, **kwargs) -> None: def class_name(cls) -> str: return "Maritalk" + def parse_messages_for_model( + self, messages: Sequence[ChatMessage] + ) -> List[Dict[str, Union[str, List[Union[str, Dict[Any, Any]]]]]]: + """ + Parses messages from LlamaIndex's format to the format expected by + the MariTalk API. + + Parameters: + messages (Sequence[ChatMessage]): A list of messages in LlamaIndex + format to be parsed. + + Returns: + A list of messages formatted for the MariTalk API. + """ + formatted_messages = [] + + for message in messages: + if message.role.value == MessageRole.USER: + role = "user" + elif message.role.value == MessageRole.ASSISTANT: + role = "assistant" + elif message.role.value == MessageRole.SYSTEM: + role = "system" + + formatted_messages.append({"role": role, "content": message.content}) + return formatted_messages + @property def metadata(self) -> LLMMetadata: return LLMMetadata( @@ -103,125 +182,204 @@ def metadata(self) -> LLMMetadata: @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: # Prepare the data payload for the Maritalk API - formatted_messages = [] - for msg in messages: - if msg.role == MessageRole.SYSTEM: - # Add system message as a user message - formatted_messages.append({"role": "user", "content": msg.content}) - # Follow it by an assistant message acknowledging it, to maintain conversation flow - formatted_messages.append({"role": "assistant", "content": "ok"}) - else: - # Format user and assistant messages as before - formatted_messages.append( - { - "role": "user" if msg.role == MessageRole.USER else "assistant", - "content": msg.content, - } - ) + formatted_messages = self.parse_messages_for_model(messages) data = { + "model": self.model, "messages": formatted_messages, "do_sample": self.do_sample, "max_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, + **kwargs, } - # Update data payload with additional kwargs if any - data.update(kwargs) - headers = {"authorization": f"Key {self.api_key}"} response = requests.post(self._endpoint, json=data, headers=headers) - if response.status_code == 429: - return ChatResponse( - message=ChatMessage( - role=MessageRole.SYSTEM, - content="Rate limited, please try again soon", - ), - raw=response.text, - ) - elif response.ok: - answer = response.json()["answer"] + + if response.ok: + answer = response.json().get("answer", "No answer found") return ChatResponse( message=ChatMessage(role=MessageRole.ASSISTANT, content=answer), raw=response.json(), ) else: - response.raise_for_status() # noqa: RET503 + raise MaritalkHTTPError(response) @llm_completion_callback() def complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponse: + complete_fn = chat_to_completion_decorator(self.chat) + return complete_fn(prompt, **kwargs) + + @llm_chat_callback() + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + try: + import httpx + + # Prepare the data payload for the Maritalk API + formatted_messages = self.parse_messages_for_model(messages) + + data = { + "model": self.model, + "messages": formatted_messages, + "do_sample": self.do_sample, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + **kwargs, + } + + headers = {"authorization": f"Key {self.api_key}"} + + async with httpx.AsyncClient() as client: + response = await client.post( + self._endpoint, json=data, headers=headers, timeout=None + ) + + if response.status_code == 200: + answer = response.json().get("answer", "No answer found") + return ChatResponse( + message=ChatMessage(role=MessageRole.ASSISTANT, content=answer), + raw=response.json(), + ) + else: + raise MaritalkHTTPError(response) + + except ImportError: + raise ImportError( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) + + @llm_completion_callback() + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + acomplete_fn = achat_to_completion_decorator(self.achat) + return await acomplete_fn(prompt, **kwargs) + + @llm_chat_callback() + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: # Prepare the data payload for the Maritalk API + formatted_messages = self.parse_messages_for_model(messages) + data = { - "messages": prompt, + "model": self.model, + "messages": formatted_messages, "do_sample": self.do_sample, "max_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, - "chat_mode": False, + "stream": True, + **kwargs, } - # Update data payload with additional kwargs if any - data.update(kwargs) - headers = {"authorization": f"Key {self.api_key}"} - response = requests.post(self._endpoint, json=data, headers=headers) - if response.status_code == 429: - return CompletionResponse( - text="Rate limited, please try again soon", - raw=response.text, + def gen() -> ChatResponseGen: + response = requests.post( + self._endpoint, json=data, headers=headers, stream=True ) - elif response.ok: - answer = response.json()["answer"] - return CompletionResponse( - text=answer, - raw=response.json(), - ) - else: - response.raise_for_status() # noqa: RET503 + if response.ok: + content = "" + for line in response.iter_lines(): + if line.startswith(b"data: "): + response_data = line.replace(b"data: ", b"").decode("utf-8") + if response_data: + parsed_data = json.loads(response_data) + if "text" in parsed_data: + content_delta = parsed_data["text"] + content += content_delta + yield ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, content=content + ), + delta=content_delta, + raw=parsed_data, + ) + else: + raise MaritalkHTTPError(response) - def stream_chat( - self, messages: Sequence[ChatMessage], **kwargs: Any - ) -> ChatResponseGen: - raise NotImplementedError( - "Maritalk does not currently support streaming completion." - ) + return gen() + @llm_completion_callback() def stream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseGen: - raise NotImplementedError( - "Maritalk does not currently support streaming completion." - ) - - @llm_chat_callback() - async def achat( - self, messages: Sequence[ChatMessage], **kwargs: Any - ) -> ChatResponse: - return self.chat(messages, **kwargs) - - @llm_completion_callback() - async def acomplete( - self, prompt: str, formatted: bool = False, **kwargs: Any - ) -> CompletionResponse: - return self.complete(prompt, formatted, **kwargs) + stream_complete_fn = stream_chat_to_completion_decorator(self.stream_chat) + return stream_complete_fn(prompt, **kwargs) @llm_chat_callback() async def astream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseAsyncGen: - raise NotImplementedError( - "Maritalk does not currently support streaming completion." - ) + try: + import httpx + + # Prepare the data payload for the Maritalk API + formatted_messages = self.parse_messages_for_model(messages) + + data = { + "model": self.model, + "messages": formatted_messages, + "do_sample": self.do_sample, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "stream": True, + **kwargs, + } + + headers = {"authorization": f"Key {self.api_key}"} + + async def gen() -> ChatResponseAsyncGen: + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", + self._endpoint, + data=json.dumps(data), + headers=headers, + timeout=None, + ) as response: + if response.status_code == 200: + content = "" + async for line in response.aiter_lines(): + if line.startswith("data: "): + response_data = line.replace("data: ", "") + if response_data: + parsed_data = json.loads(response_data) + if "text" in parsed_data: + content_delta = parsed_data["text"] + content += content_delta + yield ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, + content=content, + ), + delta=content_delta, + raw=parsed_data, + ) + else: + raise MaritalkHTTPError(response) + + return gen() + + except ImportError: + raise ImportError( + "Could not import httpx python package. " + "Please install it with `pip install httpx`." + ) @llm_completion_callback() async def astream_complete( self, prompt: str, formatted: bool = False, **kwargs: Any ) -> CompletionResponseAsyncGen: - raise NotImplementedError( - "Maritalk does not currently support streaming completion." - ) + astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat) + return await astream_complete_fn(prompt, **kwargs) diff --git a/llama-index-integrations/llms/llama-index-llms-maritalk/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-maritalk/pyproject.toml index 7724410818352..e473e803e75ff 100644 --- a/llama-index-integrations/llms/llama-index-llms-maritalk/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-maritalk/pyproject.toml @@ -30,7 +30,7 @@ license = "MIT" name = "llama-index-llms-maritalk" packages = [{include = "llama_index/"}] readme = "README.md" -version = "0.1.1" +version = "0.2.0" [tool.poetry.dependencies] python = ">=3.8.1,<3.12"