-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a53ab79
commit cee913b
Showing
6 changed files
with
226 additions
and
60 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from abc import ABC | ||
from dataclasses import dataclass | ||
from typing import Any, Literal, Optional | ||
|
||
''' | ||
信息确定的过程: | ||
首先是连接信息被确定,如BASE URL,API KEY等 | ||
然后是模型信息被确定,如模型名称,模型参数等 | ||
模型信息确定后,模型支持的特性随之被确定,如是否支持函数调用,是否支持JSON格式的返回值等 | ||
再然后是函数信息被确定,如函数签名,函数文档等 | ||
最终模拟函数执行时,需要参考函数信息和模型信息,以及连接信息,来确定函数执行的具体方式 | ||
''' | ||
|
||
#Ref: https://ollama.fan/reference/openai/#supported-features | ||
ModelFeature = Literal["function_call", "response_json", "specify_seed"] | ||
ALL_FEATURES : set[ModelFeature] = {"function_call", "response_json", "specify_seed"} | ||
|
||
|
||
@dataclass(frozen=True) | ||
class FunctionSimulator (ABC): | ||
''' just a wrapper to call model, without checking type correctness ''' | ||
|
||
function_name: str | ||
signature: str | ||
docstring: Optional[str] | ||
parameters_schema: dict[str, Any] | ||
return_schema: dict[str, Any] | ||
|
||
def execute(self, arguments_json: str) -> str: | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Set | ||
from ai_powered.colors import green | ||
from ai_powered.constants import DEBUG | ||
from .definitions import FunctionSimulator, ModelFeature | ||
import openai | ||
|
||
|
||
@dataclass(frozen=True) | ||
class GenericFunctionSimulator (FunctionSimulator): | ||
''' implementation of FunctionSimulator for OpenAI compatible models ''' | ||
|
||
base_url: str | ||
api_key: str | ||
model_name: str | ||
model_features: Set[ModelFeature] | ||
model_options: dict[str, Any] | ||
system_prompt : str | ||
|
||
def query_model(self, user_msg: str) -> str: | ||
client = openai.OpenAI(base_url=self.base_url, api_key=self.api_key, **self.model_options) | ||
|
||
response = client.chat.completions.create( | ||
model = self.model_name, | ||
messages = [ | ||
{"role": "system", "content": self.system_prompt}, | ||
{"role": "user", "content": user_msg} | ||
], | ||
tools = [{ | ||
"type": "function", | ||
"function": { | ||
"name": "return_result", | ||
"parameters": self.return_schema, | ||
}, | ||
}], | ||
tool_choice = {"type": "function", "function": {"name": "return_result"}}, | ||
) | ||
|
||
if DEBUG: | ||
print(f"{response =}") | ||
print(green(f"[fn {self.function_name}] response received.")) | ||
|
||
resp_msg = response.choices[0].message | ||
tool_calls = resp_msg.tool_calls | ||
|
||
if tool_calls is not None: | ||
return tool_calls[0].function.arguments | ||
else: | ||
raw_resp_str = resp_msg.content | ||
assert raw_resp_str is not None | ||
|
||
# raw_resp_str = "```json\n{"result": 2}\n```" | ||
|
||
if raw_resp_str.startswith("```json\n") and raw_resp_str.endswith("\n```"): | ||
unwrapped_resp_str = raw_resp_str[8:-4] | ||
else: | ||
unwrapped_resp_str = raw_resp_str | ||
|
||
# unwrapped_result_str = "2" | ||
|
||
if unwrapped_resp_str.startswith('{"result":') and unwrapped_resp_str.endswith("}"): | ||
result_str = unwrapped_resp_str | ||
else: | ||
result_str = f'{{"result": {unwrapped_resp_str}}}' | ||
|
||
if DEBUG: | ||
print(f"{result_str =}") | ||
return result_str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Any, Callable, Optional, Set, TypeAlias | ||
|
||
from ai_powered.llm_adapter.definitions import ALL_FEATURES, ModelFeature | ||
|
||
@dataclass | ||
class KnownPlatform: | ||
''' information about a known platform ''' | ||
|
||
platform_name: str | ||
match_platform_url: Callable[[str], bool] | ||
known_model_list: list["KnownModel"] | ||
|
||
@dataclass | ||
class KnownModel: | ||
''' information about a known model ''' | ||
|
||
model_name: str | ||
supported_features: Set[ModelFeature] | ||
suggested_options: dict[str, Any] = field(default_factory=dict) | ||
|
||
|
||
def contains(s: str) -> Callable[[str], bool]: | ||
''' create a function to check if a string contains a substring ''' | ||
return lambda text: s in text | ||
|
||
def starts_with(s: str) -> Callable[[str], bool]: | ||
''' create a function to check if a string starts with a substring ''' | ||
return lambda text: text.startswith(s) | ||
|
||
def equals(s: str) -> Callable[[str], bool]: | ||
''' create a function to check if a string contains a substring ''' | ||
return lambda text: s == text | ||
|
||
KNOWN_PLATFORMS : list[KnownPlatform] = [ | ||
KnownPlatform( | ||
platform_name = "openai", | ||
match_platform_url = contains("openai"), | ||
known_model_list = [ | ||
KnownModel("gpt-4o-mini", ALL_FEATURES), | ||
KnownModel("gpt-4o", ALL_FEATURES), | ||
] | ||
), | ||
KnownPlatform( | ||
platform_name = "deepseek", | ||
match_platform_url = contains("deepseek"), | ||
known_model_list = [ | ||
KnownModel("deepseek-chat", set()), | ||
KnownModel("deepseek-coder", set()), | ||
] | ||
), | ||
] | ||
|
||
ModelConfig : TypeAlias = KnownModel | ||
|
||
def complete_model_config(platform_url: str, model_name: Optional[str]) -> ModelConfig: | ||
''' select a known model from a known platform ''' | ||
for platform in KNOWN_PLATFORMS: | ||
if platform.match_platform_url(platform_url): | ||
if model_name is not None: | ||
for known_model in platform.known_model_list: | ||
if model_name.startswith(known_model.model_name): | ||
return known_model | ||
else: | ||
return platform.known_model_list[0] #known platform, but model not specified | ||
return platform.known_model_list[0] #known platform, but unknown model specified | ||
#unknown platform | ||
if model_name is not None: | ||
return ModelConfig(model_name, ALL_FEATURES) | ||
else: | ||
raise ValueError(f"Unknown platform: {platform_url}, please specify a model name") |