Skip to content

Commit

Permalink
fix: [make_tool] struct issue
Browse files Browse the repository at this point in the history
  • Loading branch information
luochen1990 committed Aug 6, 2024
1 parent b8ce626 commit 38fd2d5
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/ai_powered/tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing_extensions import Literal, ParamSpec, Required, TypeVar, TypedDict
import msgspec

from ai_powered.colors import yellow
from ai_powered.colors import gray, green
from ai_powered.constants import DEBUG
from ai_powered.llm_adapter.openai.param_types import ChatCompletionToolMessageParam
from ai_powered.llm_adapter.openai.types import ChatCompletionMessageToolCall
Expand Down Expand Up @@ -48,13 +48,13 @@ def schema(self) -> ChatCompletionToolParam:
docstring = inspect.getdoc(self.fn)

if DEBUG:
print(f"[MakeTool.schema()] {sig =}")
print(gray(f"[MakeTool.schema()] {sig =}"))

raw_schema = msgspec.json.schema(self.struct_of_parameters())
parameters_schema = deref(raw_schema)

if DEBUG:
print(yellow(f"[MakeTool.schema()] {parameters_schema =}"))
print(gray(f"[MakeTool.schema()] {parameters_schema =}"))

return {
"type": "function",
Expand All @@ -66,10 +66,18 @@ def schema(self) -> ChatCompletionToolParam:
}

def call(self, tool_call: ChatCompletionMessageToolCall) -> ChatCompletionToolMessageParam:
if DEBUG:
print(green(f"[MakeTool.call()] {tool_call =}"))

assert self.fn.__name__ == tool_call.function.name

sig = inspect.signature(self.fn)
args_json : str = tool_call.function.arguments
args_dict = msgspec.json.decode(args_json)
args_obj = msgspec.json.decode(args_json, type=self.struct_of_parameters())
args_dict = msgspec.structs.asdict(args_obj)

if DEBUG:
print(gray(f"[MakeTool.call()] {args_dict =}"))

args = sig.bind(**args_dict).args

Expand Down

0 comments on commit 38fd2d5

Please sign in to comment.