Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion docs/my-website/docs/providers/bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,19 @@ curl http://0.0.0.0:4000/v1/chat/completions \

Example of using [Bedrock Guardrails with LiteLLM](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-converse-api.html)

### Selective Content Moderation with `guarded_text`

LiteLLM supports selective content moderation using the `guarded_text` content type. This allows you to wrap only specific content that should be moderated by Bedrock Guardrails, rather than evaluating the entire conversation.

**How it works:**
- Content with `type: "guarded_text"` gets automatically wrapped in `guardrailConverseContent` blocks
- Only the wrapped content is evaluated by Bedrock Guardrails
- Regular content with `type: "text"` bypasses guardrail evaluation

:::note
If `guarded_text` is not used, the entire conversation history will be sent to the guardrail for evaluation, which can increase latency and costs.
:::

<Tabs>
<TabItem value="sdk" label="LiteLLM SDK">

Expand All @@ -915,6 +928,24 @@ response = completion(
"trace": "disabled", # The trace behavior for the guardrail. Can either be "disabled" or "enabled"
},
)

# Selective guardrail usage with guarded_text - only specific content is evaluated
response_guard = completion(
model="anthropic.claude-v2",
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What is the main topic of this legal document?"},
{"type": "guarded_text", "text": "This document contains sensitive legal information that should be moderated by guardrails."}
]
}
],
guardrailConfig={
"guardrailIdentifier": "gr-abc123",
"guardrailVersion": "DRAFT"
}
)
```
</TabItem>
<TabItem value="proxy" label="Proxy on request">
Expand Down Expand Up @@ -993,7 +1024,20 @@ response = client.chat.completions.create(model="bedrock-claude-v1", messages =
temperature=0.7
)

print(response)
# For adding selective guardrail usage with guarded_text
response_guard = client.chat.completions.create(model="bedrock-claude-v1", messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is the main topic of this legal document?"},
{"type": "guarded_text", "text": "This document contains sensitive legal information that should be moderated by guardrails."}
]
}
],
temperature=0.7
)

print(response_guard)
```
</TabItem>
</Tabs>
Expand Down
119 changes: 69 additions & 50 deletions litellm/litellm_core_utils/prompt_templates/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client
from litellm.types.files import get_file_extension_from_mime_type
from litellm.types.llms.anthropic import *
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
from litellm.types.llms.bedrock import CachePointBlock
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.llms.ollama import OllamaVisionModelObject
from litellm.types.llms.openai import (
Expand Down Expand Up @@ -1067,10 +1067,10 @@ def convert_to_gemini_tool_call_invoke(
if tool_calls is not None:
for tool in tool_calls:
if "function" in tool:
gemini_function_call: Optional[VertexFunctionCall] = (
_gemini_tool_call_invoke_helper(
function_call_params=tool["function"]
)
gemini_function_call: Optional[
VertexFunctionCall
] = _gemini_tool_call_invoke_helper(
function_call_params=tool["function"]
)
if gemini_function_call is not None:
_parts_list.append(
Expand Down Expand Up @@ -1589,9 +1589,9 @@ def anthropic_messages_pt( # noqa: PLR0915
)

if "cache_control" in _content_element:
_anthropic_content_element["cache_control"] = (
_content_element["cache_control"]
)
_anthropic_content_element[
"cache_control"
] = _content_element["cache_control"]
user_content.append(_anthropic_content_element)
elif m.get("type", "") == "text":
m = cast(ChatCompletionTextObject, m)
Expand Down Expand Up @@ -1629,9 +1629,9 @@ def anthropic_messages_pt( # noqa: PLR0915
)

if "cache_control" in _content_element:
_anthropic_content_text_element["cache_control"] = (
_content_element["cache_control"]
)
_anthropic_content_text_element[
"cache_control"
] = _content_element["cache_control"]

user_content.append(_anthropic_content_text_element)

Expand Down Expand Up @@ -2482,8 +2482,7 @@ def _validate_format(mime_type: str, image_format: str) -> str:

if is_document:
return BedrockImageProcessor._get_document_format(
mime_type=mime_type,
supported_doc_formats=supported_doc_formats
mime_type=mime_type, supported_doc_formats=supported_doc_formats
)

else:
Expand All @@ -2495,12 +2494,9 @@ def _validate_format(mime_type: str, image_format: str) -> str:
f"Unsupported image format: {image_format}. Supported formats: {supported_image_and_video_formats}"
)
return image_format

@staticmethod
def _get_document_format(
mime_type: str,
supported_doc_formats: List[str]
) -> str:
def _get_document_format(mime_type: str, supported_doc_formats: List[str]) -> str:
"""
Get the document format from the mime type

Expand All @@ -2519,13 +2515,9 @@ def _get_document_format(
The document format
"""
valid_extensions: Optional[List[str]] = None
potential_extensions = mimetypes.guess_all_extensions(
mime_type, strict=False
)
potential_extensions = mimetypes.guess_all_extensions(mime_type, strict=False)
valid_extensions = [
ext[1:]
for ext in potential_extensions
if ext[1:] in supported_doc_formats
ext[1:] for ext in potential_extensions if ext[1:] in supported_doc_formats
]

# Fallback to types/files.py if mimetypes doesn't return valid extensions
Expand Down Expand Up @@ -2686,10 +2678,12 @@ def _convert_to_bedrock_tool_call_invoke(
)
bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool)
_parts_list.append(bedrock_content_block)

# Check for cache_control and add a separate cachePoint block
if tool.get("cache_control", None) is not None:
cache_point_block = BedrockContentBlock(cachePoint=CachePointBlock(type="default"))
cache_point_block = BedrockContentBlock(
cachePoint=CachePointBlock(type="default")
)
_parts_list.append(cache_point_block)
return _parts_list
except Exception as e:
Expand Down Expand Up @@ -2751,7 +2745,7 @@ def _convert_to_bedrock_tool_call_result(
for content in content_list:
if content["type"] == "text":
content_str += content["text"]

message.get("name", "")
id = str(message.get("tool_call_id", str(uuid.uuid4())))

Expand All @@ -2760,7 +2754,7 @@ def _convert_to_bedrock_tool_call_result(
content=[tool_result_content_block],
toolUseId=id,
)

content_block = BedrockContentBlock(toolResult=tool_result)

return content_block
Expand Down Expand Up @@ -3082,6 +3076,7 @@ def _initial_message_setup(
messages.append(DEFAULT_USER_CONTINUE_MESSAGE)
return messages


@staticmethod
async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
messages: List,
Expand Down Expand Up @@ -3125,6 +3120,12 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
if element["type"] == "text":
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "guarded_text":
# Wrap guarded_text in guardrailConverseContent block
_part = BedrockContentBlock(
guardrailConverseContent={"text": element["text"]}
)
_parts.append(_part)
elif element["type"] == "image_url":
format: Optional[str] = None
if isinstance(element["image_url"], dict):
Expand Down Expand Up @@ -3167,6 +3168,7 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915

msg_i += 1
if user_content:

if len(contents) > 0 and contents[-1]["role"] == "user":
if (
assistant_continue_message is not None
Expand Down Expand Up @@ -3196,26 +3198,29 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
current_message = messages[msg_i]
tool_call_result = _convert_to_bedrock_tool_call_result(current_message)
tool_content.append(tool_call_result)

# Check if we need to add a separate cachePoint block
has_cache_control = False

# Check for message-level cache_control
if current_message.get("cache_control", None) is not None:
has_cache_control = True
# Check for content-level cache_control in list content
elif isinstance(current_message.get("content"), list):
for content_element in current_message["content"]:
if (isinstance(content_element, dict) and
content_element.get("cache_control", None) is not None):
if (
isinstance(content_element, dict)
and content_element.get("cache_control", None) is not None
):
has_cache_control = True
break

# Add a separate cachePoint block if cache_control is present
if has_cache_control:
cache_point_block = BedrockContentBlock(cachePoint=CachePointBlock(type="default"))
cache_point_block = BedrockContentBlock(
cachePoint=CachePointBlock(type="default")
)
tool_content.append(cache_point_block)


msg_i += 1
if tool_content:
Expand Down Expand Up @@ -3296,7 +3301,7 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
image_url=image_url
)
assistants_parts.append(assistants_part)
# Add cache point block for assistant content elements
# Add cache point block for assistant content elements
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
message_block=cast(
Expand All @@ -3308,8 +3313,12 @@ async def _bedrock_converse_messages_pt_async( # noqa: PLR0915
if _cache_point_block is not None:
assistants_parts.append(_cache_point_block)
assistant_content.extend(assistants_parts)
elif _assistant_content is not None and isinstance(_assistant_content, str):
assistant_content.append(BedrockContentBlock(text=_assistant_content))
elif _assistant_content is not None and isinstance(
_assistant_content, str
):
assistant_content.append(
BedrockContentBlock(text=_assistant_content)
)
# Add cache point block for assistant string content
_cache_point_block = (
litellm.AmazonConverseConfig()._get_cache_point_block(
Expand Down Expand Up @@ -3493,6 +3502,12 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
if element["type"] == "text":
_part = BedrockContentBlock(text=element["text"])
_parts.append(_part)
elif element["type"] == "guarded_text":
# Wrap guarded_text in guardrailConverseContent block
_part = BedrockContentBlock(
guardrailConverseContent={"text": element["text"]}
)
_parts.append(_part)
elif element["type"] == "image_url":
format: Optional[str] = None
if isinstance(element["image_url"], dict):
Expand Down Expand Up @@ -3536,6 +3551,7 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915

msg_i += 1
if user_content:

if len(contents) > 0 and contents[-1]["role"] == "user":
if (
assistant_continue_message is not None
Expand All @@ -3562,29 +3578,33 @@ def _bedrock_converse_messages_pt( # noqa: PLR0915
while msg_i < len(messages) and messages[msg_i]["role"] == "tool":
tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i])
current_message = messages[msg_i]

# Add the tool result first
tool_content.append(tool_call_result)

# Check if we need to add a separate cachePoint block
has_cache_control = False

# Check for message-level cache_control
if current_message.get("cache_control", None) is not None:
has_cache_control = True
# Check for content-level cache_control in list content
elif isinstance(current_message.get("content"), list):
for content_element in current_message["content"]:
if (isinstance(content_element, dict) and
content_element.get("cache_control", None) is not None):
if (
isinstance(content_element, dict)
and content_element.get("cache_control", None) is not None
):
has_cache_control = True
break

# Add a separate cachePoint block if cache_control is present
if has_cache_control:
cache_point_block = BedrockContentBlock(cachePoint=CachePointBlock(type="default"))
cache_point_block = BedrockContentBlock(
cachePoint=CachePointBlock(type="default")
)
tool_content.append(cache_point_block)

msg_i += 1
if tool_content:
# if last message was a 'user' message, then add a blank assistant message (bedrock requires alternating roles)
Expand Down Expand Up @@ -3849,10 +3869,9 @@ def function_call_prompt(messages: list, functions: list):
if isinstance(message["content"], str):
message["content"] += f""" {function_prompt}"""
else:
message["content"].append({
"type": "text",
"text": f""" {function_prompt}"""
})
message["content"].append(
{"type": "text", "text": f""" {function_prompt}"""}
)
function_added_to_prompt = True

if function_added_to_prompt is False:
Expand Down
Loading
Loading