From 61ee3413025f0b3371dbdbe1d41d3f537a6df464 Mon Sep 17 00:00:00 2001 From: LuoChen Date: Wed, 7 Aug 2024 13:11:23 +0800 Subject: [PATCH] redesign: ModelFeature as Enum instead of Literal --- src/ai_powered/chat_bot.py | 4 ++-- src/ai_powered/constants.py | 6 +++++- src/ai_powered/decorators.py | 6 +++--- src/ai_powered/llm_adapter/definitions.py | 18 ++++++++++++++---- src/ai_powered/llm_adapter/generic_adapter.py | 4 ++-- src/ai_powered/llm_adapter/known_models.py | 4 ++-- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/src/ai_powered/chat_bot.py b/src/ai_powered/chat_bot.py index 0693468..4826668 100644 --- a/src/ai_powered/chat_bot.py +++ b/src/ai_powered/chat_bot.py @@ -3,13 +3,13 @@ import openai from ai_powered.colors import gray -from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME +from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES from ai_powered.llm_adapter.known_models import complete_model_config from ai_powered.llm_adapter.openai.param_types import ChatCompletionMessageParam from ai_powered.tool_call import ChatCompletionToolParam, MakeTool default_client = openai.OpenAI(base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY) -model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME) +model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES) @dataclass class ChatBot: diff --git a/src/ai_powered/constants.py b/src/ai_powered/constants.py index f3cce6c..6319d25 100644 --- a/src/ai_powered/constants.py +++ b/src/ai_powered/constants.py @@ -1,10 +1,14 @@ import os +from typing import Optional + +from ai_powered.llm_adapter.definitions import ModelFeature DEBUG = os.environ.get('DEBUG', 'False').lower() in {'true', '1', 'yes', 'on'} OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "sk-1234567890ab-MOCK-API-KEY") OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") OPENAI_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME") -OPENAI_MODEL_FEATURES = os.environ.get("OPENAI_MODEL_FEATURES") +_features_str = os.environ.get("OPENAI_MODEL_FEATURES") +OPENAI_MODEL_FEATURES : Optional[set[ModelFeature]] = set(ModelFeature[s] for s in _features_str.split(',')) if _features_str else None SYSTEM_PROMPT = """ You are a function simulator, diff --git a/src/ai_powered/decorators.py b/src/ai_powered/decorators.py index c614ea7..225c5d0 100644 --- a/src/ai_powered/decorators.py +++ b/src/ai_powered/decorators.py @@ -8,7 +8,7 @@ from ai_powered.llm_adapter.generic_adapter import GenericFunctionSimulator from ai_powered.llm_adapter.known_models import complete_model_config from ai_powered.schema_deref import deref -from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME, SYSTEM_PROMPT, SYSTEM_PROMPT_JSON_SYNTAX, SYSTEM_PROMPT_RETURN_SCHEMA +from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_FEATURES, OPENAI_MODEL_NAME, SYSTEM_PROMPT, SYSTEM_PROMPT_JSON_SYNTAX, SYSTEM_PROMPT_RETURN_SCHEMA from ai_powered.colors import gray, green import inspect @@ -46,7 +46,7 @@ def ai_powered(fn : Callable[P, R]) -> Callable[P, R]: print(f"{param_name} (json schema): {schema}") print(f"return (json schema): {return_schema}") - model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME) + model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES) model_name = model_config.model_name model_features: set[ModelFeature] = model_config.supported_features model_options: dict[str, Any] = model_config.suggested_options @@ -55,7 +55,7 @@ def ai_powered(fn : Callable[P, R]) -> Callable[P, R]: signature = sig, docstring = docstring or "no doc, guess intention from function name", parameters_schema = json.dumps(parameters_schema), - ) + ("" if "function_call" in model_features else SYSTEM_PROMPT_RETURN_SCHEMA.format( + ) + ("" if ModelFeature.tools in model_features else SYSTEM_PROMPT_RETURN_SCHEMA.format( return_schema = json.dumps(return_schema), ) + SYSTEM_PROMPT_JSON_SYNTAX ) diff --git a/src/ai_powered/llm_adapter/definitions.py b/src/ai_powered/llm_adapter/definitions.py index baa0f65..d7e23d8 100644 --- a/src/ai_powered/llm_adapter/definitions.py +++ b/src/ai_powered/llm_adapter/definitions.py @@ -1,7 +1,7 @@ from abc import ABC from dataclasses import dataclass +import enum from typing import Any, Optional -from typing_extensions import Literal ''' 信息确定的过程: @@ -12,9 +12,19 @@ 最终模拟函数执行时,需要参考函数信息和模型信息,以及连接信息,来确定函数执行的具体方式 ''' -#Ref: https://ollama.fan/reference/openai/#supported-features -ModelFeature = Literal["function_call", "response_json", "specify_seed"] -ALL_FEATURES : set[ModelFeature] = {"function_call", "response_json", "specify_seed"} +class ModelFeature (enum.Enum): + ''' + Ollama Doc: https://ollama.fan/reference/openai/#supported-features + OpenAI Doc: + - tools: https://platform.openai.com/docs/guides/function-calling + - response_format: https://platform.openai.com/docs/guides/structured-outputs + - seed: https://platform.openai.com/docs/advanced-usage/reproducible-outputs + ''' + tools = "tools" + response_format = "response_format" + seed = "seed" + +ALL_FEATURES : set[ModelFeature] = {ModelFeature.tools, ModelFeature.response_format, ModelFeature.seed} @dataclass(frozen=True) diff --git a/src/ai_powered/llm_adapter/generic_adapter.py b/src/ai_powered/llm_adapter/generic_adapter.py index 6a603ee..90903dc 100644 --- a/src/ai_powered/llm_adapter/generic_adapter.py +++ b/src/ai_powered/llm_adapter/generic_adapter.py @@ -50,8 +50,8 @@ def query_model(self, user_msg: str) -> str: response = client.chat.completions.create( model = self.model_name, messages = messages, - tools = tools if "function_call" in self.model_features else openai.NOT_GIVEN, - tool_choice = {"type": "function", "function": {"name": "return_result"}} if "function_call" in self.model_features else openai.NOT_GIVEN, + tools = tools if ModelFeature.tools in self.model_features else openai.NOT_GIVEN, + tool_choice = {"type": "function", "function": {"name": "return_result"}} if ModelFeature.tools in self.model_features else openai.NOT_GIVEN, ) if DEBUG: print(yellow(f"{response =}")) diff --git a/src/ai_powered/llm_adapter/known_models.py b/src/ai_powered/llm_adapter/known_models.py index 280f10a..f928b43 100644 --- a/src/ai_powered/llm_adapter/known_models.py +++ b/src/ai_powered/llm_adapter/known_models.py @@ -62,8 +62,8 @@ def equals(s: str) -> Callable[[str], bool]: platform_name = "deepseek", match_platform_url = contains("deepseek"), known_model_list = [ - KnownModel("deepseek-chat", {"function_call", "response_json"}), - KnownModel("deepseek-coder", {"function_call", "response_json"}), + KnownModel("deepseek-chat", {ModelFeature.tools}), + KnownModel("deepseek-coder", {ModelFeature.tools}), ] ), KnownPlatform(