diff --git a/mirascope/core/mistral/_utils/_convert_message_params.py b/mirascope/core/mistral/_utils/_convert_message_params.py index ae3f5c316..5c1bbf10e 100644 --- a/mirascope/core/mistral/_utils/_convert_message_params.py +++ b/mirascope/core/mistral/_utils/_convert_message_params.py @@ -1,8 +1,13 @@ """Utility for converting `BaseMessageParam` to `ChatMessage`.""" +import base64 + from mistralai.models import ( AssistantMessage, + ImageURL, + ImageURLChunk, SystemMessage, + TextChunk, ToolMessage, UserMessage, ) @@ -37,9 +42,37 @@ def convert_message_params( elif isinstance(content := message_param.content, str): converted_message_params.append(_make_message(**message_param.model_dump())) else: - if len(content) != 1 or content[0].type != "text": - raise ValueError("Mistral currently only supports text parts.") + converted_content = [] + for part in content: + if part.type == "text": + converted_content.append(TextChunk(text=part.text)) + + elif part.type == "image": + if part.media_type not in [ + "image/jpeg", + "image/png", + "image/gif", + "image/webp", + ]: + raise ValueError( + f"Unsupported image media type: {part.media_type}. Mistral" + " currently only supports JPEG, PNG, GIF, and WebP images." + ) + data = base64.b64encode(part.image).decode("utf-8") + converted_content.append( + ImageURLChunk( + image_url=ImageURL( + url=f"data:{part.media_type};base64,{data}", + detail=part.detail if part.detail else "auto", + ) + ) + ) + else: + raise ValueError( + "Mistral currently only supports text and image parts. " + f"Part provided: {part.type}" + ) converted_message_params.append( - _make_message(role=message_param.role, content=content[0].text) + _make_message(role=message_param.role, content=converted_content) ) return converted_message_params diff --git a/tests/core/mistral/_utils/test_convert_message_params.py b/tests/core/mistral/_utils/test_convert_message_params.py index 01e1d5b82..cbd1a399d 100644 --- a/tests/core/mistral/_utils/test_convert_message_params.py +++ b/tests/core/mistral/_utils/test_convert_message_params.py @@ -3,7 +3,10 @@ import pytest from mistralai.models import ( AssistantMessage, + ImageURL, + ImageURLChunk, SystemMessage, + TextChunk, ToolMessage, UserMessage, ) @@ -30,12 +33,21 @@ def test_convert_message_params() -> None: ), SystemMessage(content="Hello", role="system"), ToolMessage(content="Hello", tool_call_id=Unset(), name=Unset(), role="tool"), + BaseMessageParam( + role="user", + content=[ + TextPart(type="text", text="Hello"), + ImagePart( + type="image", media_type="image/jpeg", image=b"image", detail="auto" + ), + ], + ), ] converted_message_params = convert_message_params(message_params) assert converted_message_params == [ UserMessage(content="Hello"), UserMessage(role="user", content="Hello"), - UserMessage(role="user", content="Hello"), + UserMessage(content=[TextChunk(text="Hello", TYPE="text")], role="user"), AssistantMessage(content="Hello"), SystemMessage(content="Hello"), ToolMessage(content="Hello", tool_call_id=Unset(), name=Unset(), role="tool"), @@ -44,53 +56,65 @@ def test_convert_message_params() -> None: ), SystemMessage(content="Hello", role="system"), ToolMessage(content="Hello"), + UserMessage( + role="user", + content=[ + TextChunk(text="Hello"), + ImageURLChunk( + image_url=ImageURL( + url="", detail="auto" + ) + ), + ], + ), ] with pytest.raises( ValueError, - match="Mistral currently only supports text parts.", + match="Mistral currently only supports text and image parts. Part provided: audio", ): convert_message_params( [ BaseMessageParam( role="user", content=[ - ImagePart( - type="image", - media_type="image/jpeg", - image=b"image", - detail="auto", + AudioPart( + type="audio", + media_type="audio/wav", + audio=b"audio", ) ], - ) + ), ] ) with pytest.raises( ValueError, - match="Mistral currently only supports text parts.", + match="Invalid role: invalid_role", ): convert_message_params( [ - BaseMessageParam( - role="user", - content=[ - AudioPart( - type="audio", - media_type="audio/wav", - audio=b"audio", - ) - ], - ), + BaseMessageParam(role="invalid_role", content="Hello"), ] ) with pytest.raises( ValueError, - match="Invalid role: invalid_role", + match="Unsupported image media type: image/svg." + " Mistral currently only supports JPEG, PNG, GIF, and WebP images.", ): convert_message_params( [ - BaseMessageParam(role="invalid_role", content="Hello"), + BaseMessageParam( + role="user", + content=[ + ImagePart( + type="image", + media_type="image/svg", + image=b"image", + detail="auto", + ) + ], + ) ] )