diff --git a/cookbook/llms/claude/prompt_caching.py b/cookbook/llms/claude/prompt_caching.py new file mode 100644 index 0000000000..4e4ca304cb --- /dev/null +++ b/cookbook/llms/claude/prompt_caching.py @@ -0,0 +1,46 @@ +# Inspired by: https://github.com/anthropics/anthropic-cookbook/blob/main/misc/prompt_caching.ipynb +import requests +from bs4 import BeautifulSoup + +from phi.assistant import Assistant +from phi.llm.anthropic import Claude + + +def fetch_article_content(url): + response = requests.get(url) + soup = BeautifulSoup(response.content, "html.parser") + # Remove script and style elements + for script in soup(["script", "style"]): + script.decompose() + # Get text + text = soup.get_text() + # Break into lines and remove leading and trailing space on each + lines = (line.strip() for line in text.splitlines()) + # Break multi-headlines into a line each + chunks = (phrase.strip() for line in lines for phrase in line.split(" ")) + # Drop blank lines + text = "\n".join(chunk for chunk in chunks if chunk) + return text + + +# Fetch the content of the article +book_url = "https://www.gutenberg.org/cache/epub/1342/pg1342.txt" +book_content = fetch_article_content(book_url) + +print(f"Fetched {len(book_content)} characters from the book.") + +assistant = Assistant( + llm=Claude( + model="claude-3-5-sonnet-20240620", + cache_system_prompt=True, + ), + system_prompt=book_content[:10000], + debug_mode=True, +) +assistant.print_response("Give me a one line summary of this book", markdown=True, stream=True) +print("Prompt cache creation tokens: ", assistant.llm.metrics["cache_creation_tokens"]) # type: ignore +print("Prompt cache read tokens: ", assistant.llm.metrics["cache_read_tokens"]) # type: ignore + +# assistant.print_response("Give me a one line summary of this book", markdown=True, stream=False) +# print("Prompt cache creation tokens: ", assistant.llm.metrics["cache_creation_tokens"]) +# print("Prompt cache read tokens: ", assistant.llm.metrics["cache_read_tokens"]) diff --git a/phi/llm/anthropic/claude.py b/phi/llm/anthropic/claude.py index 221cd2bfd7..66d319b733 100644 --- a/phi/llm/anthropic/claude.py +++ b/phi/llm/anthropic/claude.py @@ -1,5 +1,5 @@ import json -from typing import Optional, List, Iterator, Dict, Any, Union +from typing import Optional, List, Iterator, Dict, Any, Union, cast from phi.llm.base import LLM from phi.llm.message import Message @@ -26,7 +26,7 @@ class Claude(LLM): name: str = "claude" - model: str = "claude-3-opus-20240229" + model: str = "claude-3-5-sonnet-20240620" # -*- Request parameters max_tokens: Optional[int] = 1024 temperature: Optional[float] = None @@ -34,6 +34,7 @@ class Claude(LLM): top_p: Optional[float] = None top_k: Optional[int] = None request_params: Optional[Dict[str, Any]] = None + cache_system_prompt: bool = False # -*- Client parameters api_key: Optional[str] = None client_params: Optional[Dict[str, Any]] = None @@ -119,7 +120,13 @@ def invoke(self, messages: List[Message]) -> AnthropicMessage: else: api_messages.append({"role": message.role, "content": message.content or ""}) - api_kwargs["system"] = " ".join(system_messages) + if self.cache_system_prompt: + api_kwargs["system"] = [ + {"type": "text", "text": " ".join(system_messages), "cache_control": {"type": "ephemeral"}} + ] + api_kwargs["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"} + else: + api_kwargs["system"] = " ".join(system_messages) if self.tools: api_kwargs["tools"] = self.get_tools() @@ -141,7 +148,13 @@ def invoke_stream(self, messages: List[Message]) -> Any: else: api_messages.append({"role": message.role, "content": message.content or ""}) - api_kwargs["system"] = " ".join(system_messages) + if self.cache_system_prompt: + api_kwargs["system"] = [ + {"type": "text", "text": " ".join(system_messages), "cache_control": {"type": "ephemeral"}} + ] + api_kwargs["extra_headers"] = {"anthropic-beta": "prompt-caching-2024-07-31"} + else: + api_kwargs["system"] = " ".join(system_messages) if self.tools: api_kwargs["tools"] = self.get_tools() @@ -165,12 +178,13 @@ def response(self, messages: List[Message]) -> str: logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s") # -*- Parse response - response_content = response.content[0] # type: ignore - if isinstance(response.content[0], ToolUseBlock): - response_content = response.content[0].input["query"] - - elif isinstance(response.content[0], TextBlock): - response_content = response.content[0].text + response_content: str = "" + response_block: Union[TextBlock, ToolUseBlock] = response.content[0] + if isinstance(response_block, TextBlock): + response_content = response_block.text + elif isinstance(response_block, ToolUseBlock): + tool_block = cast(dict[str, Any], response_block.input) + response_content = tool_block.get("query", "") # -*- Create assistant message assistant_message = Message( @@ -178,8 +192,6 @@ def response(self, messages: List[Message]) -> str: content=response_content, ) - logger.debug(f"Response: {response}") - # Check if the response contains a tool call if response.stop_reason == "tool_use": tool_calls: List[Dict[str, Any]] = [] @@ -218,6 +230,22 @@ def response(self, messages: List[Message]) -> str: input_tokens = response_usage.input_tokens output_tokens = response_usage.output_tokens + try: + cache_creation_tokens = 0 + cache_read_tokens = 0 + if self.cache_system_prompt: + cache_creation_tokens = response_usage.cache_creation_input_tokens # type: ignore + cache_read_tokens = response_usage.cache_read_input_tokens # type: ignore + + assistant_message.metrics["cache_creation_tokens"] = cache_creation_tokens + assistant_message.metrics["cache_read_tokens"] = cache_read_tokens + self.metrics["cache_creation_tokens"] = ( + self.metrics.get("cache_creation_tokens", 0) + cache_creation_tokens + ) + self.metrics["cache_read_tokens"] = self.metrics.get("cache_read_tokens", 0) + cache_read_tokens + except Exception: + logger.debug("Prompt caching metrics not available") + if input_tokens is not None: assistant_message.metrics["input_tokens"] = input_tokens self.metrics["input_tokens"] = self.metrics.get("input_tokens", 0) + input_tokens @@ -352,6 +380,22 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]: input_tokens = response_usage.input_tokens output_tokens = response_usage.output_tokens + try: + cache_creation_tokens = 0 + cache_read_tokens = 0 + if self.cache_system_prompt: + cache_creation_tokens = response_usage.cache_creation_input_tokens # type: ignore + cache_read_tokens = response_usage.cache_read_input_tokens # type: ignore + + assistant_message.metrics["cache_creation_tokens"] = cache_creation_tokens + assistant_message.metrics["cache_read_tokens"] = cache_read_tokens + self.metrics["cache_creation_tokens"] = ( + self.metrics.get("cache_creation_tokens", 0) + cache_creation_tokens + ) + self.metrics["cache_read_tokens"] = self.metrics.get("cache_read_tokens", 0) + cache_read_tokens + except Exception: + logger.debug("Prompt caching metrics not available") + if input_tokens is not None: assistant_message.metrics["input_tokens"] = input_tokens self.metrics["input_tokens"] = self.metrics.get("input_tokens", 0) + input_tokens