diff --git a/agentops/llms/__init__.py b/agentops/llms/__init__.py index 8c7ba5f4a..380970af8 100644 --- a/agentops/llms/__init__.py +++ b/agentops/llms/__init__.py @@ -13,6 +13,7 @@ from .ollama import OllamaProvider from .openai import OpenAiProvider from .anthropic import AnthropicProvider +from .ai21 import AI21Provider original_func = {} original_create = None @@ -39,6 +40,12 @@ class LlmTracker: "anthropic": { "0.32.0": ("completions.create",), }, + "ai21": { + "2.0.0": ( + "chat.completions.create", + "client.answer.create", + ), + }, } def __init__(self, client): @@ -135,6 +142,22 @@ def override_api(self): f"Only Anthropic>=0.32.0 supported. v{module_version} found." ) + if api == "ai21": + module_version = version(api) + + if module_version is None: + logger.warning( + f"Cannot determine AI21 version. Only AI21>=2.0.0 supported." + ) + + if Version(module_version) >= parse("2.0.0"): + provider = AI21Provider(self.client) + provider.override() + else: + logger.warning( + f"Only AI21>=2.0.0 supported. v{module_version} found." + ) + def stop_instrumenting(self): OpenAiProvider(self.client).undo_override() GroqProvider(self.client).undo_override() @@ -142,3 +165,4 @@ def stop_instrumenting(self): LiteLLMProvider(self.client).undo_override() OllamaProvider(self.client).undo_override() AnthropicProvider(self.client).undo_override() + AI21Provider(self.client).undo_override() diff --git a/agentops/llms/ai21.py b/agentops/llms/ai21.py new file mode 100644 index 000000000..533ad2761 --- /dev/null +++ b/agentops/llms/ai21.py @@ -0,0 +1,251 @@ +import inspect +import pprint +from typing import Optional + +from agentops.llms.instrumented_provider import InstrumentedProvider +from agentops.time_travel import fetch_completion_override_from_time_travel_cache + +from ..event import ErrorEvent, LLMEvent, ActionEvent, ToolEvent +from ..session import Session +from ..log_config import logger +from ..helpers import check_call_stack_for_agent_id, get_ISO_time +from ..singleton import singleton + + +@singleton +class AI21Provider(InstrumentedProvider): + + original_create = None + original_create_async = None + original_answer = None + original_answer_async = None + + def __init__(self, client): + super().__init__(client) + self._provider_name = "AI21" + + def handle_response( + self, response, kwargs, init_timestamp, session: Optional[Session] = None + ): + """Handle responses for AI21""" + from ai21.stream.stream import Stream + from ai21.stream.async_stream import AsyncStream + from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk + from ai21.models.chat.chat_completion_response import ChatCompletionResponse + from ai21.models.responses.answer_response import AnswerResponse + + llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) + action_event = ActionEvent(init_timestamp=init_timestamp, params=kwargs) + + if session is not None: + llm_event.session_id = session.session_id + + def handle_stream_chunk(chunk: ChatCompletionChunk): + # We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion + if llm_event.returns is None: + llm_event.returns = chunk + # Manually setting content to empty string to avoid error + llm_event.returns.choices[0].delta.content = "" + + try: + accumulated_delta = llm_event.returns.choices[0].delta + llm_event.agent_id = check_call_stack_for_agent_id() + llm_event.model = kwargs["model"] + llm_event.prompt = [ + message.model_dump() for message in kwargs["messages"] + ] + + # NOTE: We assume for completion only choices[0] is relevant + choice = chunk.choices[0] + + if choice.delta.content: + accumulated_delta.content += choice.delta.content + + if choice.delta.role: + accumulated_delta.role = choice.delta.role + + if getattr("choice.delta", "tool_calls", None): + accumulated_delta.tool_calls += ToolEvent(logs=choice.delta.tools) + + if choice.finish_reason: + # Streaming is done. Record LLMEvent + llm_event.returns.choices[0].finish_reason = choice.finish_reason + llm_event.completion = { + "role": accumulated_delta.role, + "content": accumulated_delta.content, + } + llm_event.prompt_tokens = chunk.usage.prompt_tokens + llm_event.completion_tokens = chunk.usage.completion_tokens + llm_event.end_timestamp = get_ISO_time() + self._safe_record(session, llm_event) + + except Exception as e: + self._safe_record( + session, ErrorEvent(trigger_event=llm_event, exception=e) + ) + + kwargs_str = pprint.pformat(kwargs) + chunk = pprint.pformat(chunk) + logger.warning( + f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n" + f"chunk:\n {chunk}\n" + f"kwargs:\n {kwargs_str}\n" + ) + + # if the response is a generator, decorate the generator + # For synchronous Stream + if isinstance(response, Stream): + + def generator(): + for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return generator() + + # For asynchronous AsyncStream + if isinstance(response, AsyncStream): + + async def async_generator(): + async for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return async_generator() + + # Handle object responses + try: + if isinstance(response, ChatCompletionResponse): + llm_event.returns = response + llm_event.agent_id = check_call_stack_for_agent_id() + llm_event.model = kwargs["model"] + llm_event.prompt = [ + message.model_dump() for message in kwargs["messages"] + ] + llm_event.prompt_tokens = response.usage.prompt_tokens + llm_event.completion = response.choices[0].message.model_dump() + llm_event.completion_tokens = response.usage.completion_tokens + llm_event.end_timestamp = get_ISO_time() + self._safe_record(session, llm_event) + + elif isinstance(response, AnswerResponse): + action_event.returns = response + action_event.agent_id = check_call_stack_for_agent_id() + action_event.action_type = "Contextual Answers" + action_event.logs = [ + {"context": kwargs["context"], "question": kwargs["question"]}, + response.model_dump() if response.model_dump() else None, + ] + action_event.end_timestamp = get_ISO_time() + self._safe_record(session, action_event) + + except Exception as e: + self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) + kwargs_str = pprint.pformat(kwargs) + response = pprint.pformat(response) + logger.warning( + f"Unable to parse response for LLM call. Skipping upload to AgentOps\n" + f"response:\n {response}\n" + f"kwargs:\n {kwargs_str}\n" + ) + + return response + + def override(self): + self._override_completion() + self._override_completion_async() + self._override_answer() + self._override_answer_async() + + def _override_completion(self): + from ai21.clients.studio.resources.chat import ChatCompletions + + global original_create + original_create = ChatCompletions.create + + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = original_create(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + # Override the original method with the patched one + ChatCompletions.create = patched_function + + def _override_completion_async(self): + from ai21.clients.studio.resources.chat import AsyncChatCompletions + + global original_create_async + original_create_async = AsyncChatCompletions.create + + async def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = await original_create_async(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + # Override the original method with the patched one + AsyncChatCompletions.create = patched_function + + def _override_answer(self): + from ai21.clients.studio.resources.studio_answer import StudioAnswer + + global original_answer + original_answer = StudioAnswer.create + + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = original_answer(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + StudioAnswer.create = patched_function + + def _override_answer_async(self): + from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer + + global original_answer_async + original_answer_async = AsyncStudioAnswer.create + + async def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = await original_answer_async(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + AsyncStudioAnswer.create = patched_function + + def undo_override(self): + if ( + self.original_create is not None + and self.original_create_async is not None + and self.original_answer is not None + and self.original_answer_async is not None + ): + from ai21.clients.studio.resources.chat import ( + ChatCompletions, + AsyncChatCompletions, + ) + from ai21.clients.studio.resources.studio_answer import ( + StudioAnswer, + AsyncStudioAnswer, + ) + + ChatCompletions.create = self.original_create + AsyncChatCompletions.create = self.original_create_async + StudioAnswer.create = self.original_answer + AsyncStudioAnswer.create = self.original_answer_async diff --git a/docs/v1/integrations/ai21.mdx b/docs/v1/integrations/ai21.mdx new file mode 100644 index 000000000..aab669ab5 --- /dev/null +++ b/docs/v1/integrations/ai21.mdx @@ -0,0 +1,143 @@ +--- +title: AI21 +description: "Use AI21's latest models with AgentOps including: +Jamba 1.5, Jamba Instruct, and specialized task models" +--- + +import CodeTooltip from '/snippets/add-code-tooltip.mdx' +import EnvTooltip from '/snippets/add-env-tooltip.mdx' + +## AI21 + +From [AI21's docs](https://docs.ai21.com/): + +AI21 provides state-of-the-art language models through a simple API, offering: +- Multiple model sizes to balance performance and cost +- Specialized models for specific tasks like contextual answers +- Chat and completion endpoints +- Enterprise-grade reliability and support + +AI21 supports several including Jamba 1.5, Jamba Instruct, and task-specific models. + +## Using AgentOps with AI21 + +### Requires `ai21>=2.0.0` + +AgentOps works seamlessly with AI21's Python SDK. Here's how to use it: + + + + + ```bash pip + pip install agentops + ``` + ```bash poetry + poetry add agentops + ``` + + + + + + + ```python python + from ai21 import AI21Client + from ai21.models.chat import ChatMessage + import agentops + + # Initialize clients + agentops.init() + client = AI21Client(api_key="your-api-key") + + # Your AI21 code here... + + agentops.end_session("Success") # Success|Fail|Indeterminate + ``` + + + + + + ```python .env + AGENTOPS_API_KEY= + ``` + + Read more about environment variables in [Advanced Configuration](/v1/usage/advanced-configuration) + + + + Execute your program and visit [app.agentops.ai/drilldown](https://app.agentops.ai/drilldown) to observe your Agent! 🕵️ + + After your run, AgentOps prints a clickable url to console linking directly to your session in the Dashboard + +
{/* Intentionally blank div for newline */} + + + + + + +### Streaming Support + +AI21 supports streaming responses: + +```python +response = "" +stream_response = client.chat.completions.create( + messages=messages, + model="jamba-instruct", + stream=True, +) + +for chunk in stream_response: + response += str(chunk.choices[0].delta.content) +``` + +### Async Support + +You can also use AI21 models asynchronously: + +```python +from ai21 import AsyncAI21Client + +aclient = AsyncAI21Client(api_key="your-api-key") + +async def main(): + async_response = await aclient.chat.completions.create( + messages=messages, + model="jamba-1.5-mini", + ) + print(async_response.choices[0].message.content) + +await main() +``` + +### Task-Specific Models + +AI21 provides specialized models for specific tasks. Here's an example using the contextual answers endpoint: + +```python +response = client.answer.create( + context="Your context text here...", + question="Your question here?", +) +print(response.answer) +``` + +You can also stream answers: + +```python +response = client.answer.create( + context="Your context text here...", + question="Your question here?", + stream=True, +) +print(response.answer) +``` + + + + + + + \ No newline at end of file diff --git a/examples/ai21_examples/ai21_examples.ipynb b/examples/ai21_examples/ai21_examples.ipynb new file mode 100644 index 000000000..16afb4ad7 --- /dev/null +++ b/examples/ai21_examples/ai21_examples.ipynb @@ -0,0 +1,342 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# AI21 Example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First let's install the required packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -U ai21\n", + "%pip install -U agentops" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then import them" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from ai21 import AI21Client, AsyncAI21Client\n", + "from ai21.models.chat import ChatMessage\n", + "from dotenv import load_dotenv\n", + "import os\n", + "import asyncio\n", + "import agentops" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll grab our API keys. You can use dotenv like below or however else you like to load environment variables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv()\n", + "AI21_API_KEY = os.getenv(\"AI2I_API_KEY\") or \"\"\n", + "AGENTOPS_API_KEY = os.getenv(\"AGENTOPS_API_KEY\") or \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agentops.init(AGENTOPS_API_KEY, default_tags=[\"ai21-example\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setting up Messages\n", + "AI21 clients use a `ChatMessage` object to handle messages. We setup the following system prompt to guide the model in its response and a user prompt as well. We take the example of a support agent in a SaaS company." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " ChatMessage(\n", + " content=\"You are a world renowned poet in the style of Edgar Allan Poe.\",\n", + " role=\"system\",\n", + " ),\n", + " ChatMessage(\n", + " content=\"Write me a short poem about the AI agents co-existing within the human brain.\",\n", + " role=\"user\",\n", + " ),\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sync Examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will demonstrate a basic sync call to AI21 using the Jamba 1.5 model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "client = AI21Client(api_key=AI21_API_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = client.chat.completions.create(\n", + " messages=messages,\n", + " model=\"jamba-1.5-mini\",\n", + ")\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following example shows how to record data from the streamed response using the Jamba 1.5 model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = \"\"\n", + "\n", + "stream_response = client.chat.completions.create(\n", + " messages=messages,\n", + " model=\"jamba-instruct\",\n", + " stream=True,\n", + ")\n", + "\n", + "for chunk in stream_response:\n", + " response += str(chunk.choices[0].delta.content)\n", + "\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async Example\n", + "The async example is very similar to the sync example, but it uses the `AsyncAI21Client` class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "aclient = AsyncAI21Client(api_key=AI21_API_KEY)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def main():\n", + " async_response = await aclient.chat.completions.create(\n", + " messages=messages,\n", + " model=\"jamba-1.5-mini\",\n", + " )\n", + " print(async_response.choices[0].message.content)\n", + "\n", + "\n", + "await main()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following example shows how to record data from the async streamed response." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def main():\n", + " response = \"\"\n", + "\n", + " async_stream_response = await aclient.chat.completions.create(\n", + " messages=messages,\n", + " model=\"jamba-1.5-mini\",\n", + " stream=True,\n", + " )\n", + "\n", + " async for chunk in async_stream_response:\n", + " response += chunk.choices[0].delta.content\n", + "\n", + " print(response)\n", + "\n", + "\n", + "await main()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Task-Specific Models Examples" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Contextual Answers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following example demonstrates the answering capability of AI21 without streaming." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONTEXT = \"\"\"\n", + "In 2020 and 2021, enormous QE — approximately $4.4 trillion, or 18%, of 2021 gross\n", + "domestic product (GDP) — and enormous fiscal stimulus (which has been and\n", + "always will be inflationary) — approximately $5 trillion, or 21%, of 2021 GDP\n", + "— stabilized markets and allowed companies to raise enormous amounts of\n", + "capital. In addition, this infusion of capital saved many small businesses and\n", + "put more than $2.5 trillion in the hands of consumers and almost $1 trillion into\n", + "state and local coffers. These actions led to a rapid decline in unemployment, \n", + "dropping from 15% to under 4% in 20 months — the magnitude and speed of which were both\n", + "unprecedented. Additionally, the economy grew 7% in 2021 despite the arrival of\n", + "the Delta and Omicron variants and the global supply chain shortages, which were\n", + "largely fueled by the dramatic upswing in consumer spending and the shift in\n", + "that spend from services to goods.\n", + "\"\"\"\n", + "response = client.answer.create(\n", + " context=CONTEXT,\n", + " question=\"Did the economy shrink after the Omicron variant arrived?\",\n", + ")\n", + "print(response.answer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Similarly, we can use streaming to get the answer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONTEXT = \"\"\"\n", + "In the rapidly evolving field of Artificial Intelligence (AI), mathematical \n", + "foundations such as calculus, linear algebra, and statistics play a crucial role. \n", + "For instance, linear algebra is essential for understanding and developing machine \n", + "learning algorithms. It involves the study of vectors, matrices, and tensor operations \n", + "which are critical for performing transformations and optimizations. Additionally, \n", + "concepts from calculus like derivatives and integrals are used to optimize the \n", + "performance of AI models through gradient descent and other optimization techniques. \n", + "Statistics and probability form the backbone for making inferences and predictions, \n", + "enabling AI systems to learn from data and make decisions under uncertainty. \n", + "Understanding these mathematical principles allows for the development of more robust \n", + "and effective AI systems.\n", + "\"\"\"\n", + "response = client.answer.create(\n", + " context=CONTEXT,\n", + " question=\"Why is linear algebra important for machine learning algorithms?\",\n", + " stream=True,\n", + ")\n", + "print(response.answer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agentops.end_session(\"Success\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ops", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/core_manual_tests/providers/ai21_canary.py b/tests/core_manual_tests/providers/ai21_canary.py new file mode 100644 index 000000000..e38be654a --- /dev/null +++ b/tests/core_manual_tests/providers/ai21_canary.py @@ -0,0 +1,64 @@ +import asyncio + +import agentops +from dotenv import load_dotenv +import os +import ai21 +from ai21.models.chat import ChatMessage + +load_dotenv() +agentops.init(default_tags=["ai21-provider-test"]) + +api_key = os.getenv("AI2I_API_KEY") +ai21_client = ai21.AI21Client(api_key=api_key) +async_ai21_client = ai21.AsyncAI21Client(api_key=api_key) + +messages = [ + ChatMessage(content="You are an expert mathematician.", role="system"), + ChatMessage( + content="Write a summary of 5 lines on the Shockley diode equation.", + role="user", + ), +] + +response = ai21_client.chat.completions.create( + model="jamba-1.5-mini", + messages=messages, +) + + +stream_response = ai21_client.chat.completions.create( + model="jamba-1.5-mini", + messages=messages, + stream=True, +) + +response = "" +for chunk in stream_response: + response += chunk.choices[0].delta.content +print(response) + + +async def async_test(): + async_response = await async_ai21_client.chat.completions.create( + model="jamba-1.5-mini", + messages=messages, + ) + print(async_response) + + +asyncio.run(async_test()) + +agentops.stop_instrumenting() + +untracked_response = ai21_client.chat.completions.create( + model="jamba-1.5-mini", + messages=messages, +) + + +agentops.end_session(end_state="Success") + +### +# Used to verify that one session is created with one LLM event +###