Skip to content

Commit

Permalink
[FEATURE] Add support for AI21 via the ai21 Python SDK (#382)
Browse files Browse the repository at this point in the history
* add initial files for support

* working sync client

* stream not working

* updated examples notebook for current testing

* fix for `delta.content` and cleanup

* cleanup again

* cleanup and add tool event

* structure examples notebook

* add contextual answers tracking

* cleanup example notebook

* create testing file

* clean example notebook again

* rename examples directory

* updated docs page

* wrap `chunk.choices[0].delta.content` in `str(...)`

* update doc

---------

Co-authored-by: reibs <[email protected]>
  • Loading branch information
the-praxs and areibman authored Nov 5, 2024
1 parent e19d56d commit 1b8f473
Show file tree
Hide file tree
Showing 5 changed files with 824 additions and 0 deletions.
24 changes: 24 additions & 0 deletions agentops/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .ollama import OllamaProvider
from .openai import OpenAiProvider
from .anthropic import AnthropicProvider
from .ai21 import AI21Provider

original_func = {}
original_create = None
Expand All @@ -39,6 +40,12 @@ class LlmTracker:
"anthropic": {
"0.32.0": ("completions.create",),
},
"ai21": {
"2.0.0": (
"chat.completions.create",
"client.answer.create",
),
},
}

def __init__(self, client):
Expand Down Expand Up @@ -135,10 +142,27 @@ def override_api(self):
f"Only Anthropic>=0.32.0 supported. v{module_version} found."
)

if api == "ai21":
module_version = version(api)

if module_version is None:
logger.warning(
f"Cannot determine AI21 version. Only AI21>=2.0.0 supported."
)

if Version(module_version) >= parse("2.0.0"):
provider = AI21Provider(self.client)
provider.override()
else:
logger.warning(
f"Only AI21>=2.0.0 supported. v{module_version} found."
)

def stop_instrumenting(self):
OpenAiProvider(self.client).undo_override()
GroqProvider(self.client).undo_override()
CohereProvider(self.client).undo_override()
LiteLLMProvider(self.client).undo_override()
OllamaProvider(self.client).undo_override()
AnthropicProvider(self.client).undo_override()
AI21Provider(self.client).undo_override()
251 changes: 251 additions & 0 deletions agentops/llms/ai21.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import inspect
import pprint
from typing import Optional

from agentops.llms.instrumented_provider import InstrumentedProvider
from agentops.time_travel import fetch_completion_override_from_time_travel_cache

from ..event import ErrorEvent, LLMEvent, ActionEvent, ToolEvent
from ..session import Session
from ..log_config import logger
from ..helpers import check_call_stack_for_agent_id, get_ISO_time
from ..singleton import singleton


@singleton
class AI21Provider(InstrumentedProvider):

original_create = None
original_create_async = None
original_answer = None
original_answer_async = None

def __init__(self, client):
super().__init__(client)
self._provider_name = "AI21"

def handle_response(
self, response, kwargs, init_timestamp, session: Optional[Session] = None
):
"""Handle responses for AI21"""
from ai21.stream.stream import Stream
from ai21.stream.async_stream import AsyncStream
from ai21.models.chat.chat_completion_chunk import ChatCompletionChunk
from ai21.models.chat.chat_completion_response import ChatCompletionResponse
from ai21.models.responses.answer_response import AnswerResponse

llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs)
action_event = ActionEvent(init_timestamp=init_timestamp, params=kwargs)

if session is not None:
llm_event.session_id = session.session_id

def handle_stream_chunk(chunk: ChatCompletionChunk):
# We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion
if llm_event.returns is None:
llm_event.returns = chunk
# Manually setting content to empty string to avoid error
llm_event.returns.choices[0].delta.content = ""

try:
accumulated_delta = llm_event.returns.choices[0].delta
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs["model"]
llm_event.prompt = [
message.model_dump() for message in kwargs["messages"]
]

# NOTE: We assume for completion only choices[0] is relevant
choice = chunk.choices[0]

if choice.delta.content:
accumulated_delta.content += choice.delta.content

if choice.delta.role:
accumulated_delta.role = choice.delta.role

if getattr("choice.delta", "tool_calls", None):
accumulated_delta.tool_calls += ToolEvent(logs=choice.delta.tools)

if choice.finish_reason:
# Streaming is done. Record LLMEvent
llm_event.returns.choices[0].finish_reason = choice.finish_reason
llm_event.completion = {
"role": accumulated_delta.role,
"content": accumulated_delta.content,
}
llm_event.prompt_tokens = chunk.usage.prompt_tokens
llm_event.completion_tokens = chunk.usage.completion_tokens
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)

except Exception as e:
self._safe_record(
session, ErrorEvent(trigger_event=llm_event, exception=e)
)

kwargs_str = pprint.pformat(kwargs)
chunk = pprint.pformat(chunk)
logger.warning(
f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n"
f"chunk:\n {chunk}\n"
f"kwargs:\n {kwargs_str}\n"
)

# if the response is a generator, decorate the generator
# For synchronous Stream
if isinstance(response, Stream):

def generator():
for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return generator()

# For asynchronous AsyncStream
if isinstance(response, AsyncStream):

async def async_generator():
async for chunk in response:
handle_stream_chunk(chunk)
yield chunk

return async_generator()

# Handle object responses
try:
if isinstance(response, ChatCompletionResponse):
llm_event.returns = response
llm_event.agent_id = check_call_stack_for_agent_id()
llm_event.model = kwargs["model"]
llm_event.prompt = [
message.model_dump() for message in kwargs["messages"]
]
llm_event.prompt_tokens = response.usage.prompt_tokens
llm_event.completion = response.choices[0].message.model_dump()
llm_event.completion_tokens = response.usage.completion_tokens
llm_event.end_timestamp = get_ISO_time()
self._safe_record(session, llm_event)

elif isinstance(response, AnswerResponse):
action_event.returns = response
action_event.agent_id = check_call_stack_for_agent_id()
action_event.action_type = "Contextual Answers"
action_event.logs = [
{"context": kwargs["context"], "question": kwargs["question"]},
response.model_dump() if response.model_dump() else None,
]
action_event.end_timestamp = get_ISO_time()
self._safe_record(session, action_event)

except Exception as e:
self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e))
kwargs_str = pprint.pformat(kwargs)
response = pprint.pformat(response)
logger.warning(
f"Unable to parse response for LLM call. Skipping upload to AgentOps\n"
f"response:\n {response}\n"
f"kwargs:\n {kwargs_str}\n"
)

return response

def override(self):
self._override_completion()
self._override_completion_async()
self._override_answer()
self._override_answer_async()

def _override_completion(self):
from ai21.clients.studio.resources.chat import ChatCompletions

global original_create
original_create = ChatCompletions.create

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = original_create(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
ChatCompletions.create = patched_function

def _override_completion_async(self):
from ai21.clients.studio.resources.chat import AsyncChatCompletions

global original_create_async
original_create_async = AsyncChatCompletions.create

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()
session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = await original_create_async(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

# Override the original method with the patched one
AsyncChatCompletions.create = patched_function

def _override_answer(self):
from ai21.clients.studio.resources.studio_answer import StudioAnswer

global original_answer
original_answer = StudioAnswer.create

def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = original_answer(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

StudioAnswer.create = patched_function

def _override_answer_async(self):
from ai21.clients.studio.resources.studio_answer import AsyncStudioAnswer

global original_answer_async
original_answer_async = AsyncStudioAnswer.create

async def patched_function(*args, **kwargs):
# Call the original function with its original arguments
init_timestamp = get_ISO_time()

session = kwargs.get("session", None)
if "session" in kwargs.keys():
del kwargs["session"]
result = await original_answer_async(*args, **kwargs)
return self.handle_response(result, kwargs, init_timestamp, session=session)

AsyncStudioAnswer.create = patched_function

def undo_override(self):
if (
self.original_create is not None
and self.original_create_async is not None
and self.original_answer is not None
and self.original_answer_async is not None
):
from ai21.clients.studio.resources.chat import (
ChatCompletions,
AsyncChatCompletions,
)
from ai21.clients.studio.resources.studio_answer import (
StudioAnswer,
AsyncStudioAnswer,
)

ChatCompletions.create = self.original_create
AsyncChatCompletions.create = self.original_create_async
StudioAnswer.create = self.original_answer
AsyncStudioAnswer.create = self.original_answer_async
Loading

0 comments on commit 1b8f473

Please sign in to comment.