Skip to content

Commit

Permalink
fix name typo, add unit test to ollama (#11493)
Browse files Browse the repository at this point in the history
  • Loading branch information
skvrd authored Feb 29, 2024
1 parent 45f0db3 commit 348cad7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
DEFAULT_REQUEST_TIMEOUT = 30.0


def get_addtional_kwargs(
def get_additional_kwargs(
response: Dict[str, Any], exclude: Tuple[str, ...]
) -> Dict[str, Any]:
return {k: v for k, v in response.items() if k not in exclude}
Expand Down Expand Up @@ -109,12 +109,12 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
message=ChatMessage(
content=message.get("content"),
role=MessageRole(message.get("role")),
additional_kwargs=get_addtional_kwargs(
additional_kwargs=get_additional_kwargs(
message, ("content", "role")
),
),
raw=raw,
additional_kwargs=get_addtional_kwargs(raw, ("message",)),
additional_kwargs=get_additional_kwargs(raw, ("message",)),
)

@llm_chat_callback()
Expand Down Expand Up @@ -156,13 +156,15 @@ def stream_chat(
message=ChatMessage(
content=text,
role=MessageRole(message.get("role")),
additional_kwargs=get_addtional_kwargs(
additional_kwargs=get_additional_kwargs(
message, ("content", "role")
),
),
delta=delta,
raw=chunk,
additional_kwargs=get_addtional_kwargs(chunk, ("message",)),
additional_kwargs=get_additional_kwargs(
chunk, ("message",)
),
)

@llm_completion_callback()
Expand All @@ -188,7 +190,7 @@ def complete(
return CompletionResponse(
text=text,
raw=raw,
additional_kwargs=get_addtional_kwargs(raw, ("response",)),
additional_kwargs=get_additional_kwargs(raw, ("response",)),
)

@llm_completion_callback()
Expand Down Expand Up @@ -220,7 +222,7 @@ def stream_complete(
delta=delta,
text=text,
raw=chunk,
additional_kwargs=get_addtional_kwargs(
additional_kwargs=get_additional_kwargs(
chunk, ("response",)
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from llama_index.llms.ollama.base import get_additional_kwargs


def test_get_additional_kwargs():
response = {"key1": "value1", "key2": "value2", "exclude_me": "value3"}
exclude = ("exclude_me", "exclude_me_too")

expected = {"key1": "value1", "key2": "value2"}

actual = get_additional_kwargs(response, exclude)

assert actual == expected

0 comments on commit 348cad7

Please sign in to comment.