Skip to content

Commit

Permalink
redesign: ModelFeature as Enum instead of Literal
Browse files Browse the repository at this point in the history
  • Loading branch information
luochen1990 committed Aug 7, 2024
1 parent a74e311 commit 61ee341
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/ai_powered/chat_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import openai

from ai_powered.colors import gray
from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME
from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES
from ai_powered.llm_adapter.known_models import complete_model_config
from ai_powered.llm_adapter.openai.param_types import ChatCompletionMessageParam
from ai_powered.tool_call import ChatCompletionToolParam, MakeTool

default_client = openai.OpenAI(base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY)
model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME)
model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES)

@dataclass
class ChatBot:
Expand Down
6 changes: 5 additions & 1 deletion src/ai_powered/constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import os
from typing import Optional

from ai_powered.llm_adapter.definitions import ModelFeature

DEBUG = os.environ.get('DEBUG', 'False').lower() in {'true', '1', 'yes', 'on'}
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "sk-1234567890ab-MOCK-API-KEY")
OPENAI_BASE_URL = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
OPENAI_MODEL_NAME = os.environ.get("OPENAI_MODEL_NAME")
OPENAI_MODEL_FEATURES = os.environ.get("OPENAI_MODEL_FEATURES")
_features_str = os.environ.get("OPENAI_MODEL_FEATURES")
OPENAI_MODEL_FEATURES : Optional[set[ModelFeature]] = set(ModelFeature[s] for s in _features_str.split(',')) if _features_str else None

SYSTEM_PROMPT = """
You are a function simulator,
Expand Down
6 changes: 3 additions & 3 deletions src/ai_powered/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ai_powered.llm_adapter.generic_adapter import GenericFunctionSimulator
from ai_powered.llm_adapter.known_models import complete_model_config
from ai_powered.schema_deref import deref
from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME, SYSTEM_PROMPT, SYSTEM_PROMPT_JSON_SYNTAX, SYSTEM_PROMPT_RETURN_SCHEMA
from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_FEATURES, OPENAI_MODEL_NAME, SYSTEM_PROMPT, SYSTEM_PROMPT_JSON_SYNTAX, SYSTEM_PROMPT_RETURN_SCHEMA
from ai_powered.colors import gray, green
import inspect

Expand Down Expand Up @@ -46,7 +46,7 @@ def ai_powered(fn : Callable[P, R]) -> Callable[P, R]:
print(f"{param_name} (json schema): {schema}")
print(f"return (json schema): {return_schema}")

model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME)
model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES)
model_name = model_config.model_name
model_features: set[ModelFeature] = model_config.supported_features
model_options: dict[str, Any] = model_config.suggested_options
Expand All @@ -55,7 +55,7 @@ def ai_powered(fn : Callable[P, R]) -> Callable[P, R]:
signature = sig,
docstring = docstring or "no doc, guess intention from function name",
parameters_schema = json.dumps(parameters_schema),
) + ("" if "function_call" in model_features else SYSTEM_PROMPT_RETURN_SCHEMA.format(
) + ("" if ModelFeature.tools in model_features else SYSTEM_PROMPT_RETURN_SCHEMA.format(
return_schema = json.dumps(return_schema),
) + SYSTEM_PROMPT_JSON_SYNTAX )

Expand Down
18 changes: 14 additions & 4 deletions src/ai_powered/llm_adapter/definitions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC
from dataclasses import dataclass
import enum
from typing import Any, Optional
from typing_extensions import Literal

'''
信息确定的过程:
Expand All @@ -12,9 +12,19 @@
最终模拟函数执行时,需要参考函数信息和模型信息,以及连接信息,来确定函数执行的具体方式
'''

#Ref: https://ollama.fan/reference/openai/#supported-features
ModelFeature = Literal["function_call", "response_json", "specify_seed"]
ALL_FEATURES : set[ModelFeature] = {"function_call", "response_json", "specify_seed"}
class ModelFeature (enum.Enum):
'''
Ollama Doc: https://ollama.fan/reference/openai/#supported-features
OpenAI Doc:
- tools: https://platform.openai.com/docs/guides/function-calling
- response_format: https://platform.openai.com/docs/guides/structured-outputs
- seed: https://platform.openai.com/docs/advanced-usage/reproducible-outputs
'''
tools = "tools"
response_format = "response_format"
seed = "seed"

ALL_FEATURES : set[ModelFeature] = {ModelFeature.tools, ModelFeature.response_format, ModelFeature.seed}


@dataclass(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions src/ai_powered/llm_adapter/generic_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def query_model(self, user_msg: str) -> str:
response = client.chat.completions.create(
model = self.model_name,
messages = messages,
tools = tools if "function_call" in self.model_features else openai.NOT_GIVEN,
tool_choice = {"type": "function", "function": {"name": "return_result"}} if "function_call" in self.model_features else openai.NOT_GIVEN,
tools = tools if ModelFeature.tools in self.model_features else openai.NOT_GIVEN,
tool_choice = {"type": "function", "function": {"name": "return_result"}} if ModelFeature.tools in self.model_features else openai.NOT_GIVEN,
)
if DEBUG:
print(yellow(f"{response =}"))
Expand Down
4 changes: 2 additions & 2 deletions src/ai_powered/llm_adapter/known_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def equals(s: str) -> Callable[[str], bool]:
platform_name = "deepseek",
match_platform_url = contains("deepseek"),
known_model_list = [
KnownModel("deepseek-chat", {"function_call", "response_json"}),
KnownModel("deepseek-coder", {"function_call", "response_json"}),
KnownModel("deepseek-chat", {ModelFeature.tools}),
KnownModel("deepseek-coder", {ModelFeature.tools}),
]
),
KnownPlatform(
Expand Down

0 comments on commit 61ee341

Please sign in to comment.