Skip to content

Commit

Permalink
add: async support
Browse files Browse the repository at this point in the history
  • Loading branch information
luochen1990 committed Aug 8, 2024
1 parent 4268e68 commit 77e3930
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 10 deletions.
45 changes: 41 additions & 4 deletions src/ai_powered/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from functools import wraps
from typing import Any, Callable, Generic
from typing import Any, Awaitable, Callable, Generic, overload
import openai
from typing_extensions import ParamSpec, TypeVar
import json
Expand All @@ -20,7 +21,15 @@ class Result (msgspec.Struct, Generic[A]):
P = ParamSpec("P")
R = TypeVar("R")

def ai_powered(fn : Callable[P, R]) -> Callable[P, R]:
@overload
def ai_powered(fn: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
...

@overload
def ai_powered(fn: Callable[P, R]) -> Callable[P, R]:
...

def ai_powered(fn : Callable[P, Awaitable[R]] | Callable[P, R]) -> Callable[P, Awaitable[R]] | Callable[P, R]:
''' Provide an AI powered implementation of a function '''

function_name = fn.__name__
Expand Down Expand Up @@ -48,6 +57,7 @@ def ai_powered(fn : Callable[P, R]) -> Callable[P, R]:
print(f"return (json schema): {return_schema}")

client = openai.OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
async_client = openai.AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES)
model_name = model_config.model_name
model_features: set[ModelFeature] = model_config.supported_features
Expand All @@ -67,7 +77,7 @@ def ai_powered(fn : Callable[P, R]) -> Callable[P, R]:

fn_simulator = FunctionSimulatorSelector(
function_name, f"{sig}", docstring, parameters_schema, return_schema,
client, model_name, model_features, model_options
client, async_client, model_name, model_features, model_options
)

if DEBUG:
Expand All @@ -83,6 +93,7 @@ def wrapper_fn(*args: P.args, **kwargs: P.kwargs) -> R:
print(f"{real_arg_str =}")

resp_str = fn_simulator.query_model(real_arg_str)

if DEBUG:
print(f"{resp_str =}")
print(green(f"[fn {function_name}] response extracted."))
Expand All @@ -94,4 +105,30 @@ def wrapper_fn(*args: P.args, **kwargs: P.kwargs) -> R:

return returned_result.result #type: ignore

return wrapper_fn
@wraps(fn)
async def wrapper_fn_async(*args: P.args, **kwargs: P.kwargs) -> R:
real_arg = sig.bind(*args, **kwargs)
real_arg_str = msgspec.json.encode(real_arg.arguments).decode('utf-8')

if DEBUG:
print(f"{real_arg_str =}")

# NOTE: the main logic
resp_str = await fn_simulator.query_model_async(real_arg_str)

if DEBUG:
print(f"{resp_str =}")
print(green(f"[fn {function_name}] response extracted."))

returned_result = msgspec.json.decode(resp_str, type=result_type)
if DEBUG:
print(f"{returned_result =}")
print(green(f"[fn {function_name}] response validated."))

return returned_result.result #type: ignore


if asyncio.iscoroutinefunction(fn):
return wrapper_fn_async
else:
return wrapper_fn
10 changes: 7 additions & 3 deletions src/ai_powered/llm/adapter_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ def _select_impl(self) -> GenericFunctionSimulator:
if ModelFeature.structured_outputs in self.model_features:
return StructuredOutputFunctionSimulator(
self.function_name, self.signature, self.docstring, self.parameters_schema, self.return_schema,
self.client, self.model_name, self.model_features, self.model_options
self.client, self.async_client, self.model_name, self.model_features, self.model_options
)
elif ModelFeature.tools in self.model_features:
return ToolsFunctionSimulator(
self.function_name, self.signature, self.docstring, self.parameters_schema, self.return_schema,
self.client, self.model_name, self.model_features, self.model_options
self.client, self.async_client, self.model_name, self.model_features, self.model_options
)
else:
return ChatFunctionSimulator(
self.function_name, self.signature, self.docstring, self.parameters_schema, self.return_schema,
self.client, self.model_name, self.model_features, self.model_options
self.client, self.async_client, self.model_name, self.model_features, self.model_options
)

def __post_init__(self):
Expand All @@ -33,3 +33,7 @@ def __post_init__(self):

def query_model(self, arguments_json: str) -> str:
return self._selected_impl.query_model(arguments_json)

async def query_model_async(self, arguments_json: str) -> str:
result = await self._selected_impl.query_model_async(arguments_json)
return result
43 changes: 40 additions & 3 deletions src/ai_powered/llm/adapters/generic_adapter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC
from dataclasses import dataclass, field
import json
from typing import Any, Iterable, Set
Expand All @@ -6,16 +7,17 @@
from openai.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.completion_create_params import ResponseFormat
from ai_powered.colors import green, yellow
from ai_powered.colors import green, red, yellow
from ai_powered.constants import DEBUG, SYSTEM_PROMPT
from ai_powered.llm.definitions import FunctionSimulator, ModelFeature
from ai_powered.tool_call import ChatCompletionToolParam

@dataclass
class GenericFunctionSimulator (FunctionSimulator):
class GenericFunctionSimulator (FunctionSimulator, ABC):
''' implementation of FunctionSimulator for OpenAI compatible models '''

client: openai.OpenAI
async_client: openai.AsyncOpenAI
model_name: str
model_features: Set[ModelFeature]
model_options: dict[str, Any]
Expand Down Expand Up @@ -69,10 +71,27 @@ def _chat_completion_query(self, arguments_json: str) -> ChatCompletion:
response_format=self._param_response_format,
)

async def _chat_completion_query_async(self, arguments_json: str) -> ChatCompletion:
''' default impl is provided '''
result = await self.async_client.chat.completions.create(
model = self.model_name,
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": arguments_json}
],
tools = self._param_tools,
tool_choice = self._param_tool_choice,
response_format=self._param_response_format,
)
return result

def _response_message_parser(self, response_message: ChatCompletionMessage) -> str:
''' to be overrided '''
...
if DEBUG:
print(red(f"[GenericFunctionSimulator._response_message_parser()] {self.__class__ =}, {self._response_message_parser =}"))
raise NotImplementedError

#@override
def query_model(self, arguments_json: str) -> str:

if DEBUG:
Expand All @@ -89,3 +108,21 @@ def query_model(self, arguments_json: str) -> str:
response_message = response.choices[0].message
result_str = self._response_message_parser(response_message)
return result_str

#@override
async def query_model_async(self, arguments_json: str) -> str:

if DEBUG:
print(yellow(f"{arguments_json =}"))
print(yellow(f"request.tools = {self._param_tools}"))
print(green(f"[fn {self.function_name}] request prepared."))

response = await self._chat_completion_query_async(arguments_json)

if DEBUG:
print(yellow(f"[query_model_async()] {response =}"))
print(green(f"[fn {self.function_name}] response received."))

response_message = response.choices[0].message
result_str = self._response_message_parser(response_message)
return result_str
1 change: 1 addition & 0 deletions src/ai_powered/llm/adapters/tools_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _param_tools_maker(self) -> Iterable[ChatCompletionToolParam] | openai.NotGi
def _param_tool_choice_maker(self) -> ChatCompletionToolChoiceOptionParam | openai.NotGiven:
return {"type": "function", "function": {"name": "return_result"}}

#@override
def _response_message_parser(self, response_message: ChatCompletionMessage) -> str:
tool_calls = response_message.tool_calls

Expand Down
3 changes: 3 additions & 0 deletions src/ai_powered/llm/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ class FunctionSimulator (ABC):

def query_model(self, arguments_json: str) -> str:
...

async def query_model_async(self, arguments_json: str) -> str:
...
12 changes: 12 additions & 0 deletions test/examples/ai_powered_decorator/add_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pytest
from ai_powered import ai_powered

@ai_powered
async def add(a: int, b: int) -> int:
...

@pytest.mark.asyncio
async def test_add_async():

assert (await add(1, 1)) == 2
assert (await add(1, 2)) == 3

0 comments on commit 77e3930

Please sign in to comment.