diff --git a/llama-index-core/llama_index/core/base/query_pipeline/query.py b/llama-index-core/llama_index/core/base/query_pipeline/query.py index 01f890e47bebb..fd5de0bd0e00e 100644 --- a/llama-index-core/llama_index/core/base/query_pipeline/query.py +++ b/llama-index-core/llama_index/core/base/query_pipeline/query.py @@ -16,6 +16,7 @@ from llama_index.core.base.llms.types import ( ChatResponse, + ChatMessage, CompletionResponse, ) from llama_index.core.base.response.schema import Response @@ -27,6 +28,7 @@ StringableInput = Union[ CompletionResponse, ChatResponse, + ChatMessage, str, QueryBundle, Response, diff --git a/llama-index-core/tests/query_pipeline/test_utils.py b/llama-index-core/tests/query_pipeline/test_utils.py new file mode 100644 index 0000000000000..8f4648905784d --- /dev/null +++ b/llama-index-core/tests/query_pipeline/test_utils.py @@ -0,0 +1,8 @@ +from llama_index.core.base.llms.types import ChatMessage, MessageRole +from llama_index.core.base.query_pipeline.query import validate_and_convert_stringable + + +def test_validate_and_convert_stringable() -> None: + """Test conversion of stringable object into string.""" + message = ChatMessage(role=MessageRole.USER, content="hello") + assert validate_and_convert_stringable(message) == "user: hello"