From 9932123533ba0ec102cfb18797606bda326a7b50 Mon Sep 17 00:00:00 2001 From: Dttbd Date: Thu, 25 Apr 2024 21:44:56 +0800 Subject: [PATCH] feat: add more params type support --- taskingai/assistant/assistant.py | 84 +++++++++++++++----------- taskingai/inference/chat_completion.py | 39 ++++++++++-- 2 files changed, 82 insertions(+), 41 deletions(-) diff --git a/taskingai/assistant/assistant.py b/taskingai/assistant/assistant.py index f42f366..5760b00 100644 --- a/taskingai/assistant/assistant.py +++ b/taskingai/assistant/assistant.py @@ -33,6 +33,28 @@ DEFAULT_RETRIEVAL_CONFIG = RetrievalConfig(top_k=3, method=RetrievalMethod.USER_MESSAGE) +def _get_assistant_dict_params( + memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, + retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, +): + memory = memory if isinstance(memory, AssistantMemory) else (AssistantMemory(**memory) if memory else None) + tools = [tool if isinstance(tool, AssistantTool) else AssistantTool(**tool) for tool in (tools or [])] or None + retrievals = [ + retrieval if isinstance(retrieval, AssistantRetrieval) else AssistantRetrieval(**retrieval) + for retrieval in (retrievals or []) + ] or None + retrieval_configs = ( + retrieval_configs + if isinstance(retrieval_configs, RetrievalConfig) + else RetrievalConfig(**retrieval_configs) + if retrieval_configs + else None + ) + return memory, tools, retrievals, retrieval_configs + + def list_assistants( order: str = "desc", limit: int = 20, @@ -118,12 +140,12 @@ async def a_get_assistant(assistant_id: str) -> Assistant: def create_assistant( model_id: str, - memory: AssistantMemory, + memory: Union[AssistantMemory, Dict[str, Any]], name: Optional[str] = None, description: Optional[str] = None, system_prompt_template: Optional[List[str]] = None, - tools: Optional[List[AssistantTool]] = None, - retrievals: Optional[List[AssistantRetrieval]] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, metadata: Optional[Dict[str, str]] = None, ) -> Assistant: @@ -140,12 +162,9 @@ def create_assistant( :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created assistant object. """ - if retrieval_configs: - retrieval_configs = ( - retrieval_configs - if isinstance(retrieval_configs, RetrievalConfig) - else RetrievalConfig(**retrieval_configs) - ) + memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params( + memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs + ) body = AssistantCreateRequest( model_id=model_id, @@ -164,12 +183,12 @@ def create_assistant( async def a_create_assistant( model_id: str, - memory: AssistantMemory, + memory: Union[AssistantMemory, Dict[str, Any]], name: Optional[str] = None, description: Optional[str] = None, system_prompt_template: Optional[List[str]] = None, - tools: Optional[List[AssistantTool]] = None, - retrievals: Optional[List[AssistantRetrieval]] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, metadata: Optional[Dict[str, str]] = None, ) -> Assistant: @@ -186,12 +205,9 @@ async def a_create_assistant( :param metadata: The assistant metadata. It can store up to 16 key-value pairs where each key's length is less than 64 and value's length is less than 512. :return: The created assistant object. """ - if retrieval_configs: - retrieval_configs = ( - retrieval_configs - if isinstance(retrieval_configs, RetrievalConfig) - else RetrievalConfig(**retrieval_configs) - ) + memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params( + memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs + ) body = AssistantCreateRequest( model_id=model_id, @@ -214,9 +230,9 @@ def update_assistant( name: Optional[str] = None, description: Optional[str] = None, system_prompt_template: Optional[List[str]] = None, - memory: Optional[AssistantMemory] = None, - tools: Optional[List[AssistantTool]] = None, - retrievals: Optional[List[AssistantRetrieval]] = None, + memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, metadata: Optional[Dict[str, str]] = None, ) -> Assistant: @@ -235,12 +251,10 @@ def update_assistant( :return: The updated assistant object. """ - if retrieval_configs: - retrieval_configs = ( - retrieval_configs - if isinstance(retrieval_configs, RetrievalConfig) - else RetrievalConfig(**retrieval_configs) - ) + memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params( + memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs + ) + body = AssistantUpdateRequest( model_id=model_id, name=name, @@ -262,9 +276,9 @@ async def a_update_assistant( name: Optional[str] = None, description: Optional[str] = None, system_prompt_template: Optional[List[str]] = None, - memory: Optional[AssistantMemory] = None, - tools: Optional[List[AssistantTool]] = None, - retrievals: Optional[List[AssistantRetrieval]] = None, + memory: Optional[Union[AssistantMemory, Dict[str, Any]]] = None, + tools: Optional[List[Union[AssistantTool, Dict[str, Any]]]] = None, + retrievals: Optional[List[Union[AssistantRetrieval, Dict[str, Any]]]] = None, retrieval_configs: Optional[Union[RetrievalConfig, Dict[str, Any]]] = None, metadata: Optional[Dict[str, str]] = None, ) -> Assistant: @@ -283,12 +297,10 @@ async def a_update_assistant( :return: The updated assistant object. """ - if retrieval_configs: - retrieval_configs = ( - retrieval_configs - if isinstance(retrieval_configs, RetrievalConfig) - else RetrievalConfig(**retrieval_configs) - ) + memory, tools, retrievals, retrieval_configs = _get_assistant_dict_params( + memory=memory, tools=tools, retrievals=retrievals, retrieval_configs=retrieval_configs + ) + body = AssistantUpdateRequest( model_id=model_id, name=name, diff --git a/taskingai/inference/chat_completion.py b/taskingai/inference/chat_completion.py index 3c83d52..7c509c1 100644 --- a/taskingai/inference/chat_completion.py +++ b/taskingai/inference/chat_completion.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict, Union +from typing import Any, Optional, List, Dict, Union from ..client.stream import Stream, AsyncStream from taskingai.client.models import * @@ -44,12 +44,35 @@ def __init__(self, id: str, content: str): super().__init__(role=ChatCompletionRole.FUNCTION, id=id, content=content) +def _get_completion_dict_params( + messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]], + functions: Optional[List[Union[Function, Dict[str, Any]]]] = None, +): + def _build_message(message: Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]): + if isinstance(message, Dict): + if message["role"] == ChatCompletionRole.SYSTEM.value: + return SystemMessage(**message) + if message["role"] == ChatCompletionRole.USER.value: + return UserMessage(**message) + if message["role"] == ChatCompletionRole.ASSISTANT.value: + return AssistantMessage(**message) + if message["role"] == ChatCompletionRole.FUNCTION.value: + return FunctionMessage(**message) + return message + + messages = [_build_message(message) for message in messages] + functions = [ + function if isinstance(function, Function) else Function(**function) for function in (functions or []) + ] or None + return messages, functions + + def chat_completion( model_id: str, - messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage]], + messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]], configs: Optional[Dict] = None, function_call: Optional[str] = None, - functions: Optional[List[Function]] = None, + functions: Optional[List[Union[Function, Dict[str, Any]]]] = None, stream: bool = False, ) -> Union[ChatCompletion, Stream]: """ @@ -63,6 +86,9 @@ def chat_completion( :param stream: Whether to request in stream mode. :return: The list of assistants. """ + + messages, functions = _get_completion_dict_params(messages, functions) + # only add non-None parameters body = ChatCompletionRequest( model_id=model_id, @@ -82,10 +108,10 @@ def chat_completion( async def a_chat_completion( model_id: str, - messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage]], + messages: List[Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]], configs: Optional[Dict] = None, function_call: Optional[str] = None, - functions: Optional[List[Function]] = None, + functions: Optional[List[Union[Function, Dict[str, Any]]]] = None, stream: bool = False, ) -> Union[ChatCompletion, AsyncStream]: """ @@ -99,6 +125,9 @@ async def a_chat_completion( :param stream: Whether to request in stream mode. :return: The list of assistants. """ + + messages, functions = _get_completion_dict_params(messages, functions) + # only add non-None parameters body = ChatCompletionRequest( model_id=model_id,