Skip to content

Commit

Permalink
Fix Bedrock Claude
Browse files Browse the repository at this point in the history
  • Loading branch information
ashpreetbedi committed Apr 19, 2024
1 parent 5ca796f commit 349d31b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 17 deletions.
5 changes: 2 additions & 3 deletions cookbook/llms/bedrock/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from phi.llm.aws.claude import Claude

assistant = Assistant(
llm=Claude(model="anthropic.claude-v2"),
llm=Claude(model="anthropic.claude-3-sonnet-20240229-v1:0"),
description="You help people with their health and fitness goals.",
debug_mode=True,
)
assistant.print_response("Share a quick healthy breakfast recipe.", markdown=True, stream=False)
assistant.print_response("Share a quick healthy breakfast recipe.", markdown=True)
4 changes: 2 additions & 2 deletions cookbook/llms/bedrock/basic_stream_off.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from phi.assistant import Assistant
from phi.llm.anthropic import Claude
from phi.llm.aws.claude import Claude

assistant = Assistant(
llm=Claude(model="claude-3-haiku-20240307"),
llm=Claude(model="anthropic.claude-3-sonnet-20240229-v1:0"),
description="You help people with their health and fitness goals.",
)
assistant.print_response("Share a quick healthy breakfast recipe.", markdown=True, stream=False)
16 changes: 10 additions & 6 deletions phi/llm/aws/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,14 @@ def get_request_body(self, messages: List[Message]) -> Dict[str, Any]:
def parse_response_message(self, response: Dict[str, Any]) -> Message:
raise NotImplementedError("Please use a subclass of AwsBedrock")

def parse_response_delta(self, response: Dict[str, Any]) -> Optional[str]:
raise NotImplementedError("Please use a subclass of AwsBedrock")

def response(self, messages: List[Message]) -> str:
logger.debug("---------- Bedrock Response Start ----------")
# -*- Log messages for debugging
for m in messages:
m.log()

response_timer = Timer()
response_timer.start()
Expand Down Expand Up @@ -192,15 +198,13 @@ def response_stream(self, messages: List[Message]) -> Iterator[str]:
response_timer = Timer()
response_timer.start()
for delta in self.invoke_stream(body=self.get_request_body(messages)):
logger.debug(f"Delta: {delta}")
logger.debug(f"Delta type: {type(delta)}")
completion_tokens += 1
# -*- Parse response
delta_completion = delta.get("completion")
content = self.parse_response_delta(delta)
# -*- Yield completion
if delta_completion is not None:
assistant_message_content += delta_completion
yield delta_completion
if content is not None:
assistant_message_content += content
yield content

response_timer.stop()
logger.debug(f"Time to generate response: {response_timer.elapsed:.4f}s")
Expand Down
21 changes: 16 additions & 5 deletions phi/llm/aws/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from phi.llm.message import Message
from phi.llm.aws.bedrock import AwsBedrock
from phi.utils.log import logger


class Claude(AwsBedrock):
Expand Down Expand Up @@ -62,18 +61,30 @@ def get_request_body(self, messages: List[Message]) -> Dict[str, Any]:
}
if system_prompt:
request_body["system"] = system_prompt
logger.info(f"Request body: {request_body}")
return request_body

def parse_response_message(self, response: Dict[str, Any]) -> Message:
logger.debug(f"Response: {response}")
logger.debug(f"Response type: {type(response)}")
if response.get("type") == "message":
response_message = Message(role=response.get("role"))
content: Optional[str] = ""
if response.get("content"):
response_message.content = response.get("content")
_content = response.get("content")
if isinstance(_content, str):
content = _content
elif isinstance(_content, dict):
content = _content.get("text", "")
elif isinstance(_content, list):
content = "\n".join([c.get("text") for c in _content])

response_message.content = content
return response_message

return Message(
role="assistant",
content=response.get("completion"),
)

def parse_response_delta(self, response: Dict[str, Any]) -> Optional[str]:
if "delta" in response:
return response.get("delta", {}).get("text")
return response.get("completion")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "phidata"
version = "2.3.75"
version = "2.3.76"
description = "Build AI Assistants with memory, knowledge and tools."
requires-python = ">=3.8"
readme = "README.md"
Expand Down

0 comments on commit 349d31b

Please sign in to comment.