From e6437b2d48b70acb69a75b9c35ad9cee3bda01cf Mon Sep 17 00:00:00 2001 From: grisha Date: Wed, 28 Feb 2024 22:15:18 -0800 Subject: [PATCH] fix name typo, add unit test to ollama --- .../llama_index/llms/ollama/base.py | 16 +++++++++------- .../llama-index-llms-ollama/tests/test_utils.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 7 deletions(-) create mode 100644 llama-index-integrations/llms/llama-index-llms-ollama/tests/test_utils.py diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py index 562b0c73429d8..78281eb5918fd 100644 --- a/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py +++ b/llama-index-integrations/llms/llama-index-llms-ollama/llama_index/llms/ollama/base.py @@ -20,7 +20,7 @@ DEFAULT_REQUEST_TIMEOUT = 30.0 -def get_addtional_kwargs( +def get_additional_kwargs( response: Dict[str, Any], exclude: Tuple[str, ...] ) -> Dict[str, Any]: return {k: v for k, v in response.items() if k not in exclude} @@ -109,12 +109,12 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: message=ChatMessage( content=message.get("content"), role=MessageRole(message.get("role")), - additional_kwargs=get_addtional_kwargs( + additional_kwargs=get_additional_kwargs( message, ("content", "role") ), ), raw=raw, - additional_kwargs=get_addtional_kwargs(raw, ("message",)), + additional_kwargs=get_additional_kwargs(raw, ("message",)), ) @llm_chat_callback() @@ -156,13 +156,15 @@ def stream_chat( message=ChatMessage( content=text, role=MessageRole(message.get("role")), - additional_kwargs=get_addtional_kwargs( + additional_kwargs=get_additional_kwargs( message, ("content", "role") ), ), delta=delta, raw=chunk, - additional_kwargs=get_addtional_kwargs(chunk, ("message",)), + additional_kwargs=get_additional_kwargs( + chunk, ("message",) + ), ) @llm_completion_callback() @@ -188,7 +190,7 @@ def complete( return CompletionResponse( text=text, raw=raw, - additional_kwargs=get_addtional_kwargs(raw, ("response",)), + additional_kwargs=get_additional_kwargs(raw, ("response",)), ) @llm_completion_callback() @@ -220,7 +222,7 @@ def stream_complete( delta=delta, text=text, raw=chunk, - additional_kwargs=get_addtional_kwargs( + additional_kwargs=get_additional_kwargs( chunk, ("response",) ), ) diff --git a/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_utils.py b/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_utils.py new file mode 100644 index 0000000000000..0d300048cc47d --- /dev/null +++ b/llama-index-integrations/llms/llama-index-llms-ollama/tests/test_utils.py @@ -0,0 +1,12 @@ +from llama_index.llms.ollama.base import get_additional_kwargs + + +def test_get_additional_kwargs(): + response = {"key1": "value1", "key2": "value2", "exclude_me": "value3"} + exclude = ("exclude_me", "exclude_me_too") + + expected = {"key1": "value1", "key2": "value2"} + + actual = get_additional_kwargs(response, exclude) + + assert actual == expected