Skip to content

Commit

Permalink
feat: add more params type support
Browse files Browse the repository at this point in the history
  • Loading branch information
Dttbd committed Apr 25, 2024
1 parent 96b1bf8 commit 9932123
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 41 deletions.
84 changes: 48 additions & 36 deletions taskingai/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@
DEFAULT_RETRIEVAL_CONFIG = RetrievalConfig(top_k=3, method=RetrievalMethod.USER_MESSAGE)


def _get_assistant_dict_params(
memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None,
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
):
memory = memory if isinstance(memory, AssistantMemory) else (AssistantMemory(**memory) if memory else None)
tools = [tool if isinstance(tool, AssistantTool) else AssistantTool(**tool) for tool in (tools or [])] or None
retrievals = [
retrieval if isinstance(retrieval, AssistantRetrieval) else AssistantRetrieval(**retrieval)
for retrieval in (retrievals or [])
] or None
retrieval_configs = (
retrieval_configs
if isinstance(retrieval_configs, RetrievalConfig)
else RetrievalConfig(**retrieval_configs)
if retrieval_configs
else None
)
return memory, tools, retrievals, retrieval_configs


def list_assistants(
order: str = "desc",
limit: int = 20,
Expand Down Expand Up @@ -118,12 +140,12 @@ async def a_get_assistant(assistant_id: str) -> Assistant:

def create_assistant(
model_id: str,
memory: AssistantMemory,
memory: Union[AssistantMemory, Dict[str, Any]],
name: Optional[str] = None,
description: Optional[str] = None,
system_prompt_template: Optional[List[str]] = None,
tools: Optional[List[AssistantTool]] = None,
retrievals: Optional[List[AssistantRetrieval]] = None,
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
metadata: Optional[Dict[str, str]] = None,
) -> Assistant:
Expand All @@ -140,12 +162,9 @@ def create_assistant(
:param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512.
:return: The created assistant object.
"""
if retrieval_configs:
retrieval_configs = (
retrieval_configs
if isinstance(retrieval_configs, RetrievalConfig)
else RetrievalConfig(**retrieval_configs)
)
memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params(
memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs
)

body = AssistantCreateRequest(
model_id=model_id,
Expand All @@ -164,12 +183,12 @@ def create_assistant(

async def a_create_assistant(
model_id: str,
memory: AssistantMemory,
memory: Union[AssistantMemory, Dict[str, Any]],
name: Optional[str] = None,
description: Optional[str] = None,
system_prompt_template: Optional[List[str]] = None,
tools: Optional[List[AssistantTool]] = None,
retrievals: Optional[List[AssistantRetrieval]] = None,
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
metadata: Optional[Dict[str, str]] = None,
) -> Assistant:
Expand All @@ -186,12 +205,9 @@ async def a_create_assistant(
:param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512.
:return: The created assistant object.
"""
if retrieval_configs:
retrieval_configs = (
retrieval_configs
if isinstance(retrieval_configs, RetrievalConfig)
else RetrievalConfig(**retrieval_configs)
)
memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params(
memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs
)

body = AssistantCreateRequest(
model_id=model_id,
Expand All @@ -214,9 +230,9 @@ def update_assistant(
name: Optional[str] = None,
description: Optional[str] = None,
system_prompt_template: Optional[List[str]] = None,
memory: Optional[AssistantMemory] = None,
tools: Optional[List[AssistantTool]] = None,
retrievals: Optional[List[AssistantRetrieval]] = None,
memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None,
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
metadata: Optional[Dict[str, str]] = None,
) -> Assistant:
Expand All @@ -235,12 +251,10 @@ def update_assistant(
:return: The updated assistant object.
"""

if retrieval_configs:
retrieval_configs = (
retrieval_configs
if isinstance(retrieval_configs, RetrievalConfig)
else RetrievalConfig(**retrieval_configs)
)
memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params(
memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs
)

body = AssistantUpdateRequest(
model_id=model_id,
name=name,
Expand All @@ -262,9 +276,9 @@ async def a_update_assistant(
name: Optional[str] = None,
description: Optional[str] = None,
system_prompt_template: Optional[List[str]] = None,
memory: Optional[AssistantMemory] = None,
tools: Optional[List[AssistantTool]] = None,
retrievals: Optional[List[AssistantRetrieval]] = None,
memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None,
tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None,
retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None,
retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None,
metadata: Optional[Dict[str, str]] = None,
) -> Assistant:
Expand All @@ -283,12 +297,10 @@ async def a_update_assistant(
:return: The updated assistant object.
"""

if retrieval_configs:
retrieval_configs = (
retrieval_configs
if isinstance(retrieval_configs, RetrievalConfig)
else RetrievalConfig(**retrieval_configs)
)
memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params(
memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs
)

body = AssistantUpdateRequest(
model_id=model_id,
name=name,
Expand Down
39 changes: 34 additions & 5 deletions taskingai/inference/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List, Dict, Union
from typing import Any, Optional, List, Dict, Union
from ..client.stream import Stream, AsyncStream

from taskingai.client.models import *
Expand Down Expand Up @@ -44,12 +44,35 @@ def __init__(self, id: str, content: str):
super().__init__(role=ChatCompletionRole.FUNCTION, id=id, content=content)


def _get_completion_dict_params(
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]],
functions: Optional[List[Union[Function, Dict[str, Any]]]] = None,
):
def _build_message(message: Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]):
if isinstance(message, Dict):
if message["role"] == ChatCompletionRole.SYSTEM.value:
return SystemMessage(**message)
if message["role"] == ChatCompletionRole.USER.value:
return UserMessage(**message)
if message["role"] == ChatCompletionRole.ASSISTANT.value:
return AssistantMessage(**message)
if message["role"] == ChatCompletionRole.FUNCTION.value:
return FunctionMessage(**message)
return message

messages = [_build_message(message) for message in messages]
functions = [
function if isinstance(function, Function) else Function(**function) for function in (functions or [])
] or None
return messages, functions


def chat_completion(
model_id: str,
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage]],
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]],
configs: Optional[Dict] = None,
function_call: Optional[str] = None,
functions: Optional[List[Function]] = None,
functions: Optional[List[Union[Function, Dict[str, Any]]]] = None,
stream: bool = False,
) -> Union[ChatCompletion, Stream]:
"""
Expand All @@ -63,6 +86,9 @@ def chat_completion(
:param stream: Whether to request in stream mode.
:return: The list of assistants.
"""

messages, functions = _get_completion_dict_params(messages, functions)

# only add non-None parameters
body = ChatCompletionRequest(
model_id=model_id,
Expand All @@ -82,10 +108,10 @@ def chat_completion(

async def a_chat_completion(
model_id: str,
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage]],
messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]],
configs: Optional[Dict] = None,
function_call: Optional[str] = None,
functions: Optional[List[Function]] = None,
functions: Optional[List[Union[Function, Dict[str, Any]]]] = None,
stream: bool = False,
) -> Union[ChatCompletion, AsyncStream]:
"""
Expand All @@ -99,6 +125,9 @@ async def a_chat_completion(
:param stream: Whether to request in stream mode.
:return: The list of assistants.
"""

messages, functions = _get_completion_dict_params(messages, functions)

# only add non-None parameters
body = ChatCompletionRequest(
model_id=model_id,
Expand Down

0 comments on commit 9932123

Please sign in to comment.