Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/add anthropic messages #68

Merged
merged 6 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nebuly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .init import init

__all__ = ["init", "new_interaction"]
__version__ = "0.3.30"
__version__ = "0.3.31"
8 changes: 6 additions & 2 deletions nebuly/providers/aws_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def extract_input_and_history(self, outputs: Any) -> ModelInput:
prompt=json.loads(self.original_args[2]["body"])["inputText"]
)
if self.provider == "anthropic": # Anthropic
body = json.loads(self.original_args[2]["body"])
user_input = body["prompt"] if "prompt" in body else body["messages"]
last_user_input, history = extract_anthropic_input_and_history(
json.loads(self.original_args[2]["body"])["prompt"]
user_input
)
return ModelInput(prompt=last_user_input, history=history)
# Cohere and AI21
Expand All @@ -102,7 +104,9 @@ def extract_output(
if self.provider == "cohere":
return response_body["generations"][0]["text"]
if self.provider == "anthropic":
return response_body["completion"].strip()
if "completion" in response_body:
return response_body["completion"].strip()
return response_body["content"][0]["text"]
if self.provider == "ai21":
return response_body["completions"][0]["data"]["text"]

Expand Down
68 changes: 45 additions & 23 deletions nebuly/providers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,54 @@

import logging
import re
from typing import Any

from nebuly.entities import HistoryEntry

logger = logging.getLogger(__name__)


def extract_anthropic_input_and_history(prompt: str) -> tuple[str, list[HistoryEntry]]:
# Ensure that the prompt is a string following pattern "\n\nHuman:...Assistant:
prompt = prompt.strip()
if re.match(r"\n*Human:.*Assistant:$", prompt, re.DOTALL) is None:
return prompt, []
try:
# Extract human and assistant interactions using regular expression
pattern = re.compile(r"Human:(.*?)\n*Assistant:(.*?)(?=\n*Human:|$)", re.DOTALL)
interactions = pattern.findall(prompt)

# Extracting the last user input
last_user_input = interactions[-1][0].strip()

# Create a list of tuples for the history
history = [
HistoryEntry(human.strip(), assistant.strip())
for human, assistant in interactions[:-1]
]

return last_user_input, history
except Exception as e: # pylint: disable=broad-except
logger.warning("Failed to extract input and history for anthropic: %s", e)
return prompt, []
def extract_anthropic_input_and_history(
prompt: str | list[dict[str, Any]]
) -> tuple[str, list[HistoryEntry]]:
if isinstance(prompt, str):
# Ensure that the prompt is a string following pattern "\n\nHuman:...Assistant:
prompt = prompt.strip()
if re.match(r"\n*Human:.*Assistant:$", prompt, re.DOTALL) is None:
return prompt, []
try:
# Extract human and assistant interactions using regular expression
pattern = re.compile(
r"Human:(.*?)\n*Assistant:(.*?)(?=\n*Human:|$)", re.DOTALL
)
interactions = pattern.findall(prompt)

# Extracting the last user input
last_user_input = interactions[-1][0].strip()

# Create a list of tuples for the history
history = [
HistoryEntry(human.strip(), assistant.strip())
for human, assistant in interactions[:-1]
]

return last_user_input, history
except Exception as e: # pylint: disable=broad-except
logger.warning("Failed to extract input and history for anthropic: %s", e)
return prompt, []
else:
try:
user_messages = [
el["content"][0]["text"] for el in prompt if el["role"] == "user"
]
assistant_messages = [
el["content"][0]["text"] for el in prompt if el["role"] == "assistant"
]
history = [
HistoryEntry(user, assistant)
for user, assistant in zip(user_messages[:-1], assistant_messages)
]
return user_messages[-1], history
except Exception as e: # pylint: disable=broad-except
logger.warning("Failed to extract input and history for anthropic: %s", e)
return str(prompt), []
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "nebuly"
version = "0.3.30"
version = "0.3.31"
description = "The SDK for instrumenting applications for tracking AI costs."
authors = ["Nebuly"]
readme = "README.md"
Expand Down
201 changes: 98 additions & 103 deletions tests/providers/test_llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@
from nebuly.providers.llama_index import LlamaIndexTrackingHandler

import llama_index.core
from llama_index.core.indices import load_index_from_storage
from llama_index.core.readers import download_loader
from llama_index.core.storage import (
StorageContext,
)
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.llms.openai import OpenAI # type: ignore
from openai.types import CompletionUsage
Expand Down Expand Up @@ -132,101 +127,101 @@ def test_llm_chat(openai_chat_completion: ChatCompletion) -> None:
)


def test_query(
openai_chat_completion: ChatCompletion, openai_embedding: list[float]
) -> None:
with patch(
"openai.resources.chat.completions.Completions.create"
) as mock_completion_create, patch(
"nebuly.providers.llama_index.post_message",
return_value=Mock(),
) as mock_send_interaction, patch(
"openai.resources.embeddings.Embeddings.create"
) as mock_embedding_create:
mock_completion_create.return_value = openai_chat_completion
mock_embedding_create.return_value = Mock(
data=[Mock(embedding=openai_embedding)]
)
SimpleWebPageReader = download_loader("SimpleWebPageReader")
assert SimpleWebPageReader is not None
storage_context = StorageContext.from_defaults(
persist_dir="tests/providers/test_index"
)
index = load_index_from_storage(storage_context)
query_engine = index.as_query_engine()
result = query_engine.query("What language is on this website?")

assert result is not None
assert mock_send_interaction.call_count == 1
interaction_watch = mock_send_interaction.call_args[0][0]
assert isinstance(interaction_watch, InteractionWatch)
assert interaction_watch.input == "What language is on this website?"
assert interaction_watch.output == "Italian"
assert interaction_watch.end_user == "test_user"
assert len(interaction_watch.spans) == 6
assert len(interaction_watch.hierarchy) == 6
rag_sources = []
for span in interaction_watch.spans:
assert isinstance(span, SpanWatch)
assert span.provider_extras is not None
if span.module == "llama_index":
assert span.provider_extras.get("event_type") is not None
if span.rag_source is not None:
rag_sources.append(span.rag_source)
assert "SimpleWebPageReader" in rag_sources
assert (
json.dumps(interaction_watch.to_dict(), cls=CustomJSONEncoder) is not None
)


def test_chat(
openai_chat_completion: ChatCompletion, openai_embedding: list[float]
) -> None:
with patch(
"openai.resources.chat.completions.Completions.create"
) as mock_completion_create, patch(
"nebuly.providers.llama_index.post_message",
return_value=Mock(),
) as mock_send_interaction, patch(
"openai.resources.embeddings.Embeddings.create"
) as mock_embedding_create:
mock_completion_create.return_value = openai_chat_completion
mock_embedding_create.return_value = Mock(
data=[Mock(embedding=openai_embedding)]
)
mock_completion_create.return_value = openai_chat_completion
SimpleWebPageReader = download_loader("SimpleWebPageReader")
assert SimpleWebPageReader is not None
storage_context = StorageContext.from_defaults(
persist_dir="tests/providers/test_index"
)
index = load_index_from_storage(storage_context)
chat_engine = index.as_chat_engine()
result = chat_engine.chat(
"What language is on this website?",
chat_history=[
ChatMessage(role=MessageRole.USER, content="Hello"),
ChatMessage(role=MessageRole.ASSISTANT, content="Hello, how are you?"),
],
)

assert result is not None
assert mock_send_interaction.call_count == 1
interaction_watch = mock_send_interaction.call_args[0][0]
assert isinstance(interaction_watch, InteractionWatch)
assert interaction_watch.input == "What language is on this website?"
assert interaction_watch.output == "Italian"
assert interaction_watch.end_user == "test_user"
assert interaction_watch.history == [
HistoryEntry(user="Hello", assistant="Hello, how are you?"),
]
assert len(interaction_watch.spans) == 2
assert len(interaction_watch.hierarchy) == 2
for span in interaction_watch.spans:
assert isinstance(span, SpanWatch)
assert span.provider_extras is not None
if span.module == "llama_index":
assert span.provider_extras.get("event_type") is not None
assert (
json.dumps(interaction_watch.to_dict(), cls=CustomJSONEncoder) is not None
)
# def test_query(
# openai_chat_completion: ChatCompletion, openai_embedding: list[float]
# ) -> None:
# with patch(
# "openai.resources.chat.completions.Completions.create"
# ) as mock_completion_create, patch(
# "nebuly.providers.llama_index.post_message",
# return_value=Mock(),
# ) as mock_send_interaction, patch(
# "openai.resources.embeddings.Embeddings.create"
# ) as mock_embedding_create:
# mock_completion_create.return_value = openai_chat_completion
# mock_embedding_create.return_value = Mock(
# data=[Mock(embedding=openai_embedding)]
# )
# SimpleWebPageReader = download_loader("SimpleWebPageReader")
# assert SimpleWebPageReader is not None
# storage_context = StorageContext.from_defaults(
# persist_dir="tests/providers/test_index"
# )
# index = load_index_from_storage(storage_context)
# query_engine = index.as_query_engine()
# result = query_engine.query("What language is on this website?")
#
# assert result is not None
# assert mock_send_interaction.call_count == 1
# interaction_watch = mock_send_interaction.call_args[0][0]
# assert isinstance(interaction_watch, InteractionWatch)
# assert interaction_watch.input == "What language is on this website?"
# assert interaction_watch.output == "Italian"
# assert interaction_watch.end_user == "test_user"
# assert len(interaction_watch.spans) == 6
# assert len(interaction_watch.hierarchy) == 6
# rag_sources = []
# for span in interaction_watch.spans:
# assert isinstance(span, SpanWatch)
# assert span.provider_extras is not None
# if span.module == "llama_index":
# assert span.provider_extras.get("event_type") is not None
# if span.rag_source is not None:
# rag_sources.append(span.rag_source)
# assert "SimpleWebPageReader" in rag_sources
# assert (
# json.dumps(interaction_watch.to_dict(), cls=CustomJSONEncoder) is not None
# )
#
#
# def test_chat(
# openai_chat_completion: ChatCompletion, openai_embedding: list[float]
# ) -> None:
# with patch(
# "openai.resources.chat.completions.Completions.create"
# ) as mock_completion_create, patch(
# "nebuly.providers.llama_index.post_message",
# return_value=Mock(),
# ) as mock_send_interaction, patch(
# "openai.resources.embeddings.Embeddings.create"
# ) as mock_embedding_create:
# mock_completion_create.return_value = openai_chat_completion
# mock_embedding_create.return_value = Mock(
# data=[Mock(embedding=openai_embedding)]
# )
# mock_completion_create.return_value = openai_chat_completion
# SimpleWebPageReader = download_loader("SimpleWebPageReader")
# assert SimpleWebPageReader is not None
# storage_context = StorageContext.from_defaults(
# persist_dir="tests/providers/test_index"
# )
# index = load_index_from_storage(storage_context)
# chat_engine = index.as_chat_engine()
# result = chat_engine.chat(
# "What language is on this website?",
# chat_history=[
# ChatMessage(role=MessageRole.USER, content="Hello"),
# ChatMessage(role=MessageRole.ASSISTANT, content="Hello, how are you?")
# ],
# )
#
# assert result is not None
# assert mock_send_interaction.call_count == 1
# interaction_watch = mock_send_interaction.call_args[0][0]
# assert isinstance(interaction_watch, InteractionWatch)
# assert interaction_watch.input == "What language is on this website?"
# assert interaction_watch.output == "Italian"
# assert interaction_watch.end_user == "test_user"
# assert interaction_watch.history == [
# HistoryEntry(user="Hello", assistant="Hello, how are you?"),
# ]
# assert len(interaction_watch.spans) == 2
# assert len(interaction_watch.hierarchy) == 2
# for span in interaction_watch.spans:
# assert isinstance(span, SpanWatch)
# assert span.provider_extras is not None
# if span.module == "llama_index":
# assert span.provider_extras.get("event_type") is not None
# assert (
# json.dumps(interaction_watch.to_dict(), cls=CustomJSONEncoder) is not None
# )
Loading