Skip to content

Commit

Permalink
Fix constructor type
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere committed Feb 6, 2025
1 parent d144668 commit 03edc9a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 13 deletions.
55 changes: 53 additions & 2 deletions src/cohere/client_v2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,61 @@
from .client import Client, AsyncClient
from .v2.client import V2Client, AsyncV2Client
import typing
from .environment import ClientEnvironment
import os
import httpx
from concurrent.futures import ThreadPoolExecutor


class ClientV2(V2Client, Client): # type: ignore
__init__ = Client.__init__ # type: ignore
def __init__(
self,
api_key: typing.Optional[typing.Union[str,
typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
httpx_client: typing.Optional[httpx.Client] = None,
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
Client.__init__(
self,
api_key=api_key,
base_url=base_url,
environment=environment,
client_name=client_name,
timeout=timeout,
httpx_client=httpx_client,
thread_pool_executor=thread_pool_executor,
log_warning_experimental_features=log_warning_experimental_features,
)


class AsyncClientV2(AsyncV2Client, AsyncClient): # type: ignore
__init__ = AsyncClient.__init__ # type: ignore
def __init__(
self,
api_key: typing.Optional[typing.Union[str,
typing.Callable[[], str]]] = None,
*,
base_url: typing.Optional[str] = os.getenv("CO_API_URL"),
environment: ClientEnvironment = ClientEnvironment.PRODUCTION,
client_name: typing.Optional[str] = None,
timeout: typing.Optional[float] = None,
httpx_client: typing.Optional[httpx.AsyncClient] = None,
thread_pool_executor: ThreadPoolExecutor = ThreadPoolExecutor(64),
log_warning_experimental_features: bool = True,
):
AsyncClient.__init__(
self,
api_key=api_key,
base_url=base_url,
environment=environment,
client_name=client_name,
timeout=timeout,
httpx_client=httpx_client,
thread_pool_executor=thread_pool_executor,
log_warning_experimental_features=log_warning_experimental_features,
)
29 changes: 18 additions & 11 deletions tests/test_client_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,22 @@
class TestClientV2(unittest.TestCase):

def test_chat(self) -> None:
response = co.chat(model="command-r-plus", messages=[cohere.UserChatMessageV2(content="hello world!")])
response = co.chat(
model="command-r-plus", messages=[cohere.UserChatMessageV2(content="hello world!")])

print(response.message)

def test_chat_stream(self) -> None:
stream = co.chat_stream(model="command-r-plus", messages=[cohere.UserChatMessageV2(content="hello world!")])
stream = co.chat_stream(
model="command-r-plus", messages=[cohere.UserChatMessageV2(content="hello world!")])

events = set()

for chat_event in stream:
if chat_event is not None:
events.add(chat_event.type)
if chat_event.type == "content-delta":
print(chat_event.delta.message)
print(chat_event.delta)

self.assertTrue("message-start" in events)
self.assertTrue("content-start" in events)
Expand All @@ -43,10 +45,12 @@ def test_chat_documents(self) -> None:
{"title": "widget sales 2021", "text": "4 million"},
]
response = co.chat(
messages=cohere.UserChatMessageV2(
content=cohere.TextContent(text="how many widges were sold in 2020?"),
messages=[cohere.UserChatMessageV2(
content=cohere.TextContent(
text="how many widges were sold in 2020?"),
documents=documents,
),
)],
model="command-r-plus",
)

print(response.message)
Expand Down Expand Up @@ -75,9 +79,12 @@ def test_chat_tools(self) -> None:

# call the get_weather tool
tool_result = {"temperature": "30C"}
tool_content = [cohere.Content(output=tool_result, text="The weather in Toronto is 30C")]
messages.append(res.message)
messages.append(cohere.ToolChatMessageV2(tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content))

res = co.chat(tools=tools, messages=messages)
tool_content = [cohere.Content(
output=tool_result, text="The weather in Toronto is 30C")]
messages.append(cohere.AssistantChatMessageV2(content=res.message))
if res.message.tool_calls is not None:
messages.append(cohere.ToolChatMessageV2(
tool_call_id=res.message.tool_calls[0].id, tool_content=tool_content))

res = co.chat(tools=tools, messages=messages, model="command-r-plus")
print(res.message)

0 comments on commit 03edc9a

Please sign in to comment.