Skip to content

Commit

Permalink
system-prompt-caching-claude-llm-phi-1299
Browse files Browse the repository at this point in the history
  • Loading branch information
ysolanky committed Sep 27, 2024
1 parent 6a8139b commit a7140ae
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 12 deletions.
46 changes: 46 additions & 0 deletions cookbook/llms/claude/prompt_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Inspired by: https://github.com/anthropics/anthropic-cookbook/blob/main/misc/prompt_caching.ipynb
import requests
from bs4 import BeautifulSoup

from phi.assistant import Assistant
from phi.llm.anthropic import Claude


def fetch_article_content(url):
response = requests.get(url)
soup = BeautifulSoup(response.content, "html.parser")
# Remove script and style elements
for script in soup(["script", "style"]):
script.decompose()
# Get text
text = soup.get_text()
# Break into lines and remove leading and trailing space on each
lines = (line.strip() for line in text.splitlines())
# Break multi-headlines into a line each
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
# Drop blank lines
text = "\n".join(chunk for chunk in chunks if chunk)
return text


# Fetch the content of the article
book_url = "https://www.gutenberg.org/cache/epub/1342/pg1342.txt"
book_content = fetch_article_content(book_url)

print(f"Fetched {len(book_content)} characters from the book.")

assistant = Assistant(
llm=Claude(
model="claude-3-5-sonnet-20240620",
cache_system_prompt=True,
),
system_prompt=book_content[:10000],
debug_mode=True,
)
assistant.print_response("Give me a one line summary of this book", markdown=True, stream=True)
print("Prompt cache creation tokens: ", assistant.llm.metrics["cache_creation_tokens"]) # type: ignore
print("Prompt cache read tokens: ", assistant.llm.metrics["cache_read_tokens"]) # type: ignore

# assistant.print_response("Give me a one line summary of this book", markdown=True, stream=False)
# print("Prompt cache creation tokens: ", assistant.llm.metrics["cache_creation_tokens"])
# print("Prompt cache read tokens: ", assistant.llm.metrics["cache_read_tokens"])
68 changes: 56 additions & 12 deletions phi/llm/anthropic/claude.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Optional, List, Iterator, Dict, Any, Union
from typing import Optional, List, Iterator, Dict, Any, Union, cast

from phi.llm.base import LLM
from phi.llm.message import Message
Expand All @@ -26,14 +26,15 @@

class Claude(LLM):
name: str = "claude"
model: str = "claude-3-opus-20240229"
model: str = "claude-3-5-sonnet-20240620"
# -*- Request parameters
max_tokens: Optional[int] = 1024
temperature: Optional[float] = None
stop_sequences: Optional[List[str]] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
request_params: Optional[Dict[str, Any]] = None
cache_system_prompt: bool = False
# -*- Client parameters
api_key: Optional[str] = None
client_params: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -119,7 +120,13 @@ def invoke(self, messages: List[Message]) -> AnthropicMessage:
else:
api_messages.append({"role": message.role, "content": message.content or ""})

api_kwargs["system"] = " ".join(system_messages)
if self.cache_system_prompt:
api_kwargs["system"] = [
{"type": "text", "text": " ".join(system_messages), "cache_control": {"type": "ephemeral"}}
]
api_kwargs["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"}
else:
api_kwargs["system"] = " ".join(system_messages)

if self.tools:
api_kwargs["tools"] = self.get_tools()
Expand All @@ -141,7 +148,13 @@ def invoke_stream(self, messages: List[Message]) -> Any:
else:
api_messages.append({"role": message.role, "content": message.content or ""})

api_kwargs["system"] = " ".join(system_messages)
if self.cache_system_prompt:
api_kwargs["system"] = [
{"type": "text", "text": " ".join(system_messages), "cache_control": {"type": "ephemeral"}}
]
api_kwargs["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"}
else:
api_kwargs["system"] = " ".join(system_messages)

if self.tools:
api_kwargs["tools"] = self.get_tools()
Expand All @@ -165,21 +178,20 @@ def response(self, messages: List[Message]) -> str:
logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s")

# -*- Parse response
response_content = response.content[0] # type: ignore
if isinstance(response.content[0], ToolUseBlock):
response_content = response.content[0].input["query"]

elif isinstance(response.content[0], TextBlock):
response_content = response.content[0].text
response_content: str = ""
response_block: Union[TextBlock, ToolUseBlock] = response.content[0]
if isinstance(response_block, TextBlock):
response_content = response_block.text
elif isinstance(response_block, ToolUseBlock):
tool_block = cast(dict[str, Any], response_block.input)
response_content = tool_block.get("query", "")

# -*- Create assistant message
assistant_message = Message(
role=response.role or "assistant",
content=response_content,
)

logger.debug(f"Response: {response}")

# Check if the response contains a tool call
if response.stop_reason == "tool_use":
tool_calls: List[Dict[str, Any]] = []
Expand Down Expand Up @@ -218,6 +230,22 @@ def response(self, messages: List[Message]) -> str:
input_tokens = response_usage.input_tokens
output_tokens = response_usage.output_tokens

try:
cache_creation_tokens = 0
cache_read_tokens = 0
if self.cache_system_prompt:
cache_creation_tokens = response_usage.cache_creation_input_tokens # type: ignore
cache_read_tokens = response_usage.cache_read_input_tokens # type: ignore

assistant_message.metrics["cache_creation_tokens"] = cache_creation_tokens
assistant_message.metrics["cache_read_tokens"] = cache_read_tokens
self.metrics["cache_creation_tokens"] = (
self.metrics.get("cache_creation_tokens", 0) + cache_creation_tokens
)
self.metrics["cache_read_tokens"] = self.metrics.get("cache_read_tokens", 0) + cache_read_tokens
except Exception:
logger.debug("Prompt caching metrics not available")

if input_tokens is not None:
assistant_message.metrics["input_tokens"] = input_tokens
self.metrics["input_tokens"] = self.metrics.get("input_tokens", 0) + input_tokens
Expand Down Expand Up @@ -352,6 +380,22 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]:
input_tokens = response_usage.input_tokens
output_tokens = response_usage.output_tokens

try:
cache_creation_tokens = 0
cache_read_tokens = 0
if self.cache_system_prompt:
cache_creation_tokens = response_usage.cache_creation_input_tokens # type: ignore
cache_read_tokens = response_usage.cache_read_input_tokens # type: ignore

assistant_message.metrics["cache_creation_tokens"] = cache_creation_tokens
assistant_message.metrics["cache_read_tokens"] = cache_read_tokens
self.metrics["cache_creation_tokens"] = (
self.metrics.get("cache_creation_tokens", 0) + cache_creation_tokens
)
self.metrics["cache_read_tokens"] = self.metrics.get("cache_read_tokens", 0) + cache_read_tokens
except Exception:
logger.debug("Prompt caching metrics not available")

if input_tokens is not None:
assistant_message.metrics["input_tokens"] = input_tokens
self.metrics["input_tokens"] = self.metrics.get("input_tokens", 0) + input_tokens
Expand Down

0 comments on commit a7140ae

Please sign in to comment.